Skip to content

[mlir][SME] Add vector.splat -> SME conversion #67659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 28, 2023

Conversation

banach-space
Copy link
Contributor

This conversion is identical to vector.broadcast when broadcasting a
scalar.

This conversion is identical to vector.broadcast when broadcasting a
scalar.
@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2023

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Changes

This conversion is identical to vector.broadcast when broadcasting a
scalar.


Full diff: https://github.com/llvm/llvm-project/pull/67659.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+59-1)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+38)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..a83c0e9cdafa521 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -240,6 +240,63 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.splat.
+///
+/// Example:
+///
+///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
+///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
+///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
+///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   }
+///
+/// This should, in practice, be identical to vector.broadcast when
+/// broadcasting a scalar.
+struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
+  using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = splatOp.getResult().getType();
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = splatOp.getLoc();
+
+    auto srcType = splatOp.getOperand().getType();
+    auto tileElementType = tileType.getElementType();
+
+    assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
+
+    // First, broadcast the scalar to a 1-d vector.
+    auto tileSliceType =
+        VectorType::get(tileType.getShape().drop_front(), tileElementType,
+                        /*scalableDims=*/{true});
+    Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
+        loc, tileSliceType, splatOp.getOperand());
+
+    arm_sme::CastTileToVector tile =
+        getSMETileAndCastToVector(rewriter, loc, tileType);
+
+    // Next, create a loop over ZA tile slices and "move" the generated 1-d
+    // vector to each slice.
+    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+
+    rewriter.replaceOp(splatOp, tile);
+
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.transpose.
 ///
 /// Stores the input tile to memory and reloads vertically.
@@ -319,5 +376,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
-               BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
+               BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+               TransposeOpToArmSMELowering>(&ctx);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..3c08b1deafad27d 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -220,6 +220,44 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// vector.splat
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL:   func.func @splat_vec2d_from_i32(
+// CHECK-SAME:      %[[SRC:.*]]: i32) {
+// CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
+// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK:   %[[VSCALE:.*]] = vector.vscale
+// CHECK:   %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK:   scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK:     arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @splat_vec2d_from_i32(%arg0: i32) {
+  %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @splat_vec2d_from_f16(
+// CHECK-SAME:      %[[SRC:.*]]: f16) {
+// CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
+// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
+// CHECK:   %[[VSCALE:.*]] = vector.vscale
+// CHECK:   %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK:   scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK:     arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
+func.func @splat_vec2d_from_f16(%arg0: f16) {
+  %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.transpose
 //===----------------------------------------------------------------------===//

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some nits:

Comment on lines 277 to 279
auto tileSliceType =
VectorType::get(tileType.getShape().drop_front(), tileElementType,
/*scalableDims=*/{true});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto tileSliceType =
VectorType::get(tileType.getShape().drop_front(), tileElementType,
/*scalableDims=*/{true});
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);

if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();

OpBuilder::InsertionGuard g(rewriter);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this guard necessary? It exists for the entire scope of the rewrite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically speaking, it should be sufficient to move this just next to where getLoopOverTileSlices is called. See #67668 for some clarification on this.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for patch Andrzej, one more vector op we can tick of the list

VectorType::get(tileType.getShape().drop_front(), tileElementType,
/*scalableDims=*/{true});
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, splatOp.getOperand());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems getInput is typically used (the name of the operand)

Suggested change
loc, tileSliceType, splatOp.getOperand());
loc, tileSliceType, splatOp.getInput());

// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index

Comment on lines 246 to 254
// CHECK-LABEL: func.func @splat_vec2d_from_f16(
// CHECK-SAME: %[[SRC:.*]]: f16) {
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
// CHECK: scf.for {{.*}} to %[[UB]] {{.*}} {
// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much of this is tested by the previous test, could you remove some CHECK lines?

Comment on lines 257 to 258
/// This should, in practice, be identical to vector.broadcast when
/// broadcasting a scalar.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's identical, I think the ambiguity can be removed

Suggested change
/// This should, in practice, be identical to vector.broadcast when
/// broadcasting a scalar.
/// This is identical to vector.broadcast of a scalar.

Incorporate suggestions from Ben and Cullen
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@banach-space banach-space merged commit 0cb0df4 into llvm:main Sep 28, 2023
legrosbuffle pushed a commit to legrosbuffle/llvm-project that referenced this pull request Sep 29, 2023
This conversion is identical to vector.broadcast when broadcasting a
scalar.
@banach-space banach-space deleted the andrzej/splat_sme branch September 29, 2023 09:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants