Skip to content

Commit 538f135

Browse files
authored
[mlir][ArmSME] Fix scalable dims check in isValidSMETileVectorType (#65254)
Check for allDimsScalable is incorrect and currently permits fixed vectors.
1 parent 4ae0db4 commit 538f135

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mlir/lib/Dialect/ArmSME/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ bool mlir::arm_sme::isValidSMETileElementType(Type type) {
3131
}
3232

3333
bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
34-
if ((vType.getRank() != 2) && vType.allDimsScalable())
34+
if ((vType.getRank() != 2) || !vType.allDimsScalable())
3535
return false;
3636

3737
auto elemType = vType.getElementType();

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,17 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
154154
return %0 : tensor<?x?xi8>
155155
}
156156

157+
// -----
158+
159+
// CHECK-LABEL: @transfer_write_2d__fixed
160+
// CHECK: vector.transfer_write
161+
// CHECK-NOT: arm_sme.tile_store
162+
func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
163+
%c0 = arith.constant 0 : index
164+
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
165+
return
166+
}
167+
157168
// =============================================================================
158169
// vector.broadcast
159170
// =============================================================================

0 commit comments

Comments
 (0)