Skip to content

Commit 0f83f1a

Browse files
authored
Merge pull request #27 from Xilinx/tina.tosamulfolding
[FXML-1930] Implement mul folding
2 parents c697ece + 75b6df8 commit 0f83f1a

File tree

5 files changed

+307
-0
lines changed

5 files changed

+307
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ void populateTosaFoldConstantClampPatterns(MLIRContext *ctx,
3737
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
3838
RewritePatternSet &patterns,
3939
bool enableIntCastFolding);
40+
void populateTosaFoldConstantMulPatterns(MLIRContext *ctx,
41+
RewritePatternSet &patterns);
4042
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
4143
RewritePatternSet &patterns);
4244
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
66
TosaFoldConstantAdd.cpp
77
TosaFoldConstantCast.cpp
88
TosaFoldConstantClamp.cpp
9+
TosaFoldConstantMul.cpp
910
TosaFoldConstantPow.cpp
1011
TosaFoldConstantReciprocal.cpp
1112
TosaFoldConstantRSQRT.cpp
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===- TosaFoldConstantMul.cpp --------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Mul operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <llvm/ADT/APInt.h>
19+
#include <mlir/Support/LogicalResult.h>
20+
21+
using namespace mlir;
22+
using namespace mlir::tosa;
23+
24+
namespace {
25+
26+
struct TosaFoldConstantMul : public OpRewritePattern<MulOp> {
27+
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(MulOp mulOp,
31+
PatternRewriter &rewriter) const override {
32+
if (mulOp.getShift() > 0) {
33+
return rewriter.notifyMatchFailure(
34+
mulOp, "Non-zero shift folding is currently not implemented.");
35+
}
36+
37+
auto leftOp = mulOp.getInput1();
38+
auto rightOp = mulOp.getInput2();
39+
40+
// Check if both tensors are constant
41+
auto rhsIsConstantCheck =
42+
notifyIfNoTosaDenseConstantTensor(leftOp, mulOp, rewriter);
43+
if (failed(rhsIsConstantCheck)) {
44+
return rhsIsConstantCheck;
45+
}
46+
auto lhsIsConstantCheck =
47+
notifyIfNoTosaDenseConstantTensor(rightOp, mulOp, rewriter);
48+
if (failed(lhsIsConstantCheck)) {
49+
return lhsIsConstantCheck;
50+
}
51+
52+
// Extract the tensor values
53+
DenseElementsAttr lhsValues;
54+
matchPattern(leftOp, m_Constant(&lhsValues));
55+
56+
DenseElementsAttr rhsValues;
57+
matchPattern(rightOp, m_Constant(&rhsValues));
58+
59+
if (!constantBinaryOpShouldBeFolded(mulOp, lhsValues, rhsValues)) {
60+
return rewriter.notifyMatchFailure(
61+
mulOp, "Currently, muls will only be folded if this requires only "
62+
"little additional memory usage.");
63+
}
64+
65+
DenseElementsAttr newTensor;
66+
67+
auto lhsElemType = leftOp.getType().getElementType();
68+
auto rhsElemType = rightOp.getType().getElementType();
69+
assert(lhsElemType == rhsElemType);
70+
71+
auto resultType = mulOp.getType();
72+
auto resultElementType = resultType.getElementType();
73+
if (isa<IntegerType>(lhsElemType)) {
74+
assert(isa<IntegerType>(rhsElemType) &&
75+
isa<IntegerType>(resultElementType));
76+
auto resultElementWidth = resultElementType.getIntOrFloatBitWidth();
77+
assert(resultElementWidth >= lhsElemType.getIntOrFloatBitWidth() &&
78+
"The multiplication is expected to have an at least as big output "
79+
"as input type");
80+
81+
// Compute the multiplication and track if an overflow occurred to enable
82+
// emitting a warning
83+
bool mulOverflowed = false;
84+
auto intMulFun = [&resultElementWidth, &mulOverflowed](
85+
const APInt &first, const APInt &second) {
86+
bool didOverflow;
87+
auto res = first.sext(resultElementWidth)
88+
.smul_ov(second.sext(resultElementWidth), didOverflow);
89+
mulOverflowed |= didOverflow;
90+
return res;
91+
};
92+
newTensor = applyElementWise<APInt, APInt>(lhsValues, rhsValues,
93+
resultType, intMulFun);
94+
if (mulOverflowed) {
95+
mulOp.emitWarning(
96+
"Multiplication did overflow. The results are unspecified.");
97+
}
98+
} else {
99+
assert(isa<FloatType>(lhsElemType) && isa<FloatType>(rhsElemType) &&
100+
isa<FloatType>(resultType.getElementType()));
101+
auto mulFun = [](const APFloat &first, const APFloat &second) {
102+
return first * second;
103+
};
104+
newTensor = applyElementWise<APFloat, APFloat>(lhsValues, rhsValues,
105+
resultType, mulFun);
106+
}
107+
rewriter.replaceOpWithNewOp<ConstOp>(mulOp, newTensor.getType(), newTensor);
108+
109+
return success();
110+
}
111+
};
112+
113+
} // namespace
114+
115+
void mlir::tosa::populateTosaFoldConstantMulPatterns(
116+
MLIRContext *ctx, RewritePatternSet &patterns) {
117+
patterns.add<TosaFoldConstantMul>(ctx);
118+
}

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct TosaLayerwiseConstantFoldPass
5454
mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns,
5555
enableIntCastFolding);
5656
mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns);
57+
mlir::tosa::populateTosaFoldConstantMulPatterns(ctx, patterns);
5758
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
5859
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5960
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// Float multiplications
4+
5+
// CHECK-LABEL: @mul_fold_float
6+
func.func @mul_fold_float() -> tensor<4xf16> {
7+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.32{{.*}}e+03, -1.49{{.*}}e+01, -0.{{0*}}e+00, -0.{{0*}}e+00
8+
// CHECK-NOT: tosa.mul
9+
// CHECK: return [[RES]]
10+
%0 = "tosa.const"() {value =
11+
dense<[-17.4978, 4.9882, 0.0, -0.0]> :
12+
tensor<4xf16>
13+
} : () -> tensor<4xf16>
14+
%1 = "tosa.const"() {value =
15+
dense<[-132.7, -3.0, -0.0, 5.0]> :
16+
tensor<4xf16>
17+
} : () -> tensor<4xf16>
18+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16>
19+
return %2 : tensor<4xf16>
20+
}
21+
22+
// CHECK-LABEL: @mul_fold_float_infinity_nan
23+
func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> {
24+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0x7F800000, 0xFF800000, 0xFF800000, 0x7FC00000, 0xFF800000, 0x7FC00000
25+
// CHECK-NOT: tosa.mul
26+
// CHECK: return [[RES]]
27+
%0 = "tosa.const"() {value =
28+
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000, 0xFF800000]> :
29+
tensor<7xf32>
30+
} : () -> tensor<7xf32>
31+
%1 = "tosa.const"() {value =
32+
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> :
33+
tensor<7xf32>
34+
} : () -> tensor<7xf32>
35+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32>
36+
return %2 : tensor<7xf32>
37+
}
38+
39+
// CHECK-LABEL: @add_fold_float_overflow
40+
func.func @add_fold_float_overflow() -> tensor<2xf32> {
41+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000
42+
// CHECK-NOT: tosa.mul
43+
// CHECK: return [[RES]]
44+
%0 = "tosa.const"() {value =
45+
dense<[3.1e+38, -3.1e+38]> :
46+
tensor<2xf32>
47+
} : () -> tensor<2xf32>
48+
%1 = "tosa.const"() {value =
49+
dense<[2.1e+38, 1.1e+38]> :
50+
tensor<2xf32>
51+
} : () -> tensor<2xf32>
52+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
53+
return %2 : tensor<2xf32>
54+
}
55+
56+
// -----
57+
// Int multiplications
58+
59+
// CHECK-LABEL: @mul_fold_int
60+
func.func @mul_fold_int() -> tensor<4xi32> {
61+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2244, -12, 0, 0
62+
// CHECK-NOT: tosa.mul
63+
// CHECK: return [[RES]]
64+
%0 = "tosa.const"() {value =
65+
dense<[-17, 4, 0, 0]> :
66+
tensor<4xi32>
67+
} : () -> tensor<4xi32>
68+
%1 = "tosa.const"() {value =
69+
dense<[-132, -3, 0, 5]> :
70+
tensor<4xi32>
71+
} : () -> tensor<4xi32>
72+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
73+
return %2 : tensor<4xi32>
74+
}
75+
76+
// CHECK-LABEL: @mul_fold_i8
77+
func.func @mul_fold_i8() -> tensor<4xi32> {
78+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -12, 0, 0
79+
// CHECK-NOT: tosa.mul
80+
// CHECK: return [[RES]]
81+
%0 = "tosa.const"() {value =
82+
dense<[-17, 4, -2, 0]> :
83+
tensor<4xi8>
84+
} : () -> tensor<4xi8>
85+
%1 = "tosa.const"() {value =
86+
dense<[-12, -3, 0, 5]> :
87+
tensor<4xi8>
88+
} : () -> tensor<4xi8>
89+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32>
90+
return %2 : tensor<4xi32>
91+
}
92+
93+
// CHECK-LABEL: @mul_fold_int_overflow
94+
func.func @mul_fold_int_overflow() -> tensor<4xi32> {
95+
// Don't expect any specific results for the overflowing multiplication, just
96+
// that it is folded.
97+
// CHECK: [[RES:]] ={{.*}}tosa.const
98+
// CHECK-NOT: tosa.mul
99+
// CHECK: return [[RES]]
100+
%0 = "tosa.const"() {value =
101+
dense<[2147483647, 2147483640, -2147483648, -2147483640]> :
102+
tensor<4xi32>
103+
} : () -> tensor<4xi32>
104+
%1 = "tosa.const"() {value =
105+
dense<[1, 10, 1, 30]> :
106+
tensor<4xi32>
107+
} : () -> tensor<4xi32>
108+
// expected-warning@below {{Multiplication did overflow. The results are unspecified.}}
109+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
110+
return %2 : tensor<4xi32>
111+
}
112+
113+
// -----
114+
// self-multiplication
115+
116+
// CHECK-LABEL: @mul_fold_equal_args
117+
func.func @mul_fold_equal_args() -> tensor<3xi32> {
118+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0
119+
// CHECK-NOT: tosa.mul
120+
// CHECK: return [[RES]]
121+
%0 = "tosa.const"() {value =
122+
dense<[-17, 4, 0]> :
123+
tensor<3xi32>
124+
} : () -> tensor<3xi32>
125+
%2 = "tosa.mul"(%0, %0) {shift = 0 : i32} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
126+
return %2 : tensor<3xi32>
127+
}
128+
129+
// -----
130+
// Broadcasted multiplications
131+
132+
// CHECK-LABEL: @mul_fold_int_broadcast_simple
133+
func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> {
134+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -48, 0
135+
// CHECK-NOT: tosa.mul
136+
// CHECK: return [[RES]]
137+
%0 = "tosa.const"() {value =
138+
dense<[-17, 4, 0]> :
139+
tensor<3xi32>
140+
} : () -> tensor<3xi32>
141+
%1 = "tosa.const"() {value =
142+
dense<-12> :
143+
tensor<1xi32>
144+
} : () -> tensor<1xi32>
145+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
146+
return %2 : tensor<3xi32>
147+
}
148+
149+
// CHECK-LABEL: @mul_fold_int_broadcast_complex
150+
func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> {
151+
// CHECK: [[RES:]] ={{.*}}tosa.const
152+
// CHECK-SAME{LITERAL}: [[204, -119, -68],
153+
// CHECK-SAME{LITERAL}: [-12, 7, 4],
154+
// CHECK-SAME{LITERAL}: [-228, 133, 76]]
155+
// CHECK-NOT: tosa.mul
156+
// CHECK: return [[RES]]
157+
%0 = "tosa.const"() {value =
158+
dense<[[-17], [1], [19]]> :
159+
tensor<3x1xi32>
160+
} : () -> tensor<3x1xi32>
161+
%1 = "tosa.const"() {value =
162+
dense<[[-12, 7, 4]]> :
163+
tensor<1x3xi32>
164+
} : () -> tensor<1x3xi32>
165+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32>
166+
return %2 : tensor<3x3xi32>
167+
}
168+
169+
// CHECK-LABEL: @mul_fold_int_non_zero_shift
170+
func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> {
171+
// CHECK: [[FIRST:]] ={{.*}}tosa.const
172+
// CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const
173+
// CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]]
174+
// CHECK-NEXT: return [[MUL]]
175+
%0 = "tosa.const"() {value =
176+
dense<[-17, 4, 0, 0]> :
177+
tensor<4xi32>
178+
} : () -> tensor<4xi32>
179+
%1 = "tosa.const"() {value =
180+
dense<[-132, -3, 0, 5]> :
181+
tensor<4xi32>
182+
} : () -> tensor<4xi32>
183+
%2 = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
184+
return %2 : tensor<4xi32>
185+
}

0 commit comments

Comments
 (0)