Skip to content

Commit 4880bfc

Browse files
rafaelubalmwmgehre-amd
authored andcommitted
Lowering for 'tosa.scatter'
This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering: ``` func.func @tosa( %valuesIn : tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input : tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %0 = "tosa.scatter"(%valuesIn, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>) return %0 : tensor<3x7x5xi32> } ``` translates to func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %c0 = arith.constant 0 : index %c3 = arith.constant 3 : index %c1 = arith.constant 1 : index %c6 = arith.constant 6 : index %c2 = arith.constant 2 : index %c5 = arith.constant 5 : index %c0_0 = arith.constant 0 : index %c1_1 = arith.constant 1 : index %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) { %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) { %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32> %2 = arith.index_cast %extracted : i32 to index %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32> %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32> scf.yield %inserted_slice : tensor<3x7x5xi32> } scf.yield %1 : tensor<3x7x5xi32> } return %0 : tensor<3x7x5xi32> } ``` We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons: - The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon). - The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified. Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D151117
1 parent 07d8cd0 commit 4880bfc

File tree

3 files changed

+102
-3
lines changed

3 files changed

+102
-3
lines changed

mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,75 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
8282
}
8383
};
8484

85+
class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
86+
static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
87+
int64_t dim) {
88+
return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
89+
}
90+
91+
static Value createIndexConst(OpBuilder &builder, Location loc,
92+
int64_t value) {
93+
return builder.create<arith::ConstantIndexOp>(loc, value);
94+
}
95+
96+
public:
97+
using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
98+
99+
LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
100+
PatternRewriter &rewriter) const final {
101+
auto valuesIn = scatter.getValuesIn();
102+
auto indices = scatter.getIndices();
103+
auto input = scatter.getInput();
104+
auto loc = scatter.getLoc();
105+
106+
// N, W, C are chosen to match the TOSA spec
107+
auto dimN = createTensorDim(rewriter, loc, input, 0);
108+
auto dimW = createTensorDim(rewriter, loc, input, 1);
109+
auto dimC = createTensorDim(rewriter, loc, input, 2);
110+
111+
auto zero = createIndexConst(rewriter, loc, 0);
112+
auto one = createIndexConst(rewriter, loc, 1);
113+
114+
// Loop bounds
115+
auto lbs = llvm::SmallVector<Value>(2, zero);
116+
auto steps = llvm::SmallVector<Value>(2, one);
117+
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
118+
119+
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
120+
ValueRange args) -> scf::ValueVector {
121+
auto n = ivs[0];
122+
123+
// Read the index and cast it to index type
124+
auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
125+
auto castIndex = builder.create<arith::IndexCastOp>(
126+
loc, builder.getIndexType(), index);
127+
128+
// Offset, sizes, and strides for the input tensor
129+
auto inputOffset = llvm::to_vector(ivs);
130+
inputOffset.push_back(zero);
131+
132+
llvm::SmallVector<Value> sizes = {one, one, dimC};
133+
llvm::SmallVector<Value> strides = {one, one, one};
134+
135+
auto slice = builder.create<tensor::ExtractSliceOp>(
136+
loc, input, inputOffset, sizes, strides);
137+
138+
// Insert the slice into the output accumulator tensor.
139+
llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
140+
auto updated = builder.create<tensor::InsertSliceOp>(
141+
loc, slice, args[0], outputOffset, sizes, strides);
142+
143+
return {updated};
144+
};
145+
146+
auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
147+
ValueRange{valuesIn}, buildBody);
148+
rewriter.replaceOp(scatter, loops.results);
149+
150+
return success();
151+
}
152+
};
153+
85154
class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
86155
public:
87156
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
@@ -106,6 +175,6 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
106175

107176
void mlir::tosa::populateTosaToSCFConversionPatterns(
108177
RewritePatternSet *patterns) {
109-
patterns->add<IfOpConverter>(patterns->getContext());
110-
patterns->add<WhileOpConverter>(patterns->getContext());
178+
patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
179+
patterns->getContext());
111180
}

mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct TosaToSCF : public impl::TosaToSCFBase<TosaToSCF> {
3737
RewritePatternSet patterns(&getContext());
3838
ConversionTarget target(getContext());
3939
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
40-
target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
40+
target.addIllegalOp<tosa::IfOp, tosa::ScatterOp, tosa::WhileOp>();
4141
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
4242

4343
auto *op = getOperation();

mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,33 @@ func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>)
5656

5757
return %0 : tensor<f32>
5858
}
59+
60+
// -----
61+
62+
// CHECK-LABEL: func @scatter_test
63+
// CHECK-SAME: ([[VALUES_IN:%.+]]: tensor<3x7x5xi32>, [[INDICES:%.+]]: tensor<3x6xi32>, [[INPUT:%.+]]: tensor<3x6x5xi32>)
64+
func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
65+
66+
// CHECK-DAG: [[C_0:%.+]] = arith.constant 0 : index
67+
// CHECK-DAG: [[C_1:%.+]] = arith.constant 1 : index
68+
// CHECK-DAG: [[C_2:%.+]] = arith.constant 2 : index
69+
// CHECK-DAG: [[C_3:%.+]] = arith.constant 3 : index
70+
// CHECK-DAG: [[C_5:%.+]] = arith.constant 5 : index
71+
// CHECK-DAG: [[C_6:%.+]] = arith.constant 6 : index
72+
// CHECK-DAG: [[C_0_0:%.+]] = arith.constant 0 : index
73+
// CHECK-DAG: [[C_1_0:%.+]] = arith.constant 1 : index
74+
// CHECK: [[RESULT_0:%.+]] = scf.for [[ITER_VAR_0:%.+]] = [[C_0_0]] to [[C_3]] step [[C_1_0]] iter_args([[ITER_ARG_0:%.+]] = [[VALUES_IN]]) -> (tensor<3x7x5xi32>) {
75+
// CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) {
76+
// CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32>
77+
// CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index
78+
// CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
79+
// CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
80+
// CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32>
81+
// CHECK: }
82+
// CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32>
83+
// CHECK: }
84+
%0 = "tosa.scatter"(%values_in, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>)
85+
86+
// CHECK: return [[RESULT_0]] : tensor<3x7x5xi32>
87+
return %0 : tensor<3x7x5xi32>
88+
}

0 commit comments

Comments
 (0)