Skip to content

Commit 042800a

Browse files
authored
[mlir][ArmSME] Add initial SME vector legalization pass (#79152)
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for `vector.outerproduct`, `vector.transfer_read`, and `vector.transfer_write` when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME. For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. `vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>` can all be lowered to four `vector<[4]x[4]xf32>` operations. In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.
1 parent d74619a commit 042800a

File tree

9 files changed

+991
-1
lines changed

9 files changed

+991
-1
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ std::unique_ptr<Pass> createTileAllocationPass();
3636
/// variants.
3737
std::unique_ptr<Pass> createOuterProductFusionPass();
3838

39+
/// Pass that legalizes vectors so they can be lowered to ArmSME.
40+
std::unique_ptr<Pass> createVectorLegalizationPass();
41+
3942
//===----------------------------------------------------------------------===//
4043
// Registration
4144
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,27 @@ def OuterProductFusion
156156
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
157157
}
158158

159+
def VectorLegalization
160+
: Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
161+
let summary = "Legalize vectors for ArmSME";
162+
let description = [{
163+
This pass legalizes vector operations so that they can be lowered to ArmSME.
164+
This includes decomposing operations that operate on vector types larger
165+
than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
166+
tile-sized operations, as well as rewrites needed to get operations into
167+
forms compatible with SME lowerings.
168+
169+
Note: Decomposition is currently limited to vector types that are an exact
170+
multiple of SME tiles. That is scalable in two dimensions, with both the
171+
rows and columns divisible by the SVE vector length for the element type.
172+
}];
173+
let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
174+
let dependentDialects = [
175+
"func::FuncDialect",
176+
"arm_sme::ArmSMEDialect",
177+
"vector::VectorDialect",
178+
"arith::ArithDialect"
179+
];
180+
}
181+
159182
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
5656
PatternRewriter &rewriter, Location loc, Value initTile,
5757
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
5858

59+
/// Returns true if `vType` is a multiple of an SME tile size. Returns false if
60+
/// the `vType` exactly matches the size of an SME tile.
61+
bool isMultipleOfSMETileVectorType(VectorType vType);
62+
63+
/// Creates a vector type for the SME tile of `elementType`.
64+
VectorType getSMETileTypeForElement(Type elementType);
65+
5966
} // namespace mlir::arm_sme
6067

6168
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_

mlir/lib/Dialect/ArmSME/IR/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
9494
return forOp;
9595
}
9696

97+
bool isMultipleOfSMETileVectorType(VectorType vType) {
98+
if (vType.getRank() != 2 || !vType.allDimsScalable())
99+
return false;
100+
101+
auto elementType = vType.getElementType();
102+
if (!isValidSMETileElementType(elementType))
103+
return false;
104+
105+
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
106+
107+
int64_t vectorRows = vType.getDimSize(0);
108+
int64_t vectorCols = vType.getDimSize(1);
109+
110+
return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111+
vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112+
}
113+
114+
VectorType getSMETileTypeForElement(Type elementType) {
115+
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
116+
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117+
}
118+
97119
} // namespace mlir::arm_sme

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
22
EnableArmStreaming.cpp
33
OuterProductFusion.cpp
44
TileAllocation.cpp
5+
VectorLegalization.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -10,10 +11,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
1011
MLIRArmSMETransformsIncGen
1112

1213
LINK_LIBS PUBLIC
14+
MLIRPass
1315
MLIRArmSMEDialect
1416
MLIRFuncDialect
1517
MLIRLLVMCommonConversion
1618
MLIRVectorDialect
1719
MLIRSCFDialect
18-
MLIRPass
20+
MLIRSCFTransforms
21+
MLIRFuncTransforms
1922
)

0 commit comments

Comments
 (0)