Skip to content

Commit 81c13b6

Browse files
committed
Implement folding for constant tosa.muls
* Folds multiplications with constant operands (limited to muls with shift = 0) * Add unit test for the folding * Implement saturating semantics for i32 overflows, might require changes if the spec clarification comes in [0] [0] https://discuss.mlplatform.org/t/integer-multiplication-overflow-handling/187
1 parent c697ece commit 81c13b6

File tree

5 files changed

+285
-0
lines changed

5 files changed

+285
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ void populateTosaFoldConstantClampPatterns(MLIRContext *ctx,
3737
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
3838
RewritePatternSet &patterns,
3939
bool enableIntCastFolding);
40+
void populateTosaFoldConstantMulPatterns(MLIRContext *ctx,
41+
RewritePatternSet &patterns);
4042
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
4143
RewritePatternSet &patterns);
4244
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
66
TosaFoldConstantAdd.cpp
77
TosaFoldConstantCast.cpp
88
TosaFoldConstantClamp.cpp
9+
TosaFoldConstantMul.cpp
910
TosaFoldConstantPow.cpp
1011
TosaFoldConstantReciprocal.cpp
1112
TosaFoldConstantRSQRT.cpp
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//===- TosaFoldConstantMul.cpp --------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Mul operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <llvm/ADT/APInt.h>
19+
#include <mlir/Support/LogicalResult.h>
20+
21+
using namespace mlir;
22+
using namespace mlir::tosa;
23+
24+
namespace {
25+
26+
struct TosaFoldConstantMul : public OpRewritePattern<MulOp> {
27+
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(MulOp mulOp,
31+
PatternRewriter &rewriter) const override {
32+
if (mulOp.getShift() > 0) {
33+
return rewriter.notifyMatchFailure(
34+
mulOp, "Non-zero shift folding is currently not implemented.");
35+
}
36+
37+
auto leftOp = mulOp.getInput1();
38+
auto rightOp = mulOp.getInput2();
39+
40+
// Check if both tensors are constant
41+
auto rhsIsConstantCheck =
42+
notifyIfNoTosaDenseConstantTensor(leftOp, mulOp, rewriter);
43+
if (failed(rhsIsConstantCheck)) {
44+
return rhsIsConstantCheck;
45+
}
46+
auto lhsIsConstantCheck =
47+
notifyIfNoTosaDenseConstantTensor(rightOp, mulOp, rewriter);
48+
if (failed(lhsIsConstantCheck)) {
49+
return lhsIsConstantCheck;
50+
}
51+
52+
// Extract the tensor values
53+
DenseElementsAttr lhsValues;
54+
matchPattern(leftOp, m_Constant(&lhsValues));
55+
56+
DenseElementsAttr rhsValues;
57+
matchPattern(rightOp, m_Constant(&rhsValues));
58+
59+
if (!constantBinaryOpShouldBeFolded(mulOp, lhsValues, rhsValues)) {
60+
return rewriter.notifyMatchFailure(
61+
mulOp, "Currently, muls will only be folded if this requires only "
62+
"little additional memory usage.");
63+
}
64+
65+
DenseElementsAttr newTensor;
66+
67+
auto lhsElemType = leftOp.getType().getElementType();
68+
auto rhsElemType = rightOp.getType().getElementType();
69+
assert(lhsElemType == rhsElemType);
70+
71+
auto resultType = mulOp.getType();
72+
auto resultElementType = resultType.getElementType();
73+
if (isa<IntegerType>(lhsElemType)) {
74+
assert(isa<IntegerType>(rhsElemType) &&
75+
isa<IntegerType>(resultElementType));
76+
auto resultElementWidth = resultElementType.getIntOrFloatBitWidth();
77+
assert(resultElementWidth == 32 &&
78+
"All integer multiplications in TOSA are specified to result in "
79+
"32 bit width");
80+
// TODO: To implement shifts > 0, capture the shift value stored in the
81+
// mul here
82+
auto intMulFun = [&resultElementWidth](const APInt &first,
83+
const APInt &second) {
84+
// TODO the documentation has conflicting definitions for the behavior
85+
// of overflows
86+
// The sign extend should always be valid as the result type is required
87+
// to be i32 and all other integer input types are smaller or equal
88+
// to 32.
89+
return first.sext(resultElementWidth)
90+
.smul_sat(second.sext(resultElementWidth));
91+
};
92+
newTensor = applyElementWise<APInt, APInt>(lhsValues, rhsValues,
93+
resultType, intMulFun);
94+
} else {
95+
assert(isa<FloatType>(lhsElemType) && isa<FloatType>(rhsElemType) &&
96+
isa<FloatType>(resultType.getElementType()));
97+
auto mulFun = [](const APFloat &first, const APFloat &second) {
98+
return first * second;
99+
};
100+
newTensor = applyElementWise<APFloat, APFloat>(lhsValues, rhsValues,
101+
resultType, mulFun);
102+
}
103+
rewriter.replaceOpWithNewOp<ConstOp>(mulOp, newTensor.getType(), newTensor);
104+
105+
return success();
106+
}
107+
};
108+
109+
} // namespace
110+
111+
void mlir::tosa::populateTosaFoldConstantMulPatterns(
112+
MLIRContext *ctx, RewritePatternSet &patterns) {
113+
patterns.add<TosaFoldConstantMul>(ctx);
114+
}

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct TosaLayerwiseConstantFoldPass
5454
mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns,
5555
enableIntCastFolding);
5656
mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns);
57+
mlir::tosa::populateTosaFoldConstantMulPatterns(ctx, patterns);
5758
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
5859
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5960
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// Float multiplications
4+
5+
// CHECK-LABEL: @mul_fold_float
6+
func.func @mul_fold_float() -> tensor<4xf16> {
7+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.32{{.*}}e+03, -1.49{{.*}}e+01, -0.{{0*}}e+00, -0.{{0*}}e+00
8+
// CHECK-NOT: tosa.mul
9+
// CHECK: return [[RES]]
10+
%0 = "tosa.const"() {value =
11+
dense<[-17.4978, 4.9882, 0.0, -0.0]> :
12+
tensor<4xf16>
13+
} : () -> tensor<4xf16>
14+
%1 = "tosa.const"() {value =
15+
dense<[-132.7, -3.0, -0.0, 5.0]> :
16+
tensor<4xf16>
17+
} : () -> tensor<4xf16>
18+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16>
19+
return %2 : tensor<4xf16>
20+
}
21+
22+
// CHECK-LABEL: @mul_fold_float_infinity_nan
23+
func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> {
24+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0x7F800000, 0xFF800000, 0xFF800000, 0x7FC00000, 0xFF800000, 0x7FC00000
25+
// CHECK-NOT: tosa.mul
26+
// CHECK: return [[RES]]
27+
%0 = "tosa.const"() {value =
28+
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000, 0xFF800000]> :
29+
tensor<7xf32>
30+
} : () -> tensor<7xf32>
31+
%1 = "tosa.const"() {value =
32+
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> :
33+
tensor<7xf32>
34+
} : () -> tensor<7xf32>
35+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32>
36+
return %2 : tensor<7xf32>
37+
}
38+
39+
// CHECK-LABEL: @add_fold_float_overflow
40+
func.func @add_fold_float_overflow() -> tensor<2xf32> {
41+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000
42+
// CHECK-NOT: tosa.mul
43+
// CHECK: return [[RES]]
44+
%0 = "tosa.const"() {value =
45+
dense<[3.1e+38, -3.1e+38]> :
46+
tensor<2xf32>
47+
} : () -> tensor<2xf32>
48+
%1 = "tosa.const"() {value =
49+
dense<[2.1e+38, 1.1e+38]> :
50+
tensor<2xf32>
51+
} : () -> tensor<2xf32>
52+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
53+
return %2 : tensor<2xf32>
54+
}
55+
56+
// -----
57+
// Int multiplications
58+
59+
// CHECK-LABEL: @mul_fold_int
60+
func.func @mul_fold_int() -> tensor<4xi32> {
61+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2244, -12, 0, 0
62+
// CHECK-NOT: tosa.mul
63+
// CHECK: return [[RES]]
64+
%0 = "tosa.const"() {value =
65+
dense<[-17, 4, 0, 0]> :
66+
tensor<4xi32>
67+
} : () -> tensor<4xi32>
68+
%1 = "tosa.const"() {value =
69+
dense<[-132, -3, 0, 5]> :
70+
tensor<4xi32>
71+
} : () -> tensor<4xi32>
72+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
73+
return %2 : tensor<4xi32>
74+
}
75+
76+
// -----
77+
// self-multiplication
78+
79+
// CHECK-LABEL: @mul_fold_int_overflow
80+
// TODO: Change expected behavior if the tosa.mul on i32 should not be
81+
// saturating. Also add a test with different widths in that case.
82+
func.func @mul_fold_int_overflow() -> tensor<4xi32> {
83+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2147483647, 2147483647, -2147483648, -2147483648
84+
// CHECK-NOT: tosa.mul
85+
// CHECK: return [[RES]]
86+
%0 = "tosa.const"() {value =
87+
dense<[2147483647, 2147483640, -2147483648, -2147483640]> :
88+
tensor<4xi32>
89+
} : () -> tensor<4xi32>
90+
%1 = "tosa.const"() {value =
91+
dense<[1, 10, 1, 30]> :
92+
tensor<4xi32>
93+
} : () -> tensor<4xi32>
94+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
95+
return %2 : tensor<4xi32>
96+
}
97+
98+
// CHECK-LABEL: @mul_fold_equal_args
99+
func.func @mul_fold_equal_args() -> tensor<3xi32> {
100+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0
101+
// CHECK-NOT: tosa.mul
102+
// CHECK: return [[RES]]
103+
%0 = "tosa.const"() {value =
104+
dense<[-17, 4, 0]> :
105+
tensor<3xi32>
106+
} : () -> tensor<3xi32>
107+
%2 = "tosa.mul"(%0, %0) {shift = 0 : i32} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
108+
return %2 : tensor<3xi32>
109+
}
110+
111+
// -----
112+
// Broadcasted multiplications
113+
114+
// CHECK-LABEL: @mul_fold_int_broadcast_simple
115+
func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> {
116+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -48, 0
117+
// CHECK-NOT: tosa.mul
118+
// CHECK: return [[RES]]
119+
%0 = "tosa.const"() {value =
120+
dense<[-17, 4, 0]> :
121+
tensor<3xi32>
122+
} : () -> tensor<3xi32>
123+
%1 = "tosa.const"() {value =
124+
dense<-12> :
125+
tensor<1xi32>
126+
} : () -> tensor<1xi32>
127+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
128+
return %2 : tensor<3xi32>
129+
}
130+
131+
// CHECK-LABEL: @mul_fold_int_broadcast_complex
132+
func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> {
133+
// CHECK: [[RES:]] ={{.*}}tosa.const
134+
// CHECK-SAME{LITERAL}: [[204, -119, -68],
135+
// CHECK-SAME{LITERAL}: [-12, 7, 4],
136+
// CHECK-SAME{LITERAL}: [-228, 133, 76]]
137+
// CHECK-NOT: tosa.mul
138+
// CHECK: return [[RES]]
139+
%0 = "tosa.const"() {value =
140+
dense<[[-17], [1], [19]]> :
141+
tensor<3x1xi32>
142+
} : () -> tensor<3x1xi32>
143+
%1 = "tosa.const"() {value =
144+
dense<[[-12, 7, 4]]> :
145+
tensor<1x3xi32>
146+
} : () -> tensor<1x3xi32>
147+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32>
148+
return %2 : tensor<3x3xi32>
149+
}
150+
151+
// CHECK-LABEL: @mul_fold_int_non_zero_shift
152+
func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> {
153+
// CHECK: [[FIRST:]] ={{.*}}tosa.const
154+
// CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const
155+
// CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]]
156+
// CHECK-NEXT: return [[MUL]]
157+
%0 = "tosa.const"() {value =
158+
dense<[-17, 4, 0, 0]> :
159+
tensor<4xi32>
160+
} : () -> tensor<4xi32>
161+
%1 = "tosa.const"() {value =
162+
dense<[-132, -3, 0, 5]> :
163+
tensor<4xi32>
164+
} : () -> tensor<4xi32>
165+
%2 = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
166+
return %2 : tensor<4xi32>
167+
}

0 commit comments

Comments
 (0)