Skip to content

Commit c9529f7

Browse files
authored
[mlir] Drop outermost dims in slice rank reduction inference (#95020)
The `getDroppedDims` utility function does not follow the convention of dropping outermost unit dimensions first when inferring a rank reduction mask for a slice. This PR updates the implementation to match this convention.
1 parent 580343d commit c9529f7

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,40 +135,40 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
135135
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
136136
ArrayRef<OpFoldResult> mixedSizes) {
137137
llvm::SmallBitVector droppedDims(mixedSizes.size());
138-
int64_t shapePos = 0;
138+
int64_t shapePos = reducedShape.size() - 1;
139139

140-
for (const auto &size : enumerate(mixedSizes)) {
140+
for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
141+
size_t idx = mixedSizes.size() - size.index() - 1;
141142
// Rank-reduced dims must have a static unit dimension.
142143
bool isStaticUnitSize =
143144
size.value().is<Attribute>() &&
144145
llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
145146

146-
if (shapePos == static_cast<int64_t>(reducedShape.size())) {
147+
if (shapePos < 0) {
147148
// There are no more dims in the reduced shape. All remaining sizes must
148149
// be rank-reduced dims.
149150
assert(isStaticUnitSize && "expected unit dim");
150-
droppedDims.set(size.index());
151+
droppedDims.set(idx);
151152
continue;
152153
}
153154

154155
// Dim is preserved if the size is not a static 1.
155156
if (!isStaticUnitSize) {
156-
++shapePos;
157+
--shapePos;
157158
continue;
158159
}
159160

160161
// Dim is preserved if the reduced shape dim is also 1.
161162
if (reducedShape[shapePos] == 1) {
162-
++shapePos;
163+
--shapePos;
163164
continue;
164165
}
165166

166167
// Otherwise: Dim is dropped.
167-
droppedDims.set(size.index());
168+
droppedDims.set(idx);
168169
}
169170

170-
assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
171-
"dimension mismatch");
171+
assert(shapePos < 0 && "dimension mismatch");
172172
return droppedDims;
173173
}
174174

mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
102102
return %1 : tensor<?x12xf32>
103103
}
104104

105+
// CHECK-LABEL: func @unit_insert_slice_of_unit_transfer_write(
106+
// CHECK-SAME: %[[t1:.*]]: tensor<1x1x12xf32>, %[[v:.*]]: vector<1x6xf32>, %[[s:.*]]: index
107+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
108+
// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c0]], %[[c0]], %[[s]]] {in_bounds = [true, true]} : vector<1x6xf32>, tensor<1x1x12xf32>
109+
// CHECK: return %[[r]]
110+
func.func @unit_insert_slice_of_unit_transfer_write(%t1 : tensor<1x1x12xf32>, %v : vector<1x6xf32>, %s : index, %t2 : tensor<1x6xf32>) -> tensor<1x1x12xf32> {
111+
%c0 = arith.constant 0 : index
112+
%0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<1x6xf32>, tensor<1x6xf32>
113+
%1 = tensor.insert_slice %0 into %t1[0, 0, %s] [1, 1, 6] [1, 1, 1] : tensor<1x6xf32> into tensor<1x1x12xf32>
114+
return %1 : tensor<1x1x12xf32>
115+
}
116+
105117
// CHECK-LABEL: func @insert_slice_of_transfer_write_non_leading_rank_reduction(
106118
// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
107119
// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,13 @@ func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1x1xf32>, %
282282

283283
// -----
284284

285+
// CHECK-DAG: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
285286
// CHECK-LABEL: func @insert_slice_of_insert_slice(
286287
// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<f32>
287288
// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
288289
// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
289-
// CHECK: tensor.insert_slice %[[t]] into %[[r1]][5, %[[pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
290+
// CHECK: %[[composed_pos:.+]] = affine.apply #[[$map]]()[%[[pos]]]
291+
// CHECK: tensor.insert_slice %[[t]] into %[[r1]][3, %[[composed_pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
290292
func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index)
291293
-> tensor<1x14xf32>
292294
{

0 commit comments

Comments
 (0)