Skip to content

Commit 97256a4

Browse files
angelz913yuxuanchen1997
authored andcommitted
[mlir][spirv] Implement vector type legalization for function signatures (#98337)
Summary: ### Description This PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion. ### Future Plans - Check for capabilities that support vectors of size 8 or 16. - Set up `OneToNTypeConversion` and `DialectConversion` to replace the current implementation that uses `GreedyPatternRewriteDriver`. - Introduce other vector unrolling patterns to cancel out the `vector.insert_strided_slice` and `vector.extract_strided_slice` ops and fully legalize the vector types in the function body. - Handle `func::CallOp` and declarations. - Restructure the code in `SPIRVConversion.cpp`. - Create test passes for testing sets of patterns in isolation. - Optimize the way original shape is splitted into target shapes, e.g. `vector<5xi32>` can be splitted into `vector<4xi32>` and `vector<1xi32>`. --------- Co-authored-by: Jakub Kuderski <[email protected]> Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250907
1 parent ece69c6 commit 97256a4

File tree

20 files changed

+578
-13
lines changed

20 files changed

+578
-13
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
4040
let description = [{
4141
This is a generic pass to convert to SPIR-V.
4242
}];
43-
let dependentDialects = ["spirv::SPIRVDialect"];
43+
let dependentDialects = [
44+
"spirv::SPIRVDialect",
45+
"vector::VectorDialect",
46+
];
47+
let options = [
48+
Option<"runSignatureConversion", "run-signature-conversion", "bool",
49+
/*default=*/"true",
50+
"Run function signature conversion to convert vector types">
51+
];
4452
}
4553

4654
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2021
#include "mlir/Transforms/DialectConversion.h"
22+
#include "mlir/Transforms/OneToNTypeConversion.h"
2123
#include "llvm/ADT/SmallSet.h"
2224

2325
namespace mlir {
@@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget {
134136
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
135137
RewritePatternSet &patterns);
136138

139+
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
140+
141+
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
142+
137143
namespace spirv {
138144
class AccessChainOp;
139145

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,31 @@ namespace {
3939
/// A pass to perform the SPIR-V conversion.
4040
struct ConvertToSPIRVPass final
4141
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
42+
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
4243

4344
void runOnOperation() override {
4445
MLIRContext *context = &getContext();
4546
Operation *op = getOperation();
4647

48+
if (runSignatureConversion) {
49+
// Unroll vectors in function signatures to native vector size.
50+
RewritePatternSet patterns(context);
51+
populateFuncOpVectorRewritePatterns(patterns);
52+
populateReturnOpVectorRewritePatterns(patterns);
53+
GreedyRewriteConfig config;
54+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
55+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
56+
return signalPassFailure();
57+
}
58+
4759
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
60+
std::unique_ptr<ConversionTarget> target =
61+
SPIRVConversionTarget::get(targetAttr);
4862
SPIRVTypeConverter typeConverter(targetAttr);
49-
5063
RewritePatternSet patterns(context);
5164
ScfToSPIRVContext scfToSPIRVContext;
5265

53-
// Populate patterns.
66+
// Populate patterns for each dialect.
5467
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
5568
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
5669
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
@@ -60,9 +73,6 @@ struct ConvertToSPIRVPass final
6073
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
6174
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
6275

63-
std::unique_ptr<ConversionTarget> target =
64-
SPIRVConversionTarget::get(targetAttr);
65-
6676
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
6777
return signalPassFailure();
6878
}

mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,15 @@ add_mlir_dialect_library(MLIRSPIRVConversion
1616
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
1717

1818
LINK_LIBS PUBLIC
19+
MLIRArithDialect
20+
MLIRDialectUtils
1921
MLIRFuncDialect
22+
MLIRIR
2023
MLIRSPIRVDialect
24+
MLIRSupport
2125
MLIRTransformUtils
26+
MLIRVectorDialect
27+
MLIRVectorTransforms
2228
)
2329

2430
add_mlir_dialect_library(MLIRSPIRVTransforms

0 commit comments

Comments
 (0)