Skip to content

Commit f847f45

Browse files
committed
Review fixups
1 parent 6475f49 commit f847f45

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ struct ConvertIllegalShapeCastOpsToTransposes
803803
/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
804804
/// %c4_vscale = arith.muli %vscale, %c4 : index
805805
/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
806-
/// vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
806+
/// vector.transfer_write %4, %dest[%y, %x], %mask
807807
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
808808
/// : vector<[4]x[4]xf32>, memref<?x?xf32>
809809
/// ```
@@ -832,7 +832,7 @@ struct LowerIllegalTransposeStoreViaZA
832832
auto resultType = transposeOp.getResultVectorType();
833833

834834
if (resultType.getRank() != 2)
835-
return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
835+
return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
836836

837837
if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
838838
return rewriter.notifyMatchFailure(
@@ -865,7 +865,7 @@ struct LowerIllegalTransposeStoreViaZA
865865
// vscale (and emitting multiple implementations) we can't make use of the
866866
// rows of the tile after 1*vscale rows.
867867
Value tile = undefTile;
868-
for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
868+
for (int d = 0; d < numSlicesPerTile; ++d) {
869869
Value vector = rewriter.create<vector::ExtractOp>(
870870
loc, transposeOp.getVector(),
871871
rewriter.getIndexAttr(d + smeTile.row));

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,9 @@ func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memre
573573

574574
// -----
575575

576-
// CHECK: @transpose_store_scalable_via_za_masked(
577-
// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
578-
// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
576+
// CHECK-LABEL: @transpose_store_scalable_via_za_masked(
577+
// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
578+
// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
579579
func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
580580
// CHECK: %[[C2:.*]] = arith.constant 2 : index
581581
// CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
@@ -590,11 +590,11 @@ func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest
590590

591591
// -----
592592

593-
// CHECK: @transpose_store_scalable_via_za_multi_tile(
594-
// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32>
595-
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
596-
// CHECK-SAME: %[[I:.*]]: index,
597-
// CHECK-SAME: %[[J:.*]]: index)
593+
// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile(
594+
// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32>
595+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
596+
// CHECK-SAME: %[[I:.*]]: index,
597+
// CHECK-SAME: %[[J:.*]]: index)
598598
func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
599599
// CHECK: %[[C4:.*]] = arith.constant 4 : index
600600
// CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -615,3 +615,34 @@ func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %
615615
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>, memref<?x?xf32>
616616
return
617617
}
618+
619+
// -----
620+
621+
// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_with_scalable_extracts
622+
func.func @transpose_store_scalable_via_za_multi_tile_with_scalable_extracts(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
623+
// <check extracts from lower 4 x vscale of %vec>
624+
// CHECK: vector.scalable.extract
625+
// CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32>
626+
// CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_LOWER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
627+
// CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I:.[a-z0-9]+]], %[[J:[a-z0-9]+]]]
628+
629+
// <check extracts from upper 4 x vscale of %vec>
630+
// CHECK: vector.scalable.extract
631+
// CHECK: %[[ROW_2_UPPER:.*]] = vector.scalable.extract %{{.*}}[4] : vector<[4]xf32> from vector<[8]xf32>
632+
// CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_UPPER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
633+
// CHECK: %[[I_OFFSET:.*]] = arith.addi %c4_vscale, %[[I]] : index
634+
// CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I_OFFSET]], %[[J]]]
635+
%tr = vector.transpose %vec, [1, 0] : vector<2x[8]xf32> to vector<[8]x2xf32>
636+
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[8]x2xf32>, memref<?x?xf32>
637+
return
638+
}
639+
640+
// -----
641+
642+
// CHECK-LABEL: @negative_transpose_store_scalable_via_za__bad_source_shape
643+
// CHECK-NOT: arm_sme.get_tile
644+
func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vector<2x[7]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
645+
%tr = vector.transpose %vec, [1, 0] : vector<2x[7]xf32> to vector<[7]x2xf32>
646+
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
647+
return
648+
}

0 commit comments

Comments
 (0)