Skip to content

Commit 96c1611

Browse files
authored
[mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern (#141613)
Given the following example: ``` module { func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> { %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } } ``` We would generate an invalid transpose operation because the calculated permutation would be `[0, 2, 0]` which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions. The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique. The above example would then translate to a transpose having a permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos` is appended to the permutation array and the preceding indices are filled with the remaining dimensions.
1 parent 23384cd commit 96c1611

File tree

3 files changed

+95
-15
lines changed

3 files changed

+95
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,13 +1178,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11781178
int64_t destRank = packOp.getDestRank();
11791179
int64_t numTiles = destRank - srcRank;
11801180

1181-
if (!llvm::all_of(packOp.getInnerDimsPos(),
1182-
[&srcRank, &numTiles](int64_t dimPos) {
1183-
return dimPos >= (srcRank - numTiles - 1);
1184-
}))
1185-
return rewriter.notifyMatchFailure(
1186-
packOp, "Attempting to tile non-trailing source dims!");
1187-
11881181
// 1. Extract the inner tile sizes.
11891182
// Where possible, values are replaced with constant attributes (to match the
11901183
// behaviour of `getPackOpSourceOrPaddedSource`).
@@ -1204,16 +1197,24 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12041197
// %init = tensor.empty()
12051198
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
12061199
// outs(%init)
1207-
// Two assumptions are made:
1208-
// 1. All outer dims are 1 - the corresponding transposition doesn't matter.
1209-
// 2. Inner dims position correspond to the trailing `numTiles` dims.
1210-
SmallVector<int64_t> tilesPermNormalized =
1211-
getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
1200+
// Assumptions made:
1201+
// 1. All outer dims are 1 - the corresponding transposition order doesn't
1202+
// matter, but requires all dim indices to be present.
12121203
SmallVector<int64_t> srcPermForTranspose;
1213-
for (int64_t i = 0; i < (srcRank - numTiles); i++)
1204+
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
1205+
for (int64_t i = 0; i < srcRank; i++) {
1206+
// We assume the `k` dimensions of the inner dim position, where `k` is the
1207+
// rank of the inner tiling, correspond to the last `k` indices of the
1208+
// transpose permutation. This is done by adding the indices not contained
1209+
// in the inner dimension position in order from 0 to `n`. Where n is the
1210+
// rank of the source tensor. For example if we have a source tensor with
1211+
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1212+
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1213+
if (llvm::is_contained(innerDimPos, i))
1214+
continue;
12141215
srcPermForTranspose.push_back(i);
1215-
1216-
srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
1216+
}
1217+
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
12171218

12181219
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
12191220
<< "perm: " << llvm::interleaved(srcPermForTranspose)

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,48 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
229229
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
230230
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
231231
// CHECK: return %[[INSERT]]
232+
233+
// -----
234+
235+
// The following example shows a pack operation that is defined with inner
236+
// dimension positions that are not adjacent, i.e. `[2, 0]`. And the outer
237+
// dimensions of the packed tensor are of unit values, i.e. `1x1x1`.
238+
func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
239+
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
240+
return %pack : tensor<1x1x1x4x1xf32>
241+
}
242+
// CHECK-LABEL: func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer
243+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
244+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
245+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
246+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
247+
// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
248+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
249+
// CHECK-SAME: permutation = [1, 2, 0]
250+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
251+
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
252+
// CHECK: return %[[INSERT]]
253+
254+
// -----
255+
256+
// The following example shows a pack operation where the inner dimension
257+
// positions are specified as [2, 1] which are termed adjacent trailing
258+
// dimensions as they contain the last dimension of the source tensor with a
259+
// neighboring dimension. [1, 2] would also be considered trailing adjacent.
260+
// And the outer dimensions of the packed tensor are all set to unit values
261+
// of `1x1x1`.
262+
func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
263+
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
264+
return %pack : tensor<1x1x1x4x1xf32>
265+
}
266+
// CHECK-LABEL: func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer
267+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
268+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
269+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
270+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
271+
// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
272+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
273+
// CHECK-SAME: permutation = [0, 2, 1]
274+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
275+
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
276+
// CHECK: return %[[INSERT]]

mlir/test/Dialect/Linalg/decompose-unpack.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,37 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens
169169
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
170170
// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
171171
// CHECK: return %[[INSERT]]
172+
173+
// -----
174+
175+
func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
176+
%0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
177+
return %0 : tensor<1x1x4xf32>
178+
}
179+
// CHECK-LABEL: func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer
180+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
181+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
182+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
183+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
184+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
185+
// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>)
186+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
187+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
188+
// CHECK: return %[[INSERT]]
189+
190+
// -----
191+
192+
func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
193+
%pack = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
194+
return %pack : tensor<1x1x4xf32>
195+
}
196+
// CHECK-LABEL: func.func @unpack_with_non_trailing_dimensions_in_inner_dims
197+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
198+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
199+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
200+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
201+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
202+
// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>)
203+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
204+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
205+
// CHECK: return %[[INSERT]]

0 commit comments

Comments
 (0)