Skip to content

Commit 0cb0df4

Browse files
authored
[mlir][SME] Add vector.splat -> SME conversion (#67659)
This conversion is identical to vector.broadcast when broadcasting a scalar.
1 parent 898b961 commit 0cb0df4

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,60 @@ struct BroadcastOpToArmSMELowering
301301
}
302302
};
303303

304+
/// Conversion pattern for vector.splat.
305+
///
306+
/// Example:
307+
///
308+
/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
309+
///
310+
/// is converted to:
311+
///
312+
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
313+
/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
314+
/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
315+
/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
316+
/// }
317+
///
318+
/// This is identical to vector.broadcast of a scalar.
319+
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
320+
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
321+
322+
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
323+
PatternRewriter &rewriter) const final {
324+
auto tileType = splatOp.getResult().getType();
325+
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
326+
return failure();
327+
328+
OpBuilder::InsertionGuard g(rewriter);
329+
auto loc = splatOp.getLoc();
330+
331+
auto srcType = splatOp.getOperand().getType();
332+
auto tileElementType = tileType.getElementType();
333+
334+
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
335+
336+
// First, broadcast the scalar to a 1-d vector.
337+
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
338+
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
339+
loc, tileSliceType, splatOp.getInput());
340+
341+
arm_sme::CastTileToVector tile =
342+
getSMETileAndCastToVector(rewriter, loc, tileType);
343+
344+
// Next, create a loop over ZA tile slices and "move" the generated 1-d
345+
// vector to each slice.
346+
auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
347+
auto tileSliceIndex = forOp.getInductionVar();
348+
349+
rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
350+
loc, tileType, broadcastOp1D, tile, tileSliceIndex);
351+
352+
rewriter.replaceOp(splatOp, tile);
353+
354+
return success();
355+
}
356+
};
357+
304358
/// Conversion pattern for vector.transpose.
305359
///
306360
/// Stores the input tile to memory and reloads vertically.
@@ -381,5 +435,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
381435
patterns.add<TransferReadPermutationToArmSMELowering,
382436
TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
383437
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
384-
BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
438+
BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
439+
TransposeOpToArmSMELowering>(&ctx);
385440
}

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,40 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
427427
return
428428
}
429429

430+
//===----------------------------------------------------------------------===//
431+
// vector.splat
432+
//===----------------------------------------------------------------------===//
433+
434+
// -----
435+
436+
// CHECK-LABEL: func.func @splat_vec2d_from_i32(
437+
// CHECK-SAME: %[[SRC:.*]]: i32) {
438+
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
439+
// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
440+
// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
441+
// CHECK: %[[VSCALE:.*]] = vector.vscale
442+
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
443+
// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
444+
// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
445+
func.func @splat_vec2d_from_i32(%arg0: i32) {
446+
%0 = vector.splat %arg0 : vector<[4]x[4]xi32>
447+
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
448+
return
449+
}
450+
451+
// -----
452+
453+
// CHECK-LABEL: func.func @splat_vec2d_from_f16(
454+
// CHECK-SAME: %[[SRC:.*]]: f16) {
455+
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
456+
// CHECK: scf.for
457+
// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
458+
func.func @splat_vec2d_from_f16(%arg0: f16) {
459+
%0 = vector.splat %arg0 : vector<[8]x[8]xf16>
460+
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
461+
return
462+
}
463+
430464
//===----------------------------------------------------------------------===//
431465
// vector.transpose
432466
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)