@@ -301,6 +301,60 @@ struct BroadcastOpToArmSMELowering
301
301
}
302
302
};
303
303
304
+ // / Conversion pattern for vector.splat.
305
+ // /
306
+ // / Example:
307
+ // /
308
+ // / %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
309
+ // /
310
+ // / is converted to:
311
+ // /
312
+ // / %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
313
+ // / scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
314
+ // / arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
315
+ // / %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
316
+ // / }
317
+ // /
318
+ // / This is identical to vector.broadcast of a scalar.
319
+ struct SplatOpToArmSMELowering : public OpRewritePattern <vector::SplatOp> {
320
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
321
+
322
+ LogicalResult matchAndRewrite (vector::SplatOp splatOp,
323
+ PatternRewriter &rewriter) const final {
324
+ auto tileType = splatOp.getResult ().getType ();
325
+ if (!tileType || !arm_sme::isValidSMETileVectorType (tileType))
326
+ return failure ();
327
+
328
+ OpBuilder::InsertionGuard g (rewriter);
329
+ auto loc = splatOp.getLoc ();
330
+
331
+ auto srcType = splatOp.getOperand ().getType ();
332
+ auto tileElementType = tileType.getElementType ();
333
+
334
+ assert (srcType.isIntOrFloat () && " Invalid source type for vector.splat" );
335
+
336
+ // First, broadcast the scalar to a 1-d vector.
337
+ VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
338
+ Value broadcastOp1D = rewriter.create <vector::BroadcastOp>(
339
+ loc, tileSliceType, splatOp.getInput ());
340
+
341
+ arm_sme::CastTileToVector tile =
342
+ getSMETileAndCastToVector (rewriter, loc, tileType);
343
+
344
+ // Next, create a loop over ZA tile slices and "move" the generated 1-d
345
+ // vector to each slice.
346
+ auto forOp = getLoopOverTileSlices (rewriter, loc, tileElementType);
347
+ auto tileSliceIndex = forOp.getInductionVar ();
348
+
349
+ rewriter.create <arm_sme::MoveVectorToTileSliceOp>(
350
+ loc, tileType, broadcastOp1D, tile, tileSliceIndex);
351
+
352
+ rewriter.replaceOp (splatOp, tile);
353
+
354
+ return success ();
355
+ }
356
+ };
357
+
304
358
// / Conversion pattern for vector.transpose.
305
359
// /
306
360
// / Stores the input tile to memory and reloads vertically.
@@ -381,5 +435,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
381
435
patterns.add <TransferReadPermutationToArmSMELowering,
382
436
TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
383
437
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
384
- BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
438
+ BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
439
+ TransposeOpToArmSMELowering>(&ctx);
385
440
}
0 commit comments