Skip to content

Commit bb99503

Browse files
committed
Pick getTosaConstShape helper from 571a987
1 parent e896a3e commit bb99503

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
8484
LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
8585
Value &input2);
8686

87+
Value getTosaConstShape(ImplicitLocOpBuilder &builder,
88+
llvm::ArrayRef<int64_t> shape);
89+
90+
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
91+
llvm::ArrayRef<int64_t> shape);
92+
8793
namespace {
8894

8995
// Creates a TOSA operation and performs shape inference on the individual
@@ -217,7 +223,8 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
217223
}
218224

219225
// Apply an int32_t permutation to some input, that should be of the same
220-
// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
226+
// size as perms. Perms should contain some permutation of 0 - perms.size()
227+
// - 1.
221228
template <typename T>
222229
SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
223230
ArrayRef<int32_t> perms) {

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,11 @@ struct SelfConcatToTile : public OpRewritePattern<tosa::ConcatOp> {
9393
}
9494
SmallVector<int64_t> multiplies(concatType.getRank(), 1);
9595
multiplies[concatOp.getAxis()] = concatOp->getNumOperands();
96-
const int64_t rank = multiplies.size();
97-
auto constantShapeOp = rewriter.create<ConstShapeOp>(
98-
concatOp->getLoc(), shapeType::get(concatOp->getContext(), rank),
99-
DenseIntElementsAttr::get(
100-
RankedTensorType::get({rank}, rewriter.getIndexType()),
101-
multiplies));
96+
auto constantShapeValue =
97+
getTosaConstShape(rewriter, concatOp->getLoc(), multiplies);
10298
auto tileOp = rewriter.createOrFold<tosa::TileOp>(
10399
concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0),
104-
constantShapeOp);
100+
constantShapeValue);
105101
rewriter.replaceOp(concatOp, {tileOp});
106102
return success();
107103
}
@@ -140,19 +136,15 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
140136
for (auto [idx, multiplier] : llvm::enumerate(inputTileMultiples)) {
141137
multiplies[idx] *= multiplier;
142138
}
143-
auto constantShapeOp = rewriter.create<ConstShapeOp>(
139+
auto constantShapeValue = getTosaConstShape(
140+
rewriter,
144141
rewriter.getFusedLoc(
145142
{op.getMultiples().getLoc(), inputTile.getMultiples().getLoc()}),
146-
op.getMultiples().getType(),
147-
DenseIntElementsAttr::get(
148-
RankedTensorType::get(
149-
{cast<shapeType>(op.getMultiples().getType()).getRank()},
150-
rewriter.getIndexType()),
151-
multiplies));
143+
multiplies);
152144

153145
rewriter.modifyOpInPlace(op, [&]() {
154146
op.setOperand(0, inputTile->getOperand(0));
155-
op.setOperand(1, constantShapeOp);
147+
op.setOperand(1, constantShapeValue);
156148
op.getOperation()->setLoc(
157149
FusedLoc::get(getContext(), {inputTile->getLoc(), op.getLoc()}));
158150
});
@@ -828,16 +820,11 @@ struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
828820
llvm::zip_equal(newTileShape, requiredMultipliers)) {
829821
newShape *= multiplier;
830822
}
831-
auto constantShapeOp = rewriter.create<ConstShapeOp>(
832-
tileOp.getMultiples().getLoc(), tileOp.getMultiples().getType(),
833-
DenseIntElementsAttr::get(
834-
RankedTensorType::get(
835-
{cast<shapeType>(tileOp.getMultiples().getType()).getRank()},
836-
rewriter.getIndexType()),
837-
requiredMultipliers));
823+
auto constantShapeValue = getTosaConstShape(
824+
rewriter, tileOp.getMultiples().getLoc(), requiredMultipliers);
838825
auto newTile = rewriter.create<tosa::TileOp>(
839826
tileOp->getLoc(), tileOpInputType.clone(newTileShape),
840-
tileOp->getOperand(0), constantShapeOp);
827+
tileOp->getOperand(0), constantShapeValue);
841828
auto newSlice = rewriter.create<tosa::SliceOp>(
842829
sliceOp->getLoc(), sliceOp.getType(), newTile,
843830
rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr());

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,25 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
160160

161161
return success();
162162
}
163+
164+
namespace {
165+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
166+
return to_vector(llvm::map_range(shape, [](int64_t dim) {
167+
return ShapedType::isDynamic(dim) ? -1 : dim;
168+
}));
169+
}
170+
} // namespace
171+
172+
Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
173+
llvm::ArrayRef<int64_t> shape) {
174+
auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
175+
auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
176+
mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
177+
return mlir_op->getResult(0);
178+
}
179+
180+
Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
181+
llvm::ArrayRef<int64_t> shape) {
182+
ImplicitLocOpBuilder builder(loc, rewriter);
183+
return getTosaConstShape(builder, shape);
184+
}

0 commit comments

Comments
 (0)