Skip to content

Commit 1b82245

Browse files
committed
[mlir][spirv] Add smul_extended expansion for WebGPU
We need this because WGSL does not support extended multiplication ops. Fixes: llvm#59563 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D141096
1 parent 2b1a517 commit 1b82245

File tree

3 files changed

+245
-67
lines changed

3 files changed

+245
-67
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp

Lines changed: 104 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/Location.h"
18+
#include "mlir/IR/PatternMatch.h"
1819
#include "mlir/IR/TypeUtilities.h"
20+
#include "mlir/Support/LogicalResult.h"
1921
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2022
#include "llvm/ADT/ArrayRef.h"
2123
#include "llvm/ADT/STLExtras.h"
@@ -45,90 +47,126 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
4547
return SplatElementsAttr::get(type, sizedValue);
4648
}
4749

50+
Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
51+
Value lhs, Value rhs,
52+
bool signExtendArguments) {
53+
Location loc = mulOp->getLoc();
54+
Type argTy = lhs.getType();
55+
// Emulate 64-bit multiplication by splitting each input element of type i32
56+
// into 2 16-bit digits of type i32. This is so that the intermediate
57+
// multiplications and additions do not overflow. We extract these 16-bit
58+
// digits from i32 vector elements by masking (low digit) and shifting right
59+
// (high digit).
60+
//
61+
// The multiplication algorithm used is the standard (long) multiplication.
62+
// Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
63+
// digits.
64+
// - With zero-extended arguments, we end up emitting only 4 multiplications
65+
// and 4 additions after constant folding.
66+
// - With sign-extended arguments, we end up emitting 8 multiplications and
67+
// and 12 additions after CSE.
68+
Value cstLowMask = rewriter.create<ConstantOp>(
69+
loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
70+
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
71+
return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
72+
};
73+
74+
Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
75+
getScalarOrSplatAttr(argTy, 16));
76+
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
77+
return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
78+
};
79+
80+
auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
81+
// We only need to shift arithmetically by 15, but the extra
82+
// sign-extension bit will be truncated by the logical shift, so this is
83+
// fine. We do not have to introduce an extra constant since any
84+
// value in [15, 32) would do.
85+
return getHighDigit(
86+
rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
87+
};
88+
89+
Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
90+
getScalarOrSplatAttr(argTy, 0));
91+
92+
Value lhsLow = getLowDigit(lhs);
93+
Value lhsHigh = getHighDigit(lhs);
94+
Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
95+
Value rhsLow = getLowDigit(rhs);
96+
Value rhsHigh = getHighDigit(rhs);
97+
Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
98+
99+
std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
100+
std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
101+
std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102+
103+
for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
104+
for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
105+
if (i + j >= resultDigits.size())
106+
continue;
107+
108+
if (lhsDigit == cst0 || rhsDigit == cst0)
109+
continue;
110+
111+
Value &thisResDigit = resultDigits[i + j];
112+
Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
113+
Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
114+
thisResDigit = getLowDigit(current);
115+
116+
if (i + j + 1 != resultDigits.size()) {
117+
Value &nextResDigit = resultDigits[i + j + 1];
118+
Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
119+
getHighDigit(current));
120+
nextResDigit = carry;
121+
}
122+
}
123+
}
124+
125+
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
126+
Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
127+
return rewriter.create<BitwiseOrOp>(loc, low, highBits);
128+
};
129+
Value low = combineDigits(resultDigits[0], resultDigits[1]);
130+
Value high = combineDigits(resultDigits[2], resultDigits[3]);
131+
132+
return rewriter.create<CompositeConstructOp>(
133+
loc, mulOp->getResultTypes().front(), llvm::makeArrayRef({low, high}));
134+
}
135+
48136
//===----------------------------------------------------------------------===//
49137
// Rewrite Patterns
50138
//===----------------------------------------------------------------------===//
51-
struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
52-
using OpRewritePattern::OpRewritePattern;
53139

54-
LogicalResult matchAndRewrite(UMulExtendedOp op,
140+
template <typename MulExtendedOp, bool SignExtendArguments>
141+
struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142+
using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
143+
144+
LogicalResult matchAndRewrite(MulExtendedOp op,
55145
PatternRewriter &rewriter) const override {
56146
Location loc = op->getLoc();
57147
Value lhs = op.getOperand1();
58148
Value rhs = op.getOperand2();
59-
Type argTy = lhs.getType();
60149

61150
// Currently, WGSL only supports 32-bit integer types. Any other integer
62151
// types should already have been promoted/demoted to i32.
63-
auto elemTy = getElementTypeOrSelf(argTy).cast<IntegerType>();
152+
auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>();
64153
if (elemTy.getIntOrFloatBitWidth() != 32)
65154
return rewriter.notifyMatchFailure(
66155
loc,
67156
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
68157

69-
// Emulate 64-bit multiplication by splitting each input element of type i32
70-
// into 2 16-bit digits of type i32. This is so that the intermediate
71-
// multiplications and additions do not overflow. We extract these 16-bit
72-
// digits from i32 vector elements by masking (low digit) and shifting right
73-
// (high digit).
74-
//
75-
// The multiplication algorithm used is the standard (long) multiplication.
76-
// Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
77-
// digits. After constant-folding, we end up emitting only 4 multiplications
78-
// and 4 additions.
79-
Value cstLowMask = rewriter.create<ConstantOp>(
80-
loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
81-
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
82-
return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
83-
};
84-
85-
Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
86-
getScalarOrSplatAttr(argTy, 16));
87-
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
88-
return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
89-
};
90-
91-
Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
92-
getScalarOrSplatAttr(argTy, 0));
93-
94-
Value lhsLow = getLowDigit(lhs);
95-
Value lhsHigh = getHighDigit(lhs);
96-
Value rhsLow = getLowDigit(rhs);
97-
Value rhsHigh = getHighDigit(rhs);
98-
99-
std::array<Value, 2> lhsDigits = {lhsLow, lhsHigh};
100-
std::array<Value, 2> rhsDigits = {rhsLow, rhsHigh};
101-
std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102-
103-
for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
104-
for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
105-
Value &thisResDigit = resultDigits[i + j];
106-
Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
107-
Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
108-
thisResDigit = getLowDigit(current);
109-
110-
if (i + j + 1 != resultDigits.size()) {
111-
Value &nextResDigit = resultDigits[i + j + 1];
112-
Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
113-
getHighDigit(current));
114-
nextResDigit = carry;
115-
}
116-
}
117-
}
118-
119-
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
120-
Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
121-
return rewriter.create<BitwiseOrOp>(loc, low, highBits);
122-
};
123-
Value low = combineDigits(resultDigits[0], resultDigits[1]);
124-
Value high = combineDigits(resultDigits[2], resultDigits[3]);
125-
126-
rewriter.replaceOpWithNewOp<CompositeConstructOp>(
127-
op, op.getType(), llvm::makeArrayRef({low, high}));
158+
Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159+
SignExtendArguments);
160+
rewriter.replaceOp(op, mul);
128161
return success();
129162
}
130163
};
131164

165+
using ExpandSMulExtendedPattern =
166+
ExpandMulExtendedPattern<SMulExtendedOp, true>;
167+
using ExpandUMulExtendedPattern =
168+
ExpandMulExtendedPattern<UMulExtendedOp, false>;
169+
132170
//===----------------------------------------------------------------------===//
133171
// Passes
134172
//===----------------------------------------------------------------------===//
@@ -153,9 +191,8 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
153191
RewritePatternSet &patterns) {
154192
// WGSL currently does not support extended multiplication ops, see:
155193
// https://github.com/gpuweb/gpuweb/issues/1565.
156-
// TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended
157-
// expansion.
158-
patterns.add<ExpandUMulExtendedPattern>(patterns.getContext());
194+
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
195+
patterns.getContext());
159196
}
160197
} // namespace spirv
161198
} // namespace mlir

mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,79 @@ spirv.func @umul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
7070
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
7171
}
7272

73+
//===----------------------------------------------------------------------===//
74+
// spirv.SMulExtended
75+
//===----------------------------------------------------------------------===//
76+
77+
// CHECK-LABEL: func @smul_extended_i32
78+
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
79+
// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
80+
// CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32
81+
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
82+
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
83+
// CHECK-NEXT: [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : i32
84+
// CHECK-NEXT: [[LHSEXT:%.+]] = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : i32
85+
// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
86+
// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
87+
// CHECK-NEXT: [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : i32
88+
// CHECK-NEXT: [[RHSEXT:%.+]] = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : i32
89+
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
90+
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
91+
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSEXT]]
92+
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
93+
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
94+
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSEXT]]
95+
// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSLOW]]
96+
// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSHI]]
97+
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
98+
// CHECK: spirv.BitwiseOr
99+
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
100+
// CHECK: spirv.BitwiseOr
101+
// CHECK: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
102+
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
103+
spirv.func @smul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
104+
%0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
105+
spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
106+
}
107+
108+
// CHECK-LABEL: func @smul_extended_vector_i32
109+
// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
110+
// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
111+
// CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
112+
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
113+
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
114+
// CHECK-NEXT: [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : vector<3xi32>
115+
// CHECK-NEXT: [[LHSEXT:%.+]] = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : vector<3xi32>
116+
// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
117+
// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
118+
// CHECK-NEXT: [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : vector<3xi32>
119+
// CHECK-NEXT: [[RHSEXT:%.+]] = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : vector<3xi32>
120+
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
121+
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
122+
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSEXT]]
123+
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
124+
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
125+
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSEXT]]
126+
// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSLOW]]
127+
// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSHI]]
128+
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
129+
// CHECK: spirv.BitwiseOr
130+
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
131+
// CHECK: spirv.BitwiseOr
132+
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]]
133+
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
134+
spirv.func @smul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
135+
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
136+
%0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
137+
spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
138+
}
139+
140+
// CHECK-LABEL: func @smul_extended_i16
141+
// CHECK-NEXT: spirv.SMulExtended
142+
// CHECK-NEXT: spirv.ReturnValue
143+
spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
144+
%0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i16, i16)>
145+
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
146+
}
147+
73148
} // end module
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Make sure that signed extended multiplication produces expected results
2+
// with and without expansion to primitive mul/add ops for WebGPU.
3+
4+
// RUN: mlir-vulkan-runner %s \
5+
// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
6+
// RUN: --entry-point-result=void | FileCheck %s
7+
8+
// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
9+
// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
10+
// RUN: --entry-point-result=void | FileCheck %s
11+
12+
// CHECK: [0, 1, -2, 1, 1048560, -87620295, -131071, 560969770]
13+
// CHECK: [0, 0, -1, 0, 0, -1, 0, -499807318]
14+
module attributes {
15+
gpu.container_module,
16+
spirv.target_env = #spirv.target_env<
17+
#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
18+
} {
19+
gpu.module @kernels {
20+
gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>)
21+
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
22+
%0 = gpu.block_id x
23+
%lhs = memref.load %arg0[%0] : memref<8xi32>
24+
%rhs = memref.load %arg1[%0] : memref<8xi32>
25+
%low, %hi = arith.mulsi_extended %lhs, %rhs : i32
26+
memref.store %low, %arg2[%0] : memref<8xi32>
27+
memref.store %hi, %arg3[%0] : memref<8xi32>
28+
gpu.return
29+
}
30+
}
31+
32+
func.func @main() {
33+
%buf0 = memref.alloc() : memref<8xi32>
34+
%buf1 = memref.alloc() : memref<8xi32>
35+
%buf2 = memref.alloc() : memref<8xi32>
36+
%buf3 = memref.alloc() : memref<8xi32>
37+
%i32_0 = arith.constant 0 : i32
38+
39+
// Initialize output buffers.
40+
%buf4 = memref.cast %buf2 : memref<8xi32> to memref<?xi32>
41+
%buf5 = memref.cast %buf3 : memref<8xi32> to memref<?xi32>
42+
call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
43+
call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
44+
45+
%idx_0 = arith.constant 0 : index
46+
%idx_1 = arith.constant 1 : index
47+
%idx_8 = arith.constant 8 : index
48+
49+
// Initialize input buffers.
50+
%lhs_vals = arith.constant dense<[0, 1, -1, -1, 65535, 65535, -65535, 2088183954]> : vector<8xi32>
51+
%rhs_vals = arith.constant dense<[0, 1, 2, -1, 16, -1337, -65535, -1028001427]> : vector<8xi32>
52+
vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32>
53+
vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32>
54+
55+
gpu.launch_func @kernels::@kernel_add
56+
blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
57+
args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>)
58+
%buf_low = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
59+
%buf_hi = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
60+
call @printMemrefI32(%buf_low) : (memref<*xi32>) -> ()
61+
call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> ()
62+
return
63+
}
64+
func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
65+
func.func private @printMemrefI32(%ptr : memref<*xi32>)
66+
}

0 commit comments

Comments
 (0)