15
15
#include " mlir/Dialect/SPIRV/Transforms/Passes.h"
16
16
#include " mlir/IR/BuiltinAttributes.h"
17
17
#include " mlir/IR/Location.h"
18
+ #include " mlir/IR/PatternMatch.h"
18
19
#include " mlir/IR/TypeUtilities.h"
20
+ #include " mlir/Support/LogicalResult.h"
19
21
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
20
22
#include " llvm/ADT/ArrayRef.h"
21
23
#include " llvm/ADT/STLExtras.h"
@@ -45,90 +47,126 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
45
47
return SplatElementsAttr::get (type, sizedValue);
46
48
}
47
49
50
+ Value lowerExtendedMultiplication (Operation *mulOp, PatternRewriter &rewriter,
51
+ Value lhs, Value rhs,
52
+ bool signExtendArguments) {
53
+ Location loc = mulOp->getLoc ();
54
+ Type argTy = lhs.getType ();
55
+ // Emulate 64-bit multiplication by splitting each input element of type i32
56
+ // into 2 16-bit digits of type i32. This is so that the intermediate
57
+ // multiplications and additions do not overflow. We extract these 16-bit
58
+ // digits from i32 vector elements by masking (low digit) and shifting right
59
+ // (high digit).
60
+ //
61
+ // The multiplication algorithm used is the standard (long) multiplication.
62
+ // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
63
+ // digits.
64
+ // - With zero-extended arguments, we end up emitting only 4 multiplications
65
+ // and 4 additions after constant folding.
66
+ // - With sign-extended arguments, we end up emitting 8 multiplications and
67
+ // and 12 additions after CSE.
68
+ Value cstLowMask = rewriter.create <ConstantOp>(
69
+ loc, lhs.getType (), getScalarOrSplatAttr (argTy, (1 << 16 ) - 1 ));
70
+ auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
71
+ return rewriter.create <BitwiseAndOp>(loc, val, cstLowMask);
72
+ };
73
+
74
+ Value cst16 = rewriter.create <ConstantOp>(loc, lhs.getType (),
75
+ getScalarOrSplatAttr (argTy, 16 ));
76
+ auto getHighDigit = [&rewriter, loc, cst16](Value val) {
77
+ return rewriter.create <ShiftRightLogicalOp>(loc, val, cst16);
78
+ };
79
+
80
+ auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
81
+ // We only need to shift arithmetically by 15, but the extra
82
+ // sign-extension bit will be truncated by the logical shift, so this is
83
+ // fine. We do not have to introduce an extra constant since any
84
+ // value in [15, 32) would do.
85
+ return getHighDigit (
86
+ rewriter.create <ShiftRightArithmeticOp>(loc, val, cst16));
87
+ };
88
+
89
+ Value cst0 = rewriter.create <ConstantOp>(loc, lhs.getType (),
90
+ getScalarOrSplatAttr (argTy, 0 ));
91
+
92
+ Value lhsLow = getLowDigit (lhs);
93
+ Value lhsHigh = getHighDigit (lhs);
94
+ Value lhsExt = signExtendArguments ? getSignDigit (lhs) : cst0;
95
+ Value rhsLow = getLowDigit (rhs);
96
+ Value rhsHigh = getHighDigit (rhs);
97
+ Value rhsExt = signExtendArguments ? getSignDigit (rhs) : cst0;
98
+
99
+ std::array<Value, 4 > lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
100
+ std::array<Value, 4 > rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
101
+ std::array<Value, 4 > resultDigits = {cst0, cst0, cst0, cst0};
102
+
103
+ for (auto [i, lhsDigit] : llvm::enumerate (lhsDigits)) {
104
+ for (auto [j, rhsDigit] : llvm::enumerate (rhsDigits)) {
105
+ if (i + j >= resultDigits.size ())
106
+ continue ;
107
+
108
+ if (lhsDigit == cst0 || rhsDigit == cst0)
109
+ continue ;
110
+
111
+ Value &thisResDigit = resultDigits[i + j];
112
+ Value mul = rewriter.create <IMulOp>(loc, lhsDigit, rhsDigit);
113
+ Value current = rewriter.createOrFold <IAddOp>(loc, thisResDigit, mul);
114
+ thisResDigit = getLowDigit (current);
115
+
116
+ if (i + j + 1 != resultDigits.size ()) {
117
+ Value &nextResDigit = resultDigits[i + j + 1 ];
118
+ Value carry = rewriter.createOrFold <IAddOp>(loc, nextResDigit,
119
+ getHighDigit (current));
120
+ nextResDigit = carry;
121
+ }
122
+ }
123
+ }
124
+
125
+ auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
126
+ Value highBits = rewriter.create <ShiftLeftLogicalOp>(loc, high, cst16);
127
+ return rewriter.create <BitwiseOrOp>(loc, low, highBits);
128
+ };
129
+ Value low = combineDigits (resultDigits[0 ], resultDigits[1 ]);
130
+ Value high = combineDigits (resultDigits[2 ], resultDigits[3 ]);
131
+
132
+ return rewriter.create <CompositeConstructOp>(
133
+ loc, mulOp->getResultTypes ().front (), llvm::makeArrayRef ({low, high}));
134
+ }
135
+
48
136
// ===----------------------------------------------------------------------===//
49
137
// Rewrite Patterns
50
138
// ===----------------------------------------------------------------------===//
51
- struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
52
- using OpRewritePattern::OpRewritePattern;
53
139
54
- LogicalResult matchAndRewrite (UMulExtendedOp op,
140
+ template <typename MulExtendedOp, bool SignExtendArguments>
141
+ struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142
+ using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
143
+
144
+ LogicalResult matchAndRewrite (MulExtendedOp op,
55
145
PatternRewriter &rewriter) const override {
56
146
Location loc = op->getLoc ();
57
147
Value lhs = op.getOperand1 ();
58
148
Value rhs = op.getOperand2 ();
59
- Type argTy = lhs.getType ();
60
149
61
150
// Currently, WGSL only supports 32-bit integer types. Any other integer
62
151
// types should already have been promoted/demoted to i32.
63
- auto elemTy = getElementTypeOrSelf (argTy ).cast <IntegerType>();
152
+ auto elemTy = getElementTypeOrSelf (lhs. getType () ).cast <IntegerType>();
64
153
if (elemTy.getIntOrFloatBitWidth () != 32 )
65
154
return rewriter.notifyMatchFailure (
66
155
loc,
67
156
llvm::formatv (" Unexpected integer type for WebGPU: '{0}'" , elemTy));
68
157
69
- // Emulate 64-bit multiplication by splitting each input element of type i32
70
- // into 2 16-bit digits of type i32. This is so that the intermediate
71
- // multiplications and additions do not overflow. We extract these 16-bit
72
- // digits from i32 vector elements by masking (low digit) and shifting right
73
- // (high digit).
74
- //
75
- // The multiplication algorithm used is the standard (long) multiplication.
76
- // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
77
- // digits. After constant-folding, we end up emitting only 4 multiplications
78
- // and 4 additions.
79
- Value cstLowMask = rewriter.create <ConstantOp>(
80
- loc, lhs.getType (), getScalarOrSplatAttr (argTy, (1 << 16 ) - 1 ));
81
- auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
82
- return rewriter.create <BitwiseAndOp>(loc, val, cstLowMask);
83
- };
84
-
85
- Value cst16 = rewriter.create <ConstantOp>(loc, lhs.getType (),
86
- getScalarOrSplatAttr (argTy, 16 ));
87
- auto getHighDigit = [&rewriter, loc, cst16](Value val) {
88
- return rewriter.create <ShiftRightLogicalOp>(loc, val, cst16);
89
- };
90
-
91
- Value cst0 = rewriter.create <ConstantOp>(loc, lhs.getType (),
92
- getScalarOrSplatAttr (argTy, 0 ));
93
-
94
- Value lhsLow = getLowDigit (lhs);
95
- Value lhsHigh = getHighDigit (lhs);
96
- Value rhsLow = getLowDigit (rhs);
97
- Value rhsHigh = getHighDigit (rhs);
98
-
99
- std::array<Value, 2 > lhsDigits = {lhsLow, lhsHigh};
100
- std::array<Value, 2 > rhsDigits = {rhsLow, rhsHigh};
101
- std::array<Value, 4 > resultDigits = {cst0, cst0, cst0, cst0};
102
-
103
- for (auto [i, lhsDigit] : llvm::enumerate (lhsDigits)) {
104
- for (auto [j, rhsDigit] : llvm::enumerate (rhsDigits)) {
105
- Value &thisResDigit = resultDigits[i + j];
106
- Value mul = rewriter.create <IMulOp>(loc, lhsDigit, rhsDigit);
107
- Value current = rewriter.createOrFold <IAddOp>(loc, thisResDigit, mul);
108
- thisResDigit = getLowDigit (current);
109
-
110
- if (i + j + 1 != resultDigits.size ()) {
111
- Value &nextResDigit = resultDigits[i + j + 1 ];
112
- Value carry = rewriter.createOrFold <IAddOp>(loc, nextResDigit,
113
- getHighDigit (current));
114
- nextResDigit = carry;
115
- }
116
- }
117
- }
118
-
119
- auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
120
- Value highBits = rewriter.create <ShiftLeftLogicalOp>(loc, high, cst16);
121
- return rewriter.create <BitwiseOrOp>(loc, low, highBits);
122
- };
123
- Value low = combineDigits (resultDigits[0 ], resultDigits[1 ]);
124
- Value high = combineDigits (resultDigits[2 ], resultDigits[3 ]);
125
-
126
- rewriter.replaceOpWithNewOp <CompositeConstructOp>(
127
- op, op.getType (), llvm::makeArrayRef ({low, high}));
158
+ Value mul = lowerExtendedMultiplication (op, rewriter, lhs, rhs,
159
+ SignExtendArguments);
160
+ rewriter.replaceOp (op, mul);
128
161
return success ();
129
162
}
130
163
};
131
164
165
+ using ExpandSMulExtendedPattern =
166
+ ExpandMulExtendedPattern<SMulExtendedOp, true >;
167
+ using ExpandUMulExtendedPattern =
168
+ ExpandMulExtendedPattern<UMulExtendedOp, false >;
169
+
132
170
// ===----------------------------------------------------------------------===//
133
171
// Passes
134
172
// ===----------------------------------------------------------------------===//
@@ -153,9 +191,8 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
153
191
RewritePatternSet &patterns) {
154
192
// WGSL currently does not support extended multiplication ops, see:
155
193
// https://github.com/gpuweb/gpuweb/issues/1565.
156
- // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended
157
- // expansion.
158
- patterns.add <ExpandUMulExtendedPattern>(patterns.getContext ());
194
+ patterns.add <ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
195
+ patterns.getContext ());
159
196
}
160
197
} // namespace spirv
161
198
} // namespace mlir
0 commit comments