Skip to content

Commit 26255c1

Browse files
committed
Fixups
1 parent 4880622 commit 26255c1

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,8 @@ class DropInnerMostUnitDimsTransferWrite
13821382

13831383
auto resultTargetVecType =
13841384
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1385-
targetType.getElementType());
1385+
targetType.getElementType(),
1386+
targetType.getScalableDims().drop_back(dimsToDrop));
13861387

13871388
Location loc = writeOp.getLoc();
13881389
SmallVector<OpFoldResult> sizes =

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
179179

180180
func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<[4]x1xf32> {
181181
%c0 = arith.constant 0 : index
182-
%cst_0 = arith.constant 0.000000e+00 : f32
183-
%4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
184-
return %4 : vector<[4]x1xf32>
182+
%pad = arith.constant 0.0 : f32
183+
%0 = vector.transfer_read %dest[%index, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
184+
return %0 : vector<[4]x1xf32>
185185
}
186186
// CHECK: func.func @leading_scalable_dimension_transfer_read
187187
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -196,11 +196,39 @@ func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %i
196196
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
197197
func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<4x[1]xf32> {
198198
%c0 = arith.constant 0 : index
199-
%cst_0 = arith.constant 0.000000e+00 : f32
200-
%4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
201-
return %4 : vector<4x[1]xf32>
199+
%pad = arith.constant 0.0 : f32
200+
%0 = vector.transfer_read %dest[%index, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
201+
return %0 : vector<4x[1]xf32>
202202
}
203203
// CHECK: func.func @trailing_scalable_one_dim_transfer_read
204204
// CHECK-NOT: vector.shape_cast
205205
// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
206206
// CHECK-NOT: vector.shape_cast
207+
208+
// -----
209+
210+
func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>, %index: index) {
211+
%c0 = arith.constant 0 : index
212+
vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
213+
return
214+
}
215+
// CHECK: func.func @leading_scalable_dimension_transfer_write
216+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
217+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
218+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
219+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
220+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
221+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]][%[[IDX]]] {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
222+
223+
// -----
224+
225+
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
226+
func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) {
227+
%c0 = arith.constant 0 : index
228+
vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32>
229+
return
230+
}
231+
// CHECK: func.func @trailing_scalable_one_dim_transfer_write
232+
// CHECK-NOT: vector.shape_cast
233+
// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
234+
// CHECK-NOT: vector.shape_cast

0 commit comments

Comments
 (0)