Skip to content

Commit aa1c533

Browse files
committed
[mlir][tosa] Expand tosa.apply_scale lowering for vectors
Apply scale may encounter scalar, tensor, or vector operations. Expand the lowering so that it can lower arbitrary of container types. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D117080
1 parent 2db4cf5 commit aa1c533

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
5252
}
5353
};
5454

55+
Type matchContainerType(Type element, Type container) {
56+
if (auto shapedTy = container.dyn_cast<ShapedType>())
57+
return shapedTy.clone(element);
58+
59+
return element;
60+
}
61+
62+
Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
63+
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
64+
Type eTy = shapedTy.getElementType();
65+
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
66+
return DenseIntElementsAttr::get(shapedTy, valueInt);
67+
}
68+
69+
return rewriter.getIntegerAttr(type, value);
70+
}
71+
5572
// This converts the TOSA ApplyScale operator to a set of StandardOps ops,
5673
// using 64-bit operations to perform the necessary multiply, bias, and shift.
5774
// Multiple types are used to use minimal bit width operations.
@@ -65,13 +82,19 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
6582
Value value32 = op.value();
6683
Value multiplier32 = op.multiplier();
6784
Value shift8 = op.shift();
85+
6886
bool doubleRound = op.double_round();
6987
Type inType = op.value().getType();
88+
Type resultTy = op.getType();
89+
90+
Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy);
91+
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
92+
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
7093

7194
Value one8 = rewriter.create<arith::ConstantOp>(
72-
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
95+
loc, getConstantAttr(i8Ty, 1, rewriter));
7396
Value one64 = rewriter.create<arith::ConstantOp>(
74-
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
97+
loc, getConstantAttr(i64Ty, 1, rewriter));
7598

7699
Value shiftSubOne8 = rewriter.create<arith::SubIOp>(loc, shift8, one8);
77100

@@ -85,23 +108,20 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
85108
// Note that minimal bitwidth operators are used throughout the block.
86109

87110
Value round64 = rewriter.create<arith::ShLIOp>(
88-
loc, one64,
89-
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(),
90-
shiftSubOne8));
111+
loc, one64, rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftSubOne8));
91112

92113
// Double rounding is performing a round operation before the shift
93114
if (doubleRound) {
94115
Value one32 = rewriter.create<arith::ConstantOp>(
95-
loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
96-
Value shift32 =
97-
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), shift8);
116+
loc, getConstantAttr(i32Ty, 1, rewriter));
117+
Value shift32 = rewriter.create<arith::ExtSIOp>(loc, i32Ty, shift8);
98118
Value thirty32 = rewriter.create<arith::ConstantOp>(
99-
loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
119+
loc, getConstantAttr(i32Ty, 30, rewriter));
100120

101121
Value shiftThirty32 =
102122
rewriter.create<arith::ShLIOp>(loc, one32, thirty32);
103-
Value shiftThirty64 = rewriter.create<arith::ExtSIOp>(
104-
loc, rewriter.getI64Type(), shiftThirty32);
123+
Value shiftThirty64 =
124+
rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftThirty32);
105125

106126
// Round value needs to with be added or subtracted depending on the sign
107127
// of the input value.
@@ -120,7 +140,7 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
120140

121141
// We only perform double rounding if the shift value is greater than 32.
122142
Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
123-
loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32));
143+
loc, getConstantAttr(i32Ty, 32, rewriter));
124144
Value shiftGreaterThanThirtyTwo = rewriter.create<arith::CmpIOp>(
125145
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
126146
round64 = rewriter.create<mlir::SelectOp>(loc, shiftGreaterThanThirtyTwo,
@@ -133,20 +153,17 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
133153
//
134154
// Note that multiply and shift need to be perform in i64 to preserve bits.
135155

136-
Value value64 =
137-
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), value32);
138-
Value multiplier64 = rewriter.create<arith::ExtSIOp>(
139-
loc, rewriter.getI64Type(), multiplier32);
140-
Value shift64 =
141-
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), shift8);
156+
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
157+
Value multiplier64 =
158+
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
159+
Value shift64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, shift8);
142160

143161
// Multiply as a pair of i64 values to guarantee the end value fits.
144162
Value result64 = rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
145163
result64 = rewriter.create<arith::AddIOp>(loc, result64, round64);
146164
result64 = rewriter.create<arith::ShRSIOp>(loc, result64, shift64);
147165

148-
Value result32 =
149-
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), result64);
166+
Value result32 = rewriter.create<arith::TruncIOp>(loc, resultTy, result64);
150167

151168
rewriter.replaceOp(op, result32);
152169
return success();

mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,43 @@ func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
5656

5757
// -----
5858

59+
// CHECK-LABEL: @apply_scale_test_vector
60+
func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
61+
// CHECK-DAG: [[C1_8:%.+]] = arith.constant dense<1> : vector<4xi8>
62+
// CHECK-DAG: [[C1_32:%.+]] = arith.constant dense<1> : vector<4xi32>
63+
// CHECK-DAG: [[C1_64:%.+]] = arith.constant dense<1> : vector<4xi64>
64+
// CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]]
65+
66+
// CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi32>
67+
// CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : vector<4xi8> to vector<4xi64>
68+
// CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]]
69+
70+
// CHECK-DAG: [[C0_32:%.+]] = arith.constant dense<0> : vector<4xi32>
71+
// CHECK-DAG: [[C30_32:%.+]] = arith.constant dense<30> : vector<4xi32>
72+
// CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]]
73+
// CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : vector<4xi32> to vector<4xi64>
74+
// CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
75+
// CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
76+
// CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32>
77+
// CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64>
78+
// CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32>
79+
// CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]]
80+
// CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
81+
82+
// CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64>
83+
// CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64>
84+
// CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi64>
85+
// CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]]
86+
// CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]]
87+
// CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
88+
// CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
89+
90+
%0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
91+
return %0 : vector<4xi32>
92+
}
93+
94+
// -----
95+
5996
// CHECK-LABEL: @apply_scale_test_i48
6097
func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
6198
// CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8

0 commit comments

Comments
 (0)