Skip to content

Commit 5e94f5d

Browse files
committed
[mlir][tosa] Always generated pad_const and remove input_zp attr for PadOp
Co-authored-by: Udaya Ranga <[email protected]> Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Change-Id: I2b7a0169b7ec1158d28779713ad125c061e04592
1 parent 8f4ee42 commit 5e94f5d

File tree

14 files changed

+155
-215
lines changed

14 files changed

+155
-215
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,6 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
197197
input, paddings);
198198
}]>;
199199

200-
def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
201-
(ins "Type":$outputType, "Value":$input, "Value":$paddings,
202-
"Value":$pad_value),
203-
[{
204-
buildExplicitValuePadOpWithQuantInfo($_builder, $_state, outputType,
205-
input, paddings, pad_value);
206-
}]>;
207-
208200
// Wrapper over base I32EnumAttr to set common fields.
209201
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
210202
: I32EnumAttr<name, description, cases> {

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ namespace tosa {
168168
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
169169
Type srcElemType, int64_t zp = 0);
170170

171+
// Create a pad-const const tensor with value of `val` of required data-type
172+
std::optional<Value> createPadConstTensor(OpBuilder &builder, Location loc,
173+
Value src, int32_t val = 0);
174+
171175
} // namespace tosa
172176
} // namespace mlir
173177

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,8 +1882,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
18821882
let arguments = (ins
18831883
Tosa_RankedTensor:$input1,
18841884
Tosa_Shape:$padding,
1885-
Optional<Tosa_ScalarTensor>:$pad_const,
1886-
OptionalAttr<I32Attr>:$input_zp
1885+
Tosa_ScalarTensor:$pad_const
18871886
);
18881887

18891888
let results = (outs
@@ -1895,10 +1894,8 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
18951894
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
18961895
];
18971896

1898-
let builders = [Tosa_PadOpQuantInfoBuilder,
1899-
Tosa_ExplicitValuePadOpQuantInfoBuilder];
1897+
let builders = [Tosa_PadOpQuantInfoBuilder];
19001898

1901-
let hasCanonicalizer = 1;
19021899
let hasFolder = 1;
19031900
let hasVerifier = 1;
19041901
}

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
3636
return input.getType();
3737

3838
// The input type must be cast into a tensor with the same rank and all static
39-
// dimensions set to 1. This prevents the generation of a tensor.collapse_shape
40-
// op that converts a dynamically shaped tensor into a 0D tensor. While such
41-
// construct is not incorrect on its own, bufferization cannot properly handle
42-
// it at the moment, so we avoid it.
39+
// dimensions set to 1. This prevents the generation of a
40+
// tensor.collapse_shape op that converts a dynamically shaped tensor into a
41+
// 0D tensor. While such construct is not incorrect on its own, bufferization
42+
// cannot properly handle it at the moment, so we avoid it.
4343
SmallVector<int64_t> shape(input.getType().getRank(), 1);
4444
return input.getType().clone(shape);
4545
}
@@ -58,29 +58,31 @@ TensorType inferReshapeExpandedType(TensorType inputType,
5858
int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
5959

6060
// Compute result shape
61-
auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
62-
// If this is not a placeholder, do not change it
63-
if (size >= 0)
64-
return size;
65-
66-
// If we do not know the total size of the tensor, keep this dimension
67-
// dynamic in the result shape.
68-
if (!inputIsStatic)
69-
return ShapedType::kDynamic;
70-
71-
// Calculate the product of all elements in 'newShape' except for the -1
72-
// placeholder, which we discard by negating the result.
73-
int64_t totalSizeNoPlaceholder = -std::accumulate(
74-
newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
75-
76-
// If there is a 0 component in 'newShape', resolve the placeholder as 0.
77-
if (totalSizeNoPlaceholder == 0)
78-
return 0;
79-
80-
// Resolve the placeholder as the quotient between the total tensor size and
81-
// the product of all other sizes.
82-
return totalSize / totalSizeNoPlaceholder;
83-
});
61+
auto resultShape =
62+
llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
63+
// If this is not a placeholder, do not change it
64+
if (size >= 0)
65+
return size;
66+
67+
// If we do not know the total size of the tensor, keep this dimension
68+
// dynamic in the result shape.
69+
if (!inputIsStatic)
70+
return ShapedType::kDynamic;
71+
72+
// Calculate the product of all elements in 'newShape' except for the -1
73+
// placeholder, which we discard by negating the result.
74+
int64_t totalSizeNoPlaceholder = -std::accumulate(
75+
newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
76+
77+
// If there is a 0 component in 'newShape', resolve the placeholder as
78+
// 0.
79+
if (totalSizeNoPlaceholder == 0)
80+
return 0;
81+
82+
// Resolve the placeholder as the quotient between the total tensor size
83+
// and the product of all other sizes.
84+
return totalSize / totalSizeNoPlaceholder;
85+
});
8486

8587
bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
8688

@@ -108,7 +110,8 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
108110
if (lhsShape.empty() || rhsShape.empty())
109111
return lhsType.clone(ArrayRef<int64_t>{});
110112

111-
if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
113+
if (ShapedType::isDynamicShape(lhsShape) ||
114+
ShapedType::isDynamicShape(rhsShape))
112115
return lhsType.clone({ShapedType::kDynamic});
113116

114117
SmallVector<int64_t> intermediateShape;
@@ -150,14 +153,16 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
150153
}
151154

152155
SmallVector<ReassociationExprs>
153-
createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
156+
createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
157+
Type dstType) {
154158
auto srcShape = cast<TensorType>(srcType).getShape();
155159
auto dstShape = cast<TensorType>(dstType).getShape();
156160

157161
if (srcShape.empty() || dstShape.empty())
158162
return {};
159163

160-
if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
164+
if (ShapedType::isDynamicShape(srcShape) ||
165+
ShapedType::isDynamicShape(dstShape)) {
161166
assert(dstShape.size() == 1);
162167
SmallVector<AffineExpr, 2> exprs;
163168
for (auto i : llvm::seq<int64_t>(srcShape.size()))
@@ -249,14 +254,16 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
249254
auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
250255

251256
// Cast input if needed
252-
auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
257+
auto castInput =
258+
rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
253259

254260
// Emit collaspe-expand pair
255261
auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
256262
auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
257263

258264
// Cast to final result type if needed
259-
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
265+
auto result =
266+
rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
260267
rewriter.replaceOp(reshape, result);
261268
return success();
262269
}
@@ -350,29 +357,12 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
350357
}
351358

352359
ShapedType inputTy = cast<ShapedType>(input.getType());
353-
Type elementTy = inputTy.getElementType();
354360
int64_t rank = inputTy.getRank();
355361

356362
// Setup the default constantAttr.
357363

358-
Value padConstant;
359-
360-
if (padOp.getPadConst()) {
361-
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
362-
loc, padOp.getPadConst(), ValueRange({}));
363-
} else {
364-
TypedAttr constantAttr;
365-
if (isa<FloatType>(elementTy)) {
366-
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
367-
} else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
368-
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
369-
} else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
370-
int64_t value = padOp.getInputZpAttr().getInt();
371-
constantAttr = rewriter.getIntegerAttr(elementTy, value);
372-
}
373-
if (constantAttr)
374-
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
375-
}
364+
Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
365+
loc, padOp.getPadConst(), ValueRange({}));
376366

377367
if (!padConstant) {
378368
return rewriter.notifyMatchFailure(

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -175,53 +175,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175175
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176176
}
177177

178-
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::PadOp op,
182-
PatternRewriter &rewriter) const override {
183-
if (op.getPadConst())
184-
return failure();
185-
186-
auto input = op.getInput1();
187-
auto padding = op.getPadding();
188-
189-
ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
190-
Type elementTy = inputTy.getElementType();
191-
192-
Attribute constantAttr;
193-
if (llvm::isa<FloatType>(elementTy)) {
194-
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
195-
} else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
196-
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
197-
} else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
198-
int64_t value = op.getInputZpAttr().getInt();
199-
constantAttr = rewriter.getIntegerAttr(elementTy, value);
200-
}
201-
202-
if (!constantAttr) {
203-
return rewriter.notifyMatchFailure(
204-
op,
205-
"tosa.pad to linalg lowering encountered an unknown element type");
206-
}
207-
208-
auto denseAttr = DenseElementsAttr::get(
209-
RankedTensorType::get({1}, elementTy), constantAttr);
210-
auto constantVal = rewriter.create<tosa::ConstOp>(
211-
op.getLoc(), denseAttr.getType(), denseAttr);
212-
213-
rewriter.replaceOpWithNewOp<tosa::PadOp>(
214-
op, op.getType(), ValueRange{input, padding, constantVal},
215-
op->getAttrs());
216-
return success();
217-
}
218-
};
219-
220-
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
221-
MLIRContext *context) {
222-
results.add<MaterializePadValue>(context);
223-
}
224-
225178
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
226179
using OpRewritePattern::OpRewritePattern;
227180

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,23 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
214214
}
215215
}
216216

217+
// Create a pad-const const tensor with value of `val` of required data-type
218+
std::optional<Value> mlir::tosa::createPadConstTensor(OpBuilder &builder,
219+
Location loc, Value src,
220+
int32_t val) {
221+
const auto srcType = getElementTypeOrSelf(src);
222+
const auto srcElemType = getElementTypeOrSelf(src);
223+
const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
224+
const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
225+
const auto padConstAttr{
226+
llvm::isa<FloatType>(srcElemType)
227+
? DenseElementsAttr::get(padConstEType,
228+
builder.getFloatAttr(srcElemType, val))
229+
: DenseElementsAttr::get(padConstEType,
230+
builder.getIntegerAttr(srcElemType, val))};
231+
return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
232+
}
233+
217234
//===----------------------------------------------------------------------===//
218235
// Tosa utilities.
219236
//===----------------------------------------------------------------------===//
@@ -679,30 +696,14 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
679696
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
680697
Type outputType, Value input,
681698
Value paddings) {
682-
result.addOperands({input, paddings});
683-
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
699+
const Location loc{result.location};
700+
int32_t zp{0};
701+
const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
684702
if (quantAttr) {
685-
result.addAttribute("input_zp",
686-
builder.getI32IntegerAttr(
687-
static_cast<int32_t>(quantAttr.getInputZp())));
688-
}
689-
result.types.push_back(outputType);
690-
}
691-
692-
/// This builder is called on TOSA pad operator when an explicit pad_const
693-
/// value is passed in. It also optionally constructs quantization_attr.
694-
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
695-
OperationState &result,
696-
Type outputType, Value input,
697-
Value paddings,
698-
Value padConst) {
699-
result.addOperands({input, paddings, padConst});
700-
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
701-
if (quantAttr) {
702-
result.addAttribute("input_zp",
703-
builder.getI32IntegerAttr(
704-
static_cast<int32_t>(quantAttr.getInputZp())));
703+
zp = static_cast<int32_t>(quantAttr.getInputZp());
705704
}
705+
const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
706+
result.addOperands({input, paddings, padConstOp.value()});
706707
result.types.push_back(outputType);
707708
}
708709

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,16 @@ class TransposeConvStridedConverter
148148
return rewriter.notifyMatchFailure(
149149
op, "zero point must be zero for non-int8 integer types");
150150

151-
if (weightZpVal != 0) {
152-
weight = CreateOpAndInferShape<tosa::PadOp>(
153-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
154-
weightPaddingVal, nullptr, rewriter.getI32IntegerAttr(weightZpVal));
155-
156-
} else {
157-
weight = CreateOpAndInferShape<tosa::PadOp>(
158-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
159-
weightPaddingVal);
160-
}
151+
// construct pad_const values from zp values
152+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
153+
const Value inputPadConst =
154+
createPadConstTensor(builder, op->getLoc(), input, inputZpVal).value();
155+
const Value weightPadConst =
156+
createPadConstTensor(builder, op->getLoc(), input, weightZpVal).value();
157+
158+
weight = CreateOpAndInferShape<tosa::PadOp>(
159+
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
160+
weightPaddingVal, weightPadConst);
161161

162162
weightTy = cast<ShapedType>(weight.getType());
163163
weightHeight = weightTy.getDimSize(1);
@@ -169,7 +169,6 @@ class TransposeConvStridedConverter
169169
stride[0], weightWidth / stride[1],
170170
stride[1], inputChannels};
171171

172-
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
173172
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
174173
builder, UnrankedTensorType::get(weightETy), weight,
175174
getTosaConstShape(rewriter, loc, weightReshapeDims0));
@@ -206,15 +205,9 @@ class TransposeConvStridedConverter
206205
Value inputPaddingVal =
207206
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
208207

209-
if (inputZpVal != 0) {
210-
input = CreateOpAndInferShape<tosa::PadOp>(
211-
rewriter, loc, UnrankedTensorType::get(inputETy), input,
212-
inputPaddingVal, nullptr, rewriter.getI32IntegerAttr(inputZpVal));
213-
} else {
214-
input = CreateOpAndInferShape<tosa::PadOp>(
215-
rewriter, loc, UnrankedTensorType::get(inputETy), input,
216-
inputPaddingVal);
217-
}
208+
input = CreateOpAndInferShape<tosa::PadOp>(
209+
rewriter, loc, UnrankedTensorType::get(inputETy), input,
210+
inputPaddingVal, inputPadConst);
218211

219212
// We use a zero bias as we need to broadcast the bias.
220213
auto zeroBias = rewriter.create<tosa::ConstOp>(

0 commit comments

Comments
 (0)