Skip to content

Commit ac31713

Browse files
committed
Separate vector interleave lowering and to-shuffle rewrite patterns from dialect conversions
1 parent e558d21 commit ac31713

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,6 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
830830
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
831831
PatternBenefit(2));
832832

833-
// Need this until vector.interleave is handled.
834-
vector::populateVectorInterleaveToShufflePatterns(patterns);
835833
}
836834

837835
void mlir::populateVectorReductionToSPIRVDotProductPatterns(

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1616
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1717
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2022

2123
namespace mlir {
2224
#define GEN_PASS_DEF_CONVERTVECTORTOSPIRV
@@ -36,6 +38,15 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
3638
MLIRContext *context = &getContext();
3739
Operation *op = getOperation();
3840

41+
// Rewrite patterns need to be matched separately from the dialect conversion
42+
{
43+
RewritePatternSet patterns(context);
44+
vector::populateVectorInterleaveLoweringPatterns(patterns);
45+
vector::populateVectorInterleaveToShufflePatterns(patterns);
46+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
47+
return signalPassFailure();
48+
}
49+
3950
auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
4051
std::unique_ptr<ConversionTarget> target =
4152
SPIRVConversionTarget::get(targetAttr);

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,18 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
483483

484484
// -----
485485

486+
// CHECK-LABEL: func @interleave
487+
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
488+
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
489+
// CHECK: return %[[SHUFFLE]]
490+
func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32>
491+
{
492+
%0 = vector.interleave %a, %b : vector<2xf32>
493+
return %0 : vector<4xf32>
494+
}
495+
496+
// -----
497+
486498
// CHECK-LABEL: func @reduction_add
487499
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
488500
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

0 commit comments

Comments
 (0)