Skip to content

Commit edba22f

Browse files
lhutton1Tai78641
authored andcommitted
[mlir][tosa] Switch zero point of negate to input variable type
This commit changes the zero point attribute to an input to align with the 1.0 spec. Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184 Signed-off-by: Tai Ly <[email protected]>
1 parent 78631ac commit edba22f

File tree

17 files changed

+365
-92
lines changed

17 files changed

+365
-92
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,12 @@ profileComplianceMap = {
112112
{"tosa.logical_not",
113113
{{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
114114
{"tosa.negate",
115-
{{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
116-
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
115+
{{{Profile::pro_int},
116+
{{i8T, i8T, i8T, i8T},
117+
{i16T, i16T, i16T, i16T},
118+
{i32T, i32T, i32T, i32T}}},
119+
{{Profile::pro_fp},
120+
{{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
117121
{"tosa.reciprocal",
118122
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
119123
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -308,7 +312,7 @@ extensionComplianceMap = {
308312
{"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
309313
{"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
310314
{"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
311-
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
315+
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
312316
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
313317
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
314318
{"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
178178
input, kernel, stride, pad, acc_type);
179179
}]>;
180180

181-
// This builder is called on single-parameter unary operators that have a scale
181+
// This builder is called on single-parameter negate operators that have a scale
182182
// relationship between their input and output, expressed by the
183183
// UnaryOpQuantizationAttr.
184-
def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
184+
def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
185185
(ins "Type":$outputType, "Value":$input),
186186
[{
187-
buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
187+
buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
188188
}]>;
189189

190190
// These builders are called on the TOSA pad operator that needs to create its

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
13491349
//===----------------------------------------------------------------------===//
13501350
// Operator: negate
13511351
//===----------------------------------------------------------------------===//
1352-
def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
1352+
def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
1353+
TosaElementwiseOperator,
1354+
Pure]> {
13531355
let summary = "Elementwise negate op";
13541356

13551357
let description = [{
@@ -1358,8 +1360,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
13581360

13591361
let arguments = (ins
13601362
Tosa_Tensor:$input1,
1361-
OptionalAttr<I32Attr>:$input1_zp,
1362-
OptionalAttr<I32Attr>:$output_zp
1363+
Tosa_ScalarIntOrFloatTensor:$input1_zp,
1364+
Tosa_ScalarIntOrFloatTensor:$output_zp
13631365
);
13641366

13651367
let results = (outs
@@ -1371,9 +1373,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
13711373
Extension<[Tosa_EXT_BF16]>,
13721374
];
13731375

1374-
let builders = [Tosa_UnaryOpQuantInfoBuilder];
1376+
let builders = [Tosa_NegateOpQuantInfoBuilder];
1377+
1378+
let extraClassDeclaration = [{
1379+
FailureOr<int64_t> getInput1ZeroPoint();
1380+
FailureOr<int64_t> getOutputZeroPoint();
1381+
LogicalResult verifyInput1ZeroPoint(int64_t zp);
1382+
LogicalResult verifyOutputZeroPoint(int64_t zp);
1383+
}];
13751384

13761385
let hasFolder = 1;
1386+
let hasVerifier = 1;
1387+
1388+
let assemblyFormat =
1389+
"operands attr-dict `:` functional-type(operands, results)";
13771390
}
13781391

13791392
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
193193

194194
// tosa::NegateOp
195195
if (isa<tosa::NegateOp>(op)) {
196-
if (isa<FloatType>(elementTy))
197-
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
196+
auto negate = cast<tosa::NegateOp>(op);
198197

199-
if (isa<IntegerType>(elementTy)) {
200-
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
201-
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
198+
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
199+
if (failed(maybeInZp)) {
200+
(void)rewriter.notifyMatchFailure(
201+
op, "input1 zero point cannot be statically determined");
202+
return nullptr;
203+
}
204+
205+
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
206+
if (failed(maybeOutZp)) {
207+
(void)rewriter.notifyMatchFailure(
208+
op, "output zero point cannot be statically determined");
209+
return nullptr;
210+
}
202211

203-
const int64_t inZp =
204-
inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
205-
const int64_t outZp =
206-
outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
212+
int64_t inZp = *maybeInZp;
213+
int64_t outZp = *maybeOutZp;
207214

215+
if (isa<FloatType>(elementTy))
216+
return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
217+
218+
if (isa<IntegerType>(elementTy)) {
208219
if (!inZp && !outZp) {
209220
auto constant = rewriter.create<arith::ConstantOp>(
210221
loc, IntegerAttr::get(elementTy, 0));

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,45 @@ struct MatMulOpSharding
6060
}
6161
};
6262

63+
struct NegateOpSharding
64+
: public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
65+
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
66+
Value val = op->getOperand(0);
67+
auto type = dyn_cast<RankedTensorType>(val.getType());
68+
if (!type)
69+
return {};
70+
SmallVector<utils::IteratorType> types(type.getRank(),
71+
utils::IteratorType::parallel);
72+
return types;
73+
}
74+
75+
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
76+
MLIRContext *ctx = op->getContext();
77+
Value val = op->getOperand(0);
78+
auto type = dyn_cast<RankedTensorType>(val.getType());
79+
if (!type)
80+
return {};
81+
int64_t rank = type.getRank();
82+
SmallVector<AffineMap> maps = {
83+
AffineMap::getMultiDimIdentityMap(rank, ctx),
84+
AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
85+
AffineMap::getMultiDimIdentityMap(rank, ctx)};
86+
return maps;
87+
}
88+
89+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
90+
ArrayRef<MeshSharding> operandShardings,
91+
ArrayRef<MeshSharding> resultShardings,
92+
IRMapping &spmdizationMap,
93+
SymbolTableCollection &symbolTable,
94+
OpBuilder &builder) const {
95+
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
96+
resultShardings, spmdizationMap,
97+
symbolTable, builder);
98+
return success();
99+
}
100+
};
101+
63102
template <typename OpType>
64103
static void registerElemwiseOne(MLIRContext *ctx) {
65104
OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -82,9 +121,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
82121
BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
83122
LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84123
MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85-
LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
124+
LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86125
GreaterOp, GreaterEqualOp>(ctx);
87126

88127
MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
128+
NegateOp::attachInterface<NegateOpSharding>(*ctx);
89129
});
90130
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,13 +1143,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
11431143
}
11441144

11451145
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1146-
auto input = getInput1();
11471146
// Element-wise negate(negate(x)) = x
1148-
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1149-
return op.getInput1();
1147+
// iff all zero points are constant 0
1148+
auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1149+
if (!definingOp) {
1150+
// defining op of input1 is not a negate, cannot fold
1151+
return {};
11501152
}
11511153

1152-
return {};
1154+
if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1155+
failed(maybeIZp) || *maybeIZp != 0) {
1156+
// input1 zero point is not constant 0, cannot fold
1157+
return {};
1158+
}
1159+
if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1160+
failed(maybeOZp) || *maybeOZp != 0) {
1161+
// output zero point is not constant 0, cannot fold
1162+
return {};
1163+
}
1164+
if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1165+
failed(maybeIZp) || *maybeIZp != 0) {
1166+
// definingOp's input1 zero point is not constant 0, cannot fold
1167+
return {};
1168+
}
1169+
if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1170+
failed(maybeOZp) || *maybeOZp != 0) {
1171+
// definingOp's output zero point is not constant 0, cannot fold
1172+
return {};
1173+
}
1174+
1175+
return definingOp.getInput1();
11531176
}
11541177

11551178
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -708,23 +708,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
708708
result.types.push_back(outputType);
709709
}
710710

711-
/// This builder is called on single-parameter unary operators that have scale
712-
/// relationship between their input and output, expressed by the
713-
/// UnaryOpQuantizationAttr.
714-
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
715-
OperationState &result, Type outputType,
716-
Value input) {
717-
result.addOperands(input);
711+
/// This builder is called on single-parameter negate operator
712+
/// to construct input and output zero points based on their
713+
/// types.
714+
static void buildNegateOpWithQuantInfo(OpBuilder &builder,
715+
OperationState &result, Type outputType,
716+
Value input) {
717+
const Location loc{result.location};
718+
int64_t input1Zp{0};
719+
int64_t outputZp{0};
718720
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
719721
if (quantAttr) {
720-
// note: negateOp has attributes input1_zp and output_zp
721-
result.addAttribute("input1_zp",
722-
builder.getI32IntegerAttr(
723-
static_cast<int32_t>(quantAttr.getInputZp())));
724-
result.addAttribute("output_zp",
725-
builder.getI32IntegerAttr(
726-
static_cast<int32_t>(quantAttr.getOutputZp())));
722+
input1Zp = quantAttr.getInputZp();
723+
outputZp = quantAttr.getOutputZp();
724+
}
725+
const std::optional<Value> input1ZpOp =
726+
createZeroPointTensor(builder, loc, input.getType(), input1Zp);
727+
if (!input1ZpOp) {
728+
(void)emitError(
729+
loc, "Failed to create input1 zero point for quantized NEGATE op");
730+
}
731+
732+
const std::optional<Value> outputZpOp =
733+
createZeroPointTensor(builder, loc, input.getType(), outputZp);
734+
if (!outputZpOp) {
735+
(void)emitError(
736+
loc, "Failed to create output zero point for quantized NEGATE op");
727737
}
738+
739+
if (input1ZpOp && outputZpOp) {
740+
result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
741+
} else {
742+
// failed to create one or more zero points above: just add input as
743+
// operands. This will trigger error in building the op because of
744+
// missing zero points
745+
result.addOperands({input});
746+
}
747+
728748
result.types.push_back(outputType);
729749
}
730750

@@ -1714,6 +1734,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
17141734
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
17151735
ZERO_POINT_HELPER(AvgPool2dOp, Input)
17161736
ZERO_POINT_HELPER(AvgPool2dOp, Output)
1737+
ZERO_POINT_HELPER(NegateOp, Input1)
1738+
ZERO_POINT_HELPER(NegateOp, Output)
1739+
17171740
#undef ZERO_POINT_HELPER
17181741

17191742
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2216,7 +2239,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
22162239
NARY_SHAPE_INFER(tosa::LogicalXorOp)
22172240
NARY_SHAPE_INFER(tosa::MaximumOp)
22182241
NARY_SHAPE_INFER(tosa::MinimumOp)
2219-
NARY_SHAPE_INFER(tosa::NegateOp)
22202242
NARY_SHAPE_INFER(tosa::PowOp)
22212243
NARY_SHAPE_INFER(tosa::ReciprocalOp)
22222244
NARY_SHAPE_INFER(tosa::ReverseOp)
@@ -2229,6 +2251,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
22292251
NARY_SHAPE_INFER(tosa::SigmoidOp)
22302252
#undef PRED_SHAPE_INFER
22312253

2254+
LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2255+
MLIRContext *context, ::std::optional<Location> location,
2256+
NegateOp::Adaptor adaptor,
2257+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2258+
ShapeAdaptor inputShape(adaptor.getInput1().getType());
2259+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2260+
return success();
2261+
}
2262+
2263+
LogicalResult tosa::NegateOp::verify() {
2264+
// Verify same element type
2265+
const Type input1Type = getInput1().getType();
2266+
const Type outputType = getOutput().getType();
2267+
if (verifySameElementTypes(*this, input1Type, outputType).failed())
2268+
return failure();
2269+
2270+
// Verify same shape
2271+
const SmallVector<Type, 2> types = {input1Type, outputType};
2272+
if (failed(verifyCompatibleShapes(types)))
2273+
return emitOpError() << "requires the same shape for input1 and output";
2274+
2275+
const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2276+
const Type input1ZpEType =
2277+
getStorageElementTypeOrSelf(getInput1Zp().getType());
2278+
if (input1EType != input1ZpEType) {
2279+
return emitOpError("expect both input1 and its zero point are the same "
2280+
"element type, got ")
2281+
<< input1EType << " and " << input1ZpEType;
2282+
}
2283+
const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2284+
const Type outputZpEType =
2285+
getStorageElementTypeOrSelf(getOutputZp().getType());
2286+
if (outputEType != outputZpEType) {
2287+
return emitOpError("expect both output and its zero point are the same "
2288+
"element type, got ")
2289+
<< outputEType << " and " << outputZpEType;
2290+
}
2291+
2292+
FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2293+
if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2294+
return failure();
2295+
2296+
FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2297+
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2298+
return failure();
2299+
2300+
return success();
2301+
}
2302+
22322303
static LogicalResult poolingInferReturnTypes(
22332304
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
22342305
ArrayRef<int64_t> pad,

0 commit comments

Comments
 (0)