@@ -82,6 +82,75 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
82
82
}
83
83
};
84
84
85
+ class ScatterOpConverter : public OpRewritePattern <tosa::ScatterOp> {
86
+ static Value createTensorDim (OpBuilder &builder, Location loc, Value tensor,
87
+ int64_t dim) {
88
+ return builder.createOrFold <tensor::DimOp>(loc, tensor, dim);
89
+ }
90
+
91
+ static Value createIndexConst (OpBuilder &builder, Location loc,
92
+ int64_t value) {
93
+ return builder.create <arith::ConstantIndexOp>(loc, value);
94
+ }
95
+
96
+ public:
97
+ using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
98
+
99
+ LogicalResult matchAndRewrite (tosa::ScatterOp scatter,
100
+ PatternRewriter &rewriter) const final {
101
+ auto valuesIn = scatter.getValuesIn ();
102
+ auto indices = scatter.getIndices ();
103
+ auto input = scatter.getInput ();
104
+ auto loc = scatter.getLoc ();
105
+
106
+ // N, W, C are chosen to match the TOSA spec
107
+ auto dimN = createTensorDim (rewriter, loc, input, 0 );
108
+ auto dimW = createTensorDim (rewriter, loc, input, 1 );
109
+ auto dimC = createTensorDim (rewriter, loc, input, 2 );
110
+
111
+ auto zero = createIndexConst (rewriter, loc, 0 );
112
+ auto one = createIndexConst (rewriter, loc, 1 );
113
+
114
+ // Loop bounds
115
+ auto lbs = llvm::SmallVector<Value>(2 , zero);
116
+ auto steps = llvm::SmallVector<Value>(2 , one);
117
+ auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
118
+
119
+ auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
120
+ ValueRange args) -> scf::ValueVector {
121
+ auto n = ivs[0 ];
122
+
123
+ // Read the index and cast it to index type
124
+ auto index = builder.create <tensor::ExtractOp>(loc, indices, ivs);
125
+ auto castIndex = builder.create <arith::IndexCastOp>(
126
+ loc, builder.getIndexType (), index);
127
+
128
+ // Offset, sizes, and strides for the input tensor
129
+ auto inputOffset = llvm::to_vector (ivs);
130
+ inputOffset.push_back (zero);
131
+
132
+ llvm::SmallVector<Value> sizes = {one, one, dimC};
133
+ llvm::SmallVector<Value> strides = {one, one, one};
134
+
135
+ auto slice = builder.create <tensor::ExtractSliceOp>(
136
+ loc, input, inputOffset, sizes, strides);
137
+
138
+ // Insert the slice into the output accumulator tensor.
139
+ llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
140
+ auto updated = builder.create <tensor::InsertSliceOp>(
141
+ loc, slice, args[0 ], outputOffset, sizes, strides);
142
+
143
+ return {updated};
144
+ };
145
+
146
+ auto loops = scf::buildLoopNest (rewriter, loc, lbs, ubs, steps,
147
+ ValueRange{valuesIn}, buildBody);
148
+ rewriter.replaceOp (scatter, loops.results );
149
+
150
+ return success ();
151
+ }
152
+ };
153
+
85
154
class WhileOpConverter : public OpRewritePattern <tosa::WhileOp> {
86
155
public:
87
156
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
@@ -106,6 +175,6 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
106
175
107
176
void mlir::tosa::populateTosaToSCFConversionPatterns (
108
177
RewritePatternSet *patterns) {
109
- patterns->add <IfOpConverter>(patterns-> getContext ());
110
- patterns-> add <WhileOpConverter>( patterns->getContext ());
178
+ patterns->add <IfOpConverter, ScatterOpConverter, WhileOpConverter>(
179
+ patterns->getContext ());
111
180
}
0 commit comments