-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][SME] Add vector.splat -> SME conversion #67659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This conversion is identical to vector.broadcast when broadcasting a scalar.
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir ChangesThis conversion is identical to vector.broadcast when broadcasting a Full diff: https://github.com/llvm/llvm-project/pull/67659.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..a83c0e9cdafa521 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -240,6 +240,63 @@ struct BroadcastOpToArmSMELowering
}
};
+/// Conversion pattern for vector.splat.
+///
+/// Example:
+///
+/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
+/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
+/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
+/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
+///
+/// This should, in practice, be identical to vector.broadcast when
+/// broadcasting a scalar.
+struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = splatOp.getResult().getType();
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = splatOp.getLoc();
+
+ auto srcType = splatOp.getOperand().getType();
+ auto tileElementType = tileType.getElementType();
+
+ assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
+
+ // First, broadcast the scalar to a 1-d vector.
+ auto tileSliceType =
+ VectorType::get(tileType.getShape().drop_front(), tileElementType,
+ /*scalableDims=*/{true});
+ Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
+ loc, tileSliceType, splatOp.getOperand());
+
+ arm_sme::CastTileToVector tile =
+ getSMETileAndCastToVector(rewriter, loc, tileType);
+
+ // Next, create a loop over ZA tile slices and "move" the generated 1-d
+ // vector to each slice.
+ auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
+ auto tileSliceIndex = forOp.getInductionVar();
+
+ rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+
+ rewriter.replaceOp(splatOp, tile);
+
+ return success();
+ }
+};
+
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -319,5 +376,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
- BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
+ BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransposeOpToArmSMELowering>(&ctx);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..3c08b1deafad27d 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -220,6 +220,44 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
return
}
+//===----------------------------------------------------------------------===//
+// vector.splat
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func.func @splat_vec2d_from_i32(
+// CHECK-SAME: %[[SRC:.*]]: i32) {
+// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK: scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @splat_vec2d_from_i32(%arg0: i32) {
+ %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @splat_vec2d_from_f16(
+// CHECK-SAME: %[[SRC:.*]]: f16) {
+// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK: scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
+func.func @splat_vec2d_from_f16(%arg0: f16) {
+ %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+ return
+}
+
//===----------------------------------------------------------------------===//
// vector.transpose
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some nits:
auto tileSliceType = | ||
VectorType::get(tileType.getShape().drop_front(), tileElementType, | ||
/*scalableDims=*/{true}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto tileSliceType = | |
VectorType::get(tileType.getShape().drop_front(), tileElementType, | |
/*scalableDims=*/{true}); | |
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) | ||
return failure(); | ||
|
||
OpBuilder::InsertionGuard g(rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this guard necessary? It exists for the entire scope of the rewrite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically speaking, it should be sufficient to move this just next to where getLoopOverTileSlices
is called. See #67668 for some clarification on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for patch Andrzej, one more vector op we can tick of the list
VectorType::get(tileType.getShape().drop_front(), tileElementType, | ||
/*scalableDims=*/{true}); | ||
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( | ||
loc, tileSliceType, splatOp.getOperand()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems getInput
is typically used (the name of the operand)
loc, tileSliceType, splatOp.getOperand()); | |
loc, tileSliceType, splatOp.getInput()); |
// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 | ||
// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> | ||
// CHECK: %[[VSCALE:.*]] = vector.vscale | ||
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index | |
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index |
// CHECK-LABEL: func.func @splat_vec2d_from_f16( | ||
// CHECK-SAME: %[[SRC:.*]]: f16) { | ||
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16> | ||
// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 | ||
// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16> | ||
// CHECK: %[[VSCALE:.*]] = vector.vscale | ||
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index | ||
// CHECK: scf.for {{.*}} to %[[UB]] {{.*}} { | ||
// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
much of this is tested by the previous test, could you remove some CHECK lines?
/// This should, in practice, be identical to vector.broadcast when | ||
/// broadcasting a scalar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's identical, I think the ambiguity can be removed
/// This should, in practice, be identical to vector.broadcast when | |
/// broadcasting a scalar. | |
/// This is identical to vector.broadcast of a scalar. |
Incorporate suggestions from Ben and Cullen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
Add missing insertion guard
This conversion is identical to vector.broadcast when broadcasting a scalar.
This conversion is identical to vector.broadcast when broadcasting a
scalar.