Skip to content

Commit 8a47e90

Browse files
akuegelTensorFlow MLIR Team
authored andcommitted
Adapt Stablehlo to recent upstream Tosa changes.
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
1 parent 4b7071b commit 8a47e90

File tree

11 files changed

+201
-80
lines changed

11 files changed

+201
-80
lines changed

stablehlo/BUILD

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,8 +1542,6 @@ glob_lit_tests(
15421542
driver = "@llvm-project//mlir:run_lit.sh",
15431543
exclude = [
15441544
# TODO: remove the following excludes once #2751 is fixed.
1545-
"stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir",
1546-
"stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir",
15471545
"stablehlo/conversions/tosa/tests/rescale_interpreter.mlir",
15481546
"stablehlo/conversions/tosa/tests/binary.mlir",
15491547
"stablehlo/conversions/tosa/tests/nullary.mlir",
@@ -1638,9 +1636,8 @@ cc_library(
16381636
srcs = [
16391637
"stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp",
16401638
"stablehlo/conversions/tosa/transforms/StablehloPrepareForTosa.cpp",
1641-
# TODO: un-comment the following once #2751 is fixed
1642-
# "stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp",
1643-
# "stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp",
1639+
"stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp",
1640+
"stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp",
16441641
],
16451642
hdrs = [
16461643
"stablehlo/conversions/tosa/transforms/Passes.h",

stablehlo/BUILD.bazel

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,9 +1577,8 @@ cc_library(
15771577
srcs = [
15781578
"stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp",
15791579
"stablehlo/conversions/tosa/transforms/StablehloPrepareForTosa.cpp",
1580-
# TODO: un-comment the following once #2751 is fixed
1581-
# "stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp",
1582-
# "stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp",
1580+
"stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp",
1581+
"stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp",
15831582
],
15841583
hdrs = [
15851584
"stablehlo/conversions/tosa/transforms/Passes.h",

stablehlo/docs/generated/stablehlo_passes.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,17 @@ _Convert between versions of VHLO._
351351
```
352352
-target : The target version. Must be a version of the form #.#.# .
353353
```
354+
355+
### `-stablehlo-quant-legalize-to-tosa-rescale`
356+
357+
_Legalize StableHLO Quantized operations to TOSA rescale operations_
358+
359+
This pass rewrites StableHLO quantized operations to integer operations
360+
by inserting TOSA rescale operations at the inputs and outputs of the
361+
integer operations.
362+
363+
### `-tosa-rescale-legalize-to-stablehlo`
364+
365+
_Legalize TOSA rescales to StableHlo primitive math operations_
366+
367+
This pass rewrites TOSA rescale operations to StableHLO primitive math operations.

stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ RUNFILES_DIR = LITE_CFG_PY.parents[4].absolute().as_posix()''',
5959
for src in glob(
6060
["**/*.mlir"],
6161
exclude = [
62-
"legalize_quant_ops_to_tosa_rescale.mlir",
63-
"legalize_tosa_rescale_to_stablehlo.mlir",
6462
"rescale_interpreter.mlir",
6563
"binary.mlir",
6664
"nullary.mlir",

stablehlo/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir

Lines changed: 93 additions & 27 deletions
Large diffs are not rendered by default.

stablehlo/stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
// -----
44
// CHECK-LABEL: @rescale
55
func.func @rescale(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>) -> tensor<2x2xi32> {
6-
%0 = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1431655765>, input_unsigned = false, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 13>} :
7-
(tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>) -> tensor<2x2xi32>
6+
%multiplier = "tosa.const"() {values = dense<1431655765> : tensor<1xi32> } : () -> tensor<1xi32>
7+
%shift = "tosa.const"() {values = dense<13> : tensor<1xi8> } : () -> tensor<1xi8>
8+
%input_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
9+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
10+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", input_unsigned = false, output_unsigned = false, per_channel = false, scale32 = true} :
11+
(tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<2x2xi32>
812

913
// convert input quantized type to storage type
1014
// CHECK-DAG: %[[arg:.+]] = stablehlo.bitcast_convert %arg0 : (tensor<2x2x!quant.uniform<i8:f32, 2.500000e-02:-1>>) -> tensor<2x2xi8>

stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
config.suffixes = ['.mlir']
2727
config.test_source_root = os.path.dirname(__file__)
2828
# TODO: remove the following once #2751 is fixed.
29-
config.excludes = ['legalize_quant_ops_to_tosa_rescale.mlir', 'legalize_tosa_rescale_to_stablehlo.mlir', 'rescale_interpreter.mlir', 'binary.mlir', 'nullary.mlir', 'unary.mlir']
29+
config.excludes = [
30+
'rescale_interpreter.mlir',
31+
'binary.mlir',
32+
'nullary.mlir',
33+
'unary.mlir',
34+
]
3035

3136
# Disallow reusing variables across CHECK-LABEL matches.
3237
# A variable can eschew this (be made "global") by prefixing its name with $.

stablehlo/stablehlo/conversions/tosa/transforms/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ add_mlir_pdll_library(StablehloTOSAPDLLPatternsIncGen
2525
add_mlir_library(StablehloTOSATransforms
2626
StablehloLegalizeToTosa.cpp
2727
StablehloPrepareForTosa.cpp
28-
# TODO: un-comment the following once #2751 is fixed.
29-
# StablehloQuantLegalizeToTosaRescale.cpp
30-
# TosaRescaleLegalizeToStablehlo.cpp
28+
StablehloQuantLegalizeToTosaRescale.cpp
29+
TosaRescaleLegalizeToStablehlo.cpp
3130

32-
PARTIAL_SOURCES_INTENDED
3331
DEPENDS
3432
StablehloTOSATransformsPassIncGen
3533
StablehloTOSAPDLLPatternsIncGen

stablehlo/stablehlo/conversions/tosa/transforms/Passes.td

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,24 @@ def StablehloPrepareForTosaPass : Pass<"stablehlo-prepare-for-tosa", "mlir::func
3333
let dependentDialects = ["::mlir::tosa::TosaDialect"];
3434
}
3535

36-
// TODO: un-comment the following once #2751 is fixed.
37-
// def StablehloQuantLegalizeToTosaRescalePass : Pass<"stablehlo-quant-legalize-to-tosa-rescale", "mlir::func::FuncOp"> {
38-
// let summary = "Legalize StableHLO Quantized operations to TOSA rescale operations";
39-
// let description = [{
40-
// This pass rewrites StableHLO quantized operations to integer operations
41-
// by inserting TOSA rescale operations at the inputs and outputs of the
42-
// integer operations.
43-
// }];
44-
// let dependentDialects = [
45-
// "::mlir::tosa::TosaDialect",
46-
// ];
47-
// }
48-
//
49-
// def TosaRescaleLegalizeToStablehloPass : Pass<"tosa-rescale-legalize-to-stablehlo", "mlir::func::FuncOp"> {
50-
// let summary = "Legalize TOSA rescales to StableHlo primitive math operations";
51-
// let description = [{
52-
// This pass rewrites TOSA rescale operations to StableHLO primitive math operations.
53-
// }];
54-
// let dependentDialects = [
55-
// "::mlir::stablehlo::StablehloDialect"
56-
// ];
57-
// }
36+
def StablehloQuantLegalizeToTosaRescalePass : Pass<"stablehlo-quant-legalize-to-tosa-rescale", "mlir::func::FuncOp"> {
37+
let summary = "Legalize StableHLO Quantized operations to TOSA rescale operations";
38+
let description = [{
39+
This pass rewrites StableHLO quantized operations to integer operations
40+
by inserting TOSA rescale operations at the inputs and outputs of the
41+
integer operations.
42+
}];
43+
let dependentDialects = [
44+
"::mlir::tosa::TosaDialect",
45+
];
46+
}
47+
48+
def TosaRescaleLegalizeToStablehloPass : Pass<"tosa-rescale-legalize-to-stablehlo", "mlir::func::FuncOp"> {
49+
let summary = "Legalize TOSA rescales to StableHlo primitive math operations";
50+
let description = [{
51+
This pass rewrites TOSA rescale operations to StableHLO primitive math operations.
52+
}];
53+
let dependentDialects = [
54+
"::mlir::stablehlo::StablehloDialect"
55+
];
56+
}

stablehlo/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,24 @@ Value buildRescale(PatternRewriter &rewriter, Location loc,
4545
ShapedType outputType, Value inputVal, int32_t multiplier,
4646
int32_t shift, int64_t inputZp, int64_t outputZp,
4747
bool doubleRound, bool scale32, bool perChannel) {
48+
auto multiplierVal = getConstTensorInt<int32_t>(rewriter, loc, {multiplier});
49+
auto shiftVal =
50+
getConstTensorInt<int8_t>(rewriter, loc, static_cast<int8_t>(shift));
51+
auto inputZpVal =
52+
createZeroPointTensor(rewriter, loc, inputVal.getType(), inputZp);
53+
if (!inputZpVal) {
54+
(void)emitError(loc,
55+
"Failed to create input zero point tensor for RescaleOp");
56+
}
57+
auto outputZpVal = createZeroPointTensor(rewriter, loc, outputType, outputZp);
58+
if (!outputZpVal) {
59+
(void)emitError(loc,
60+
"Failed to create output zero point tensor for RescaleOp");
61+
}
4862
auto rescale_op = rewriter.create<RescaleOp>(
49-
loc, outputType, inputVal,
50-
rewriter.getI32IntegerAttr(static_cast<int32_t>(inputZp)),
51-
rewriter.getI32IntegerAttr(static_cast<int32_t>(outputZp)),
52-
rewriter.getDenseI32ArrayAttr({multiplier}),
53-
rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
54-
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(doubleRound),
63+
loc, outputType, inputVal, multiplierVal, shiftVal, inputZpVal.value(),
64+
outputZpVal.value(), rewriter.getBoolAttr(scale32),
65+
rewriter.getStringAttr(doubleRound ? "DOUBLE_ROUND" : "SINGLE_ROUND"),
5566
rewriter.getBoolAttr(perChannel),
5667
/*input_unsigned=*/rewriter.getBoolAttr(false),
5768
/*output_unsigned=*/rewriter.getBoolAttr(false));

stablehlo/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ LogicalResult ConvertTosaRescaleToStablehlo::matchAndRewrite(
6565
}
6666

6767
bool scale32 = op.getScale32();
68-
bool doubleRound = op.getDoubleRound();
68+
bool doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
6969
bool perChannel = op.getPerChannel();
7070

7171
if (perChannel || doubleRound || !scale32) {
@@ -106,18 +106,48 @@ LogicalResult ConvertTosaRescaleToStablehlo::matchAndRewrite(
106106
auto i32Type = inputType.clone(rewriter.getI32Type());
107107
auto i64Type = inputType.clone(rewriter.getI64Type());
108108

109-
// construct multiplier, shift constant values from op attrs
110-
// for scale32, multiplier is tensor of i32
109+
// construct multiplier, shift constant values for scale32, multiplier and
110+
// shift are constant tensors of i32 or i8, respectively.
111+
DenseElementsAttr multiplierElems;
112+
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) {
113+
return rewriter.notifyMatchFailure(
114+
op, "tosa.rescale requires constant multiplier input values");
115+
}
116+
llvm::SmallVector<int32_t> multiplierValues =
117+
llvm::to_vector(multiplierElems.getValues<int32_t>());
111118
Value multiplier = getStablehloConstantOp(
112-
rewriter, loc, DenseElementsAttr::get(i32Type, op.getMultiplier()));
119+
rewriter, loc,
120+
DenseElementsAttr::get(
121+
i32Type, rewriter.getI32IntegerAttr(multiplierValues.front())));
122+
DenseElementsAttr shiftElems;
123+
if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) {
124+
return rewriter.notifyMatchFailure(
125+
op, "tosa.rescale requires constant shift input values");
126+
}
127+
llvm::SmallVector<int8_t> shiftValues =
128+
llvm::to_vector(shiftElems.getValues<int8_t>());
113129
Value shift = getStablehloConstantOp(
114-
rewriter, loc, DenseElementsAttr::get(i8Type, op.getShift()));
130+
rewriter, loc,
131+
DenseElementsAttr::get(i8Type,
132+
rewriter.getI8IntegerAttr(shiftValues.front())));
115133

116-
// construct inputZp and outputZp from op attrs
134+
// construct inputZp and outputZp
135+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
136+
if (failed(maybeIZp)) {
137+
return rewriter.notifyMatchFailure(
138+
op, "input zero point cannot be statically determined");
139+
}
117140
Value inputZpI32 = getStablehloConstantOp(
118-
rewriter, loc, DenseElementsAttr::get(i32Type, op.getInputZpAttr()));
141+
rewriter, loc,
142+
DenseElementsAttr::get(i32Type, rewriter.getI32IntegerAttr(*maybeIZp)));
143+
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
144+
if (failed(maybeOZp)) {
145+
return rewriter.notifyMatchFailure(
146+
op, "output zero point cannot be statically determined");
147+
}
119148
Value outputZpI32 = getStablehloConstantOp(
120-
rewriter, loc, DenseElementsAttr::get(i32Type, op.getOutputZpAttr()));
149+
rewriter, loc,
150+
DenseElementsAttr::get(i32Type, rewriter.getI32IntegerAttr(*maybeOZp)));
121151

122152
// construct constant 1, min and max tensors
123153
Value onesI64 = getStablehloConstantOp(rewriter, loc,

0 commit comments

Comments
 (0)