Skip to content

Commit d38d606

Browse files
committed
[mlir][tensor][linalg] Enhance pack op propagation across generic ops.
Considering the case that generic + pack (with outer_dim_perms), the truth is that it is equipvelent to generic + pack + transpose. There are two steps to bubble up the pack op accross the generic op. Step 1. swap generic + pack -> pack + generic. In this step, we can bind the packing information to dimensions of iteration domain. With the information, we can pack the operands with corresponding data tile sizes; the packed inner dimensions will be appended to the indexing_maps. Note that the outer dimensions of indexing maps are not changed at all. Step 2. Fold the transpose into generic op. The step two is just updating the indexing map, so we do not have to handle outer_dim_perms anymore. There could be step 3 to extract the transpose op out (i.e., generic -> transpose + generic), then we can fold the transpose into the pack op. This step is not done in the revision. Co-authored-by: Lorenzo Chelini <[email protected]> Reviewed By: chelini Differential Revision: https://reviews.llvm.org/D139680
1 parent 6e6fe27 commit d38d606

File tree

2 files changed

+200
-57
lines changed

2 files changed

+200
-57
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 108 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/Utils/Utils.h"
1717
#include "mlir/Dialect/Utils/IndexingUtils.h"
1818
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
#include "llvm/Support/Debug.h"
1920

2021
namespace mlir {
2122
#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
@@ -29,10 +30,66 @@ using namespace mlir::linalg;
2930

3031
namespace {
3132

33+
// The struct contains the infomation about mapping packing information to
34+
// the iteration domain of Linalg ops.
35+
struct PackInfo {
36+
int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
37+
// InnerDimsPos on iteration domain, which follows the order in pack ops.
38+
SmallVector<int64_t> tiledDimsPos;
39+
// The sizes of tiling data dimensions on iteration domain.
40+
llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
41+
// The mapping from a dimension of iteration domain to the corresponding inner
42+
// tiling dimension on iteration domain.
43+
llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
44+
// The permutation of outer dims (on domain).
45+
SmallVector<int64_t> outerDimsOnDomainPerm;
46+
Optional<Value> paddingValue;
47+
};
48+
49+
static PackInfo getPackingInfoFromConsumer(
50+
AffineMap indexingMap, ArrayRef<OpFoldResult> innerTileSizes,
51+
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
52+
Optional<Value> paddingValue = llvm::None) {
53+
LLVM_DEBUG(
54+
{ llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
55+
PackInfo packInfo;
56+
packInfo.paddingValue = paddingValue;
57+
int64_t origNumDims = indexingMap.getNumDims();
58+
SmallVector<AffineExpr> exprs(indexingMap.getResults());
59+
for (auto [index, innerDimPos, tileSize] :
60+
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
61+
innerDimsPos, innerTileSizes)) {
62+
int64_t domainDimPos =
63+
exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
64+
packInfo.tiledDimsPos.push_back(domainDimPos);
65+
packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
66+
packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
67+
LLVM_DEBUG({
68+
llvm::dbgs() << "map innerDimPos=" << innerDimPos
69+
<< " to iteration dimension (d" << domainDimPos << ", d"
70+
<< packInfo.tileToPointMapping[domainDimPos]
71+
<< "), which has size=("
72+
<< packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
73+
});
74+
}
75+
76+
for (auto dim : outerDimsPerm)
77+
packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
78+
if (!packInfo.outerDimsOnDomainPerm.empty()) {
79+
LLVM_DEBUG({
80+
llvm::dbgs() << "map outer dimsDimsPerm to ";
81+
for (auto dim : packInfo.outerDimsOnDomainPerm)
82+
llvm::dbgs() << dim << " ";
83+
llvm::dbgs() << "\n";
84+
});
85+
}
86+
87+
return packInfo;
88+
}
89+
3290
/// Returns a tuple for packed operand and indexing_map with the assumptions:
3391
/// 1) The generic op is the producer of the pack op.
3492
/// 2) The generic op has only one result.
35-
/// 3) The indexing map of the output operand is identity.
3693
/// If the operand is a scalar or packing dimensions are all irrelevant to the
3794
/// operand, the opreand and the updated indexing map will be returned.
3895
/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
@@ -62,62 +119,57 @@ namespace {
62119
/// inner_tiles = [8]
63120
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
64121
static std::tuple<Value, AffineMap>
65-
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc,
66-
tensor::PackOp packOp, GenericOp genericOp,
67-
OpOperand *opOperand) {
68-
int numOrigLoops = genericOp.getNumLoops();
69-
int64_t numInnerLoops = packOp.getInnerDimsPos().size();
122+
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
123+
GenericOp genericOp, OpOperand *opOperand) {
124+
int64_t numOrigLoops = genericOp.getNumLoops();
125+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
70126
int64_t numLoops = numOrigLoops + numInnerLoops;
71127
AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
128+
llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
72129
SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
73-
74130
if (genericOp.isScalar(opOperand))
75-
return std::make_tuple(
76-
opOperand->get(),
77-
AffineMap::get(numLoops, 0, exprs, packOp.getContext()));
78-
79-
llvm::SetVector<int64_t> innerDimsPosSet(packOp.getInnerDimsPos().begin(),
80-
packOp.getInnerDimsPos().end());
81-
// Mapping from AffinDimExpr of indexing maps to the operand shape dimension.
82-
DenseMap<int64_t, int64_t> iterMapToDim;
83-
for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) {
131+
return std::make_tuple(opOperand->get(),
132+
AffineMap::get(numLoops, 0, exprs, b.getContext()));
133+
134+
// Step 1. Construct the information of packing data dimensions; append inner
135+
// dimensions to the indexing maps for the operand.
136+
for (auto [index, expr] : llvm::enumerate(exprs)) {
84137
int64_t dimPos = expr.cast<AffineDimExpr>().getPosition();
85-
if (!innerDimsPosSet.contains(dimPos))
86-
continue;
87-
iterMapToDim[dimPos] = index;
138+
domainDimToOperandDim[dimPos] = index;
88139
}
89-
90-
// Construct the information of packing data dimensions and new indexing maps
91-
// for the operand.
92140
SmallVector<int64_t> innerDimsPos;
93141
SmallVector<OpFoldResult> innerTileSizes;
94-
for (auto [index, value] : llvm::enumerate(
95-
llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) {
96-
int64_t dimPos = std::get<0>(value);
97-
if (!iterMapToDim.count(dimPos))
142+
for (auto dimPos : packInfo.tiledDimsPos) {
143+
if (!domainDimToOperandDim.count(dimPos))
98144
continue;
99-
innerDimsPos.push_back(iterMapToDim[dimPos]);
100-
innerTileSizes.push_back(std::get<1>(value));
101-
exprs.push_back(b.getAffineDimExpr(numOrigLoops + index));
145+
int64_t index = domainDimToOperandDim[dimPos];
146+
innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
147+
innerDimsPos.push_back(index);
148+
exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
102149
}
103-
auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext());
104150

151+
// Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op.
152+
// TODO: should we propagate the permutation of outer dims to the pack op?
105153
SmallVector<int64_t> outerDimsPerm;
106-
for (auto outDim : packOp.getOuterDimsPerm()) {
107-
if (!iterMapToDim.count(outDim))
108-
continue;
109-
outerDimsPerm.push_back(iterMapToDim[outDim]);
154+
if (!packInfo.outerDimsOnDomainPerm.empty()) {
155+
SmallVector<int64_t> inversedOuterPerm =
156+
invertPermutationVector(packInfo.outerDimsOnDomainPerm);
157+
for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
158+
int64_t dimPos = exprs[i].cast<AffineDimExpr>().getPosition();
159+
exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
160+
}
110161
}
162+
auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
111163

112164
// The operand does not have dimensions that relates to pack op.
113-
if (innerDimsPos.empty() && outerDimsPerm.empty())
165+
if (innerDimsPos.empty())
114166
return std::make_tuple(opOperand->get(), indexingMap);
115167

116168
auto empty = tensor::PackOp::createDestinationTensor(
117169
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
118170
auto packedOperand = b.create<tensor::PackOp>(
119171
loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
120-
packOp.getPaddingValue(), outerDimsPerm);
172+
packInfo.paddingValue, outerDimsPerm);
121173
return std::make_tuple(packedOperand, indexingMap);
122174
}
123175

@@ -187,34 +239,45 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
187239
return failure();
188240

189241
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
190-
// TODO: Add support for all permutation indexing maps.
191-
if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity())
192-
return rewriter.notifyMatchFailure(
193-
packOp, "the result of generic op does not have identity indexing_map");
242+
auto packInfo = getPackingInfoFromConsumer(
243+
genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(),
244+
packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(),
245+
packOp.getPaddingValue());
194246

195247
Location loc = packOp.getLoc();
196248
SmallVector<Value> inputOperands;
197249
SmallVector<AffineMap> indexingMaps;
198250
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
199251
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
200-
rewriter, loc, packOp, genericOp, inputOperand);
252+
rewriter, loc, packInfo, genericOp, inputOperand);
201253
inputOperands.push_back(packedOperand);
202254
indexingMaps.push_back(packedIndexingMap);
203255
}
204256

205257
int64_t numLoops = genericOp.getNumLoops();
206-
int64_t numInnerLoops = packOp.getInnerDimsPos().size();
258+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
207259
int64_t newNumLoops = numLoops + numInnerLoops;
208260
SmallVector<utils::IteratorType> iterTypes =
209261
genericOp.getIteratorTypesArray();
210262
iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
211263

264+
// Rebuild the indexing map for the corresponding init operand.
265+
auto [packedOutOperand, packedOutIndexingMap] =
266+
getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
267+
opOperand);
212268
SmallVector<AffineExpr> outExprs(
213-
genericOp.getMatchingIndexingMap(opOperand).getResults());
269+
packedOutIndexingMap.getResults().drop_back(numInnerLoops));
270+
// Apply transpose to the indexing map, because we'll replace the init operand
271+
// with the destination of pack op.
272+
auto outerDimsPerm = packOp.getOuterDimsPerm();
273+
if (!outerDimsPerm.empty()) {
274+
applyPermutationToVector<AffineExpr>(outExprs, outerDimsPerm);
275+
}
214276
for (int i = 0; i < numInnerLoops; ++i)
215277
outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
216-
indexingMaps.push_back(
217-
AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()));
278+
AffineMap outMap =
279+
AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext());
280+
indexingMaps.push_back(outMap);
218281

219282
auto newGenericOp = rewriter.create<linalg::GenericOp>(
220283
loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,17 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
9696
into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
9797
return %pack : tensor<16x4x32x16xi32>
9898
}
99-
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
99+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
100+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
100101
// CHECK: func.func @elem_pack_transpose_outer_dims
101102
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
102103
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
103-
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
104+
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32>
104105
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
105-
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
106-
// CHECK-SAME: into %[[ARG0_EMPTY]]
106+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
107+
// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32>
107108
// CHECK: %[[ELEM:.+]] = linalg.generic
108-
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
109+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
109110
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
110111
// CHECK-SAME: ins(%[[PACK_ARG0]]
111112
// CHECK-SAME: outs(%[[DEST]]
@@ -130,16 +131,17 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
130131
into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
131132
return %pack : tensor<16x4x16x32xi32>
132133
}
133-
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
134+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
135+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
134136
// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims
135137
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
136138
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
137-
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
139+
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
138140
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
139-
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
141+
// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
140142
// CHECK-SAME: into %[[ARG0_EMPTY]]
141143
// CHECK: %[[ELEM:.+]] = linalg.generic
142-
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
144+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
143145
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
144146
// CHECK-SAME: ins(%[[PACK_ARG0]]
145147
// CHECK-SAME: outs(%[[DEST]]
@@ -200,6 +202,37 @@ func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %d
200202

201203
// -----
202204

205+
#map = affine_map<(d0, d1, d2, d3) -> (d3)>
206+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
207+
func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> {
208+
%0 = tensor.empty() : tensor<1x56x57x64xf32>
209+
%1 = linalg.generic {
210+
indexing_maps = [#map, #map1],
211+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
212+
ins(%arg0 : tensor<64xf32>)
213+
outs(%0 : tensor<1x56x57x64xf32>) {
214+
^bb0(%in: f32, %out: f32):
215+
linalg.yield %in : f32
216+
} -> tensor<1x56x57x64xf32>
217+
%2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
218+
return %2 : tensor<1x2x56x57x32xf32>
219+
}
220+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
221+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
222+
// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims2
223+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
224+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
225+
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
226+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
227+
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
228+
// CHECK-SAME: into %[[ARG0_EMPTY]]
229+
// CHECK: %[[RES:.+]] = linalg.generic
230+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
231+
// CHECK-SAME: ins(%[[PACKED_ARG0]]
232+
// CHECK-SAME: outs(%[[DEST]]
233+
234+
// -----
235+
203236
#map0 = affine_map<(d0, d1) -> (d0, d1)>
204237
#map1 = affine_map<(d0, d1) -> (d0)>
205238
#map2 = affine_map<(d0, d1) -> (d1)>
@@ -225,6 +258,53 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
225258
into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
226259
return %4 : tensor<100x200x4x16x16x32xi32>
227260
}
228-
// CHECK: func.func @transpose_pack
229-
// CHECK: linalg.generic
230-
// CHECK: tensor.pack
261+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
262+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
263+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
264+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
265+
// CHECK: func.func @transpose_pack
266+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
267+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
268+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
269+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
270+
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
271+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
272+
// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32]
273+
// CHECK-SAME: into %[[ARG0_EMPTY]]
274+
// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
275+
// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
276+
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
277+
// CHECK-SAME: into %[[ARG2_EMPTY]]
278+
// CHECK: %[[RES:.+]] = linalg.generic
279+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
280+
// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
281+
// CHECK-SAME: outs(%[[DEST]]
282+
283+
// -----
284+
285+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
286+
#map1 = affine_map<(d0, d1) -> (d0)>
287+
#map2 = affine_map<(d0, d1) -> (d1)>
288+
func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
289+
{
290+
%init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
291+
%transpose = linalg.generic {
292+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
293+
affine_map<(d0, d1, d2, d3) -> (d0)>,
294+
affine_map<(d0, d1, d2, d3) -> (d1)>,
295+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
296+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
297+
ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
298+
outs(%init_transpose : tensor<100x200x128x256xi32>) {
299+
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
300+
%0 = arith.addi %b0, %b1 : i32
301+
%1 = arith.addi %0, %b2 : i32
302+
linalg.yield %1 : i32
303+
} -> tensor<100x200x128x256xi32>
304+
%4 = tensor.pack %transpose
305+
outer_dims_perm = [1, 2, 3, 0]
306+
inner_dims_pos = [3, 2]
307+
inner_tiles = [16, 32]
308+
into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
309+
return %4 : tensor<200x4x16x100x16x32xi32>
310+
}

0 commit comments

Comments
 (0)