Skip to content

Commit 6ddc03d

Browse files
authored
[mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd (#68495)
WebGPU does not currently support extended arithmetic, this is an issue when we want to lower from SPIR-V. This commit adds a pattern to transform and emulate spirv.IAddCarry with spirv.IAdd operations Fixes #65154
1 parent ab17ecd commit 6ddc03d

File tree

3 files changed

+147
-2
lines changed

3 files changed

+147
-2
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,42 @@ using ExpandSMulExtendedPattern =
167167
using ExpandUMulExtendedPattern =
168168
ExpandMulExtendedPattern<UMulExtendedOp, false>;
169169

170+
struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
171+
using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
172+
173+
LogicalResult matchAndRewrite(IAddCarryOp op,
174+
PatternRewriter &rewriter) const override {
175+
Location loc = op->getLoc();
176+
Value lhs = op.getOperand1();
177+
Value rhs = op.getOperand2();
178+
179+
// Currently, WGSL only supports 32-bit integer types. Any other integer
180+
// types should already have been promoted/demoted to i32.
181+
Type argTy = lhs.getType();
182+
auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
183+
if (elemTy.getIntOrFloatBitWidth() != 32)
184+
return rewriter.notifyMatchFailure(
185+
loc,
186+
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
187+
188+
Value one =
189+
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
190+
Value zero =
191+
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
192+
193+
// Calculate the carry by checking if the addition resulted in an overflow.
194+
Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
195+
Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
196+
Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
197+
198+
Value add = rewriter.create<CompositeConstructOp>(
199+
loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
200+
201+
rewriter.replaceOp(op, add);
202+
return success();
203+
}
204+
};
205+
170206
//===----------------------------------------------------------------------===//
171207
// Passes
172208
//===----------------------------------------------------------------------===//
@@ -191,8 +227,12 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
191227
RewritePatternSet &patterns) {
192228
// WGSL currently does not support extended multiplication ops, see:
193229
// https://github.com/gpuweb/gpuweb/issues/1565.
194-
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
195-
patterns.getContext());
230+
patterns.add<
231+
// clang-format off
232+
ExpandSMulExtendedPattern,
233+
ExpandUMulExtendedPattern,
234+
ExpandAddCarryPattern
235+
>(patterns.getContext());
196236
}
197237
} // namespace spirv
198238
} // namespace mlir

mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,41 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
145145
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
146146
}
147147

148+
// CHECK-LABEL: func @iaddcarry_i32
149+
// CHECK-SAME: ([[A:%.+]]: i32, [[B:%.+]]: i32)
150+
// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant 1 : i32
151+
// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant 0 : i32
152+
// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
153+
// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
154+
// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
155+
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (i32, i32) -> !spirv.struct<(i32, i32)>
156+
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
157+
spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
158+
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
159+
spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
160+
}
161+
162+
// CHECK-LABEL: func @iaddcarry_vector_i32
163+
// CHECK-SAME: ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
164+
// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant dense<1> : vector<3xi32>
165+
// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
166+
// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
167+
// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
168+
// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
169+
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
170+
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
171+
spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
172+
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
173+
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
174+
spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
175+
}
176+
177+
// CHECK-LABEL: func @iaddcarry_i16
178+
// CHECK-NEXT: spirv.IAddCarry
179+
// CHECK-NEXT: spirv.ReturnValue
180+
spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None" {
181+
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(i16, i16)>
182+
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
183+
}
184+
148185
} // end module
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Make sure that addition with carry produces expected results
2+
// with and without expansion to primitive add/cmp ops for WebGPU.
3+
4+
// RUN: mlir-vulkan-runner %s \
5+
// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
6+
// RUN: --entry-point-result=void | FileCheck %s
7+
8+
// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
9+
// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
10+
// RUN: --entry-point-result=void | FileCheck %s
11+
12+
// CHECK: [0, 42, 0, 42]
13+
// CHECK: [1, 0, 1, 1]
14+
module attributes {
15+
gpu.container_module,
16+
spirv.target_env = #spirv.target_env<
17+
#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
18+
} {
19+
gpu.module @kernels {
20+
gpu.func @kernel_add(%arg0 : memref<4xi32>, %arg1 : memref<4xi32>, %arg2 : memref<4xi32>, %arg3 : memref<4xi32>)
21+
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
22+
%0 = gpu.block_id x
23+
%lhs = memref.load %arg0[%0] : memref<4xi32>
24+
%rhs = memref.load %arg1[%0] : memref<4xi32>
25+
%sum, %carry = arith.addui_extended %lhs, %rhs : i32, i1
26+
27+
%carry_i32 = arith.extui %carry : i1 to i32
28+
29+
memref.store %sum, %arg2[%0] : memref<4xi32> memref.store %carry_i32, %arg3[%0] : memref<4xi32>
30+
gpu.return
31+
}
32+
}
33+
34+
func.func @main() {
35+
%buf0 = memref.alloc() : memref<4xi32>
36+
%buf1 = memref.alloc() : memref<4xi32>
37+
%buf2 = memref.alloc() : memref<4xi32>
38+
%buf3 = memref.alloc() : memref<4xi32>
39+
%i32_0 = arith.constant 0 : i32
40+
41+
// Initialize output buffers.
42+
%buf4 = memref.cast %buf2 : memref<4xi32> to memref<?xi32>
43+
%buf5 = memref.cast %buf3 : memref<4xi32> to memref<?xi32>
44+
call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
45+
call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
46+
47+
%idx_0 = arith.constant 0 : index
48+
%idx_1 = arith.constant 1 : index
49+
%idx_4 = arith.constant 4 : index
50+
51+
// Initialize input buffers.
52+
%lhs_vals = arith.constant dense<[-1, 24, 4294967295, 43]> : vector<4xi32>
53+
%rhs_vals = arith.constant dense<[1, 18, 1, 4294967295]> : vector<4xi32>
54+
vector.store %lhs_vals, %buf0[%idx_0] : memref<4xi32>, vector<4xi32>
55+
vector.store %rhs_vals, %buf1[%idx_0] : memref<4xi32>, vector<4xi32>
56+
57+
gpu.launch_func @kernels::@kernel_add
58+
blocks in (%idx_4, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
59+
args(%buf0 : memref<4xi32>, %buf1 : memref<4xi32>, %buf2 : memref<4xi32>, %buf3 : memref<4xi32>)
60+
%buf_sum = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
61+
%buf_carry = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
62+
call @printMemrefI32(%buf_sum) : (memref<*xi32>) -> ()
63+
call @printMemrefI32(%buf_carry) : (memref<*xi32>) -> ()
64+
return
65+
}
66+
func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
67+
func.func private @printMemrefI32(%ptr : memref<*xi32>)
68+
}

0 commit comments

Comments
 (0)