Skip to content

Commit 7e10ecd

Browse files
authored
[mlir][tosa] Remove optional for pad_const and remove input_zp attr for PadOp (#129336)
Always generated pad_const and remove input_zp attr for PadOp. - Co-authored-by: Udaya Ranga <[email protected]> - Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]>
1 parent f8ba0df commit 7e10ecd

File tree

15 files changed

+156
-216
lines changed

15 files changed

+156
-216
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,6 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
197197
input, paddings);
198198
}]>;
199199

200-
def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
201-
(ins "Type":$outputType, "Value":$input, "Value":$paddings,
202-
"Value":$pad_value),
203-
[{
204-
buildExplicitValuePadOpWithQuantInfo($_builder, $_state, outputType,
205-
input, paddings, pad_value);
206-
}]>;
207-
208200
// Wrapper over base I32EnumAttr to set common fields.
209201
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
210202
: I32EnumAttr<name, description, cases> {

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ namespace tosa {
168168
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
169169
Type srcElemType, int64_t zp = 0);
170170

171+
// Create a pad-const const tensor with value of `val` of required data-type
172+
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
173+
int32_t val = 0);
174+
171175
} // namespace tosa
172176
} // namespace mlir
173177

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,8 +1903,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
19031903
let arguments = (ins
19041904
Tosa_RankedTensor:$input1,
19051905
Tosa_Shape:$padding,
1906-
Optional<Tosa_ScalarTensor>:$pad_const,
1907-
OptionalAttr<I32Attr>:$input_zp
1906+
Tosa_ScalarTensor:$pad_const
19081907
);
19091908

19101909
let results = (outs
@@ -1916,10 +1915,8 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
19161915
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
19171916
];
19181917

1919-
let builders = [Tosa_PadOpQuantInfoBuilder,
1920-
Tosa_ExplicitValuePadOpQuantInfoBuilder];
1918+
let builders = [Tosa_PadOpQuantInfoBuilder];
19211919

1922-
let hasCanonicalizer = 1;
19231920
let hasFolder = 1;
19241921
let hasVerifier = 1;
19251922
}

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
3636
return input.getType();
3737

3838
// The input type must be cast into a tensor with the same rank and all static
39-
// dimensions set to 1. This prevents the generation of a tensor.collapse_shape
40-
// op that converts a dynamically shaped tensor into a 0D tensor. While such
41-
// construct is not incorrect on its own, bufferization cannot properly handle
42-
// it at the moment, so we avoid it.
39+
// dimensions set to 1. This prevents the generation of a
40+
// tensor.collapse_shape op that converts a dynamically shaped tensor into a
41+
// 0D tensor. While such construct is not incorrect on its own, bufferization
42+
// cannot properly handle it at the moment, so we avoid it.
4343
SmallVector<int64_t> shape(input.getType().getRank(), 1);
4444
return input.getType().clone(shape);
4545
}
@@ -58,29 +58,31 @@ TensorType inferReshapeExpandedType(TensorType inputType,
5858
int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
5959

6060
// Compute result shape
61-
auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
62-
// If this is not a placeholder, do not change it
63-
if (size >= 0)
64-
return size;
65-
66-
// If we do not know the total size of the tensor, keep this dimension
67-
// dynamic in the result shape.
68-
if (!inputIsStatic)
69-
return ShapedType::kDynamic;
70-
71-
// Calculate the product of all elements in 'newShape' except for the -1
72-
// placeholder, which we discard by negating the result.
73-
int64_t totalSizeNoPlaceholder = -std::accumulate(
74-
newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
75-
76-
// If there is a 0 component in 'newShape', resolve the placeholder as 0.
77-
if (totalSizeNoPlaceholder == 0)
78-
return 0;
79-
80-
// Resolve the placeholder as the quotient between the total tensor size and
81-
// the product of all other sizes.
82-
return totalSize / totalSizeNoPlaceholder;
83-
});
61+
auto resultShape =
62+
llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
63+
// If this is not a placeholder, do not change it.
64+
if (size >= 0)
65+
return size;
66+
67+
// If we do not know the total size of the tensor, keep this dimension
68+
// dynamic in the result shape.
69+
if (!inputIsStatic)
70+
return ShapedType::kDynamic;
71+
72+
// Calculate the product of all elements in 'newShape' except for the -1
73+
// placeholder, which we discard by negating the result.
74+
int64_t totalSizeNoPlaceholder = -std::accumulate(
75+
newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
76+
77+
// If there is a 0 component in 'newShape', resolve the placeholder as
78+
// 0.
79+
if (totalSizeNoPlaceholder == 0)
80+
return 0;
81+
82+
// Resolve the placeholder as the quotient between the total tensor size
83+
// and the product of all other sizes.
84+
return totalSize / totalSizeNoPlaceholder;
85+
});
8486

8587
bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
8688

@@ -108,7 +110,8 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
108110
if (lhsShape.empty() || rhsShape.empty())
109111
return lhsType.clone(ArrayRef<int64_t>{});
110112

111-
if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
113+
if (ShapedType::isDynamicShape(lhsShape) ||
114+
ShapedType::isDynamicShape(rhsShape))
112115
return lhsType.clone({ShapedType::kDynamic});
113116

114117
SmallVector<int64_t> intermediateShape;
@@ -150,14 +153,16 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
150153
}
151154

152155
SmallVector<ReassociationExprs>
153-
createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
156+
createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
157+
Type dstType) {
154158
auto srcShape = cast<TensorType>(srcType).getShape();
155159
auto dstShape = cast<TensorType>(dstType).getShape();
156160

157161
if (srcShape.empty() || dstShape.empty())
158162
return {};
159163

160-
if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
164+
if (ShapedType::isDynamicShape(srcShape) ||
165+
ShapedType::isDynamicShape(dstShape)) {
161166
assert(dstShape.size() == 1);
162167
SmallVector<AffineExpr, 2> exprs;
163168
for (auto i : llvm::seq<int64_t>(srcShape.size()))
@@ -249,14 +254,16 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
249254
auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
250255

251256
// Cast input if needed
252-
auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
257+
auto castInput =
258+
rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
253259

254260
// Emit collaspe-expand pair
255261
auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
256262
auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
257263

258264
// Cast to final result type if needed
259-
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
265+
auto result =
266+
rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
260267
rewriter.replaceOp(reshape, result);
261268
return success();
262269
}
@@ -350,29 +357,12 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
350357
}
351358

352359
ShapedType inputTy = cast<ShapedType>(input.getType());
353-
Type elementTy = inputTy.getElementType();
354360
int64_t rank = inputTy.getRank();
355361

356362
// Setup the default constantAttr.
357363

358-
Value padConstant;
359-
360-
if (padOp.getPadConst()) {
361-
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
362-
loc, padOp.getPadConst(), ValueRange({}));
363-
} else {
364-
TypedAttr constantAttr;
365-
if (isa<FloatType>(elementTy)) {
366-
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
367-
} else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
368-
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
369-
} else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
370-
int64_t value = padOp.getInputZpAttr().getInt();
371-
constantAttr = rewriter.getIntegerAttr(elementTy, value);
372-
}
373-
if (constantAttr)
374-
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
375-
}
364+
Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
365+
loc, padOp.getPadConst(), ValueRange({}));
376366

377367
if (!padConstant) {
378368
return rewriter.notifyMatchFailure(

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -175,53 +175,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175175
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176176
}
177177

178-
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::PadOp op,
182-
PatternRewriter &rewriter) const override {
183-
if (op.getPadConst())
184-
return failure();
185-
186-
auto input = op.getInput1();
187-
auto padding = op.getPadding();
188-
189-
ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
190-
Type elementTy = inputTy.getElementType();
191-
192-
Attribute constantAttr;
193-
if (llvm::isa<FloatType>(elementTy)) {
194-
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
195-
} else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
196-
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
197-
} else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
198-
int64_t value = op.getInputZpAttr().getInt();
199-
constantAttr = rewriter.getIntegerAttr(elementTy, value);
200-
}
201-
202-
if (!constantAttr) {
203-
return rewriter.notifyMatchFailure(
204-
op,
205-
"tosa.pad to linalg lowering encountered an unknown element type");
206-
}
207-
208-
auto denseAttr = DenseElementsAttr::get(
209-
RankedTensorType::get({1}, elementTy), constantAttr);
210-
auto constantVal = rewriter.create<tosa::ConstOp>(
211-
op.getLoc(), denseAttr.getType(), denseAttr);
212-
213-
rewriter.replaceOpWithNewOp<tosa::PadOp>(
214-
op, op.getType(), ValueRange{input, padding, constantVal},
215-
op->getAttrs());
216-
return success();
217-
}
218-
};
219-
220-
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
221-
MLIRContext *context) {
222-
results.add<MaterializePadValue>(context);
223-
}
224-
225178
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
226179
using OpRewritePattern::OpRewritePattern;
227180

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,22 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
216216
}
217217
}
218218

219+
// Create a pad-const const tensor with value of `val` of required data-type
220+
Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
221+
Value src, int32_t val) {
222+
const auto srcType = getElementTypeOrSelf(src);
223+
const auto srcElemType = getElementTypeOrSelf(src);
224+
const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
225+
const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
226+
const auto padConstAttr{
227+
llvm::isa<FloatType>(srcElemType)
228+
? DenseElementsAttr::get(padConstEType,
229+
builder.getFloatAttr(srcElemType, val))
230+
: DenseElementsAttr::get(padConstEType,
231+
builder.getIntegerAttr(srcElemType, val))};
232+
return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
233+
}
234+
219235
//===----------------------------------------------------------------------===//
220236
// Tosa utilities.
221237
//===----------------------------------------------------------------------===//
@@ -708,30 +724,14 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
708724
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
709725
Type outputType, Value input,
710726
Value paddings) {
711-
result.addOperands({input, paddings});
712-
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
713-
if (quantAttr) {
714-
result.addAttribute("input_zp",
715-
builder.getI32IntegerAttr(
716-
static_cast<int32_t>(quantAttr.getInputZp())));
717-
}
718-
result.types.push_back(outputType);
719-
}
720-
721-
/// This builder is called on TOSA pad operator when an explicit pad_const
722-
/// value is passed in. It also optionally constructs quantization_attr.
723-
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
724-
OperationState &result,
725-
Type outputType, Value input,
726-
Value paddings,
727-
Value padConst) {
728-
result.addOperands({input, paddings, padConst});
729-
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
727+
const Location loc{result.location};
728+
int32_t zp{0};
729+
const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
730730
if (quantAttr) {
731-
result.addAttribute("input_zp",
732-
builder.getI32IntegerAttr(
733-
static_cast<int32_t>(quantAttr.getInputZp())));
731+
zp = static_cast<int32_t>(quantAttr.getInputZp());
734732
}
733+
const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
734+
result.addOperands({input, paddings, padConstOp});
735735
result.types.push_back(outputType);
736736
}
737737

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,16 @@ class TransposeConvStridedConverter
156156
return rewriter.notifyMatchFailure(
157157
op, "weight zero point must be zero for non-int8 integer types");
158158

159-
if (weightZpVal != 0) {
160-
weight = CreateOpAndInferShape<tosa::PadOp>(
161-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
162-
weightPaddingVal, nullptr, rewriter.getI32IntegerAttr(weightZpVal));
163-
164-
} else {
165-
weight = CreateOpAndInferShape<tosa::PadOp>(
166-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
167-
weightPaddingVal);
168-
}
159+
// construct pad_const values from zp values
160+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
161+
const Value inputPadConst =
162+
createPadConstTensor(builder, op->getLoc(), input, inputZpVal);
163+
const Value weightPadConst =
164+
createPadConstTensor(builder, op->getLoc(), input, weightZpVal);
165+
166+
weight = CreateOpAndInferShape<tosa::PadOp>(
167+
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
168+
weightPaddingVal, weightPadConst);
169169

170170
weightTy = cast<ShapedType>(weight.getType());
171171
weightHeight = weightTy.getDimSize(1);
@@ -177,7 +177,6 @@ class TransposeConvStridedConverter
177177
stride[0], weightWidth / stride[1],
178178
stride[1], inputChannels};
179179

180-
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
181180
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
182181
builder, UnrankedTensorType::get(weightETy), weight,
183182
getTosaConstShape(rewriter, loc, weightReshapeDims0));
@@ -214,15 +213,9 @@ class TransposeConvStridedConverter
214213
Value inputPaddingVal =
215214
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
216215

217-
if (inputZpVal != 0) {
218-
input = CreateOpAndInferShape<tosa::PadOp>(
219-
rewriter, loc, UnrankedTensorType::get(inputETy), input,
220-
inputPaddingVal, nullptr, rewriter.getI32IntegerAttr(inputZpVal));
221-
} else {
222-
input = CreateOpAndInferShape<tosa::PadOp>(
223-
rewriter, loc, UnrankedTensorType::get(inputETy), input,
224-
inputPaddingVal);
225-
}
216+
input = CreateOpAndInferShape<tosa::PadOp>(
217+
rewriter, loc, UnrankedTensorType::get(inputETy), input,
218+
inputPaddingVal, inputPadConst);
226219

227220
// We use a zero bias as we need to broadcast the bias.
228221
auto zeroBias = rewriter.create<tosa::ConstOp>(

0 commit comments

Comments
 (0)