Skip to content

Commit dadcaf8

Browse files
authored
[mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (#88762)
This adds a simple rewrite/legalization to decompose constant splats larger than a single ArmSME tile into multiple SME virtual tile sized splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose into four `vector<[4]x[4]xi32>` splats.
1 parent 71b9f66 commit dadcaf8

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
165165
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
166166
}
167167

168+
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
169+
/// tiles by decomposing them into tile-sized operations.
170+
struct LegalizeArithConstantOpsByDecomposition
171+
: public OneToNOpConversionPattern<arith::ConstantOp> {
172+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
173+
174+
LogicalResult
175+
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
176+
OneToNPatternRewriter &rewriter) const override {
177+
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
178+
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
179+
if (!vectorType || !denseAttr || !denseAttr.isSplat())
180+
return failure();
181+
182+
if (!isMultipleOfSMETileVectorType(vectorType))
183+
return rewriter.notifyMatchFailure(constantOp,
184+
kMatchFailureNotSMETileTypeMultiple);
185+
186+
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
187+
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
188+
auto tileSplat = rewriter.create<arith::ConstantOp>(
189+
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
190+
rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
191+
adaptor.getResultMapping());
192+
193+
return success();
194+
}
195+
};
196+
168197
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
169198
/// decomposing them into tile-sized operations.
170199
struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
637666
// Note: High benefit to ensure masked outer products are lowered first.
638667
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
639668
converter, context, 1024);
640-
patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
669+
patterns.add<LegalizeArithConstantOpsByDecomposition,
670+
LegalizeVectorOuterProductOpsByDecomposition,
641671
LegalizeTransferReadOpsByDecomposition,
642672
LegalizeTransferWriteOpsByDecomposition>(converter, context);
643673
populateFuncTypeConversionPatterns(converter, patterns);

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m
433433
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
434434
return %cast : vector<[4]xf32>
435435
}
436+
437+
// -----
438+
439+
// CHECK-LABEL: @multi_tile_splat
440+
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
441+
{
442+
// CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
443+
// CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
444+
%0 = arith.constant dense<42> : vector<[8]x[8]xi32>
445+
return %0 : vector<[8]x[8]xi32>
446+
}

0 commit comments

Comments
 (0)