Skip to content

Commit d61ec51

Browse files
authored
[mlir][spirv] Add IsInf/IsNan expansion for WebGPU (#86903)
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in #86578. Also do some misc cleanups in the surrounding code.
1 parent 5990278 commit d61ec51

File tree

3 files changed

+82
-16
lines changed

3 files changed

+82
-16
lines changed

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@
1818
namespace mlir {
1919
namespace spirv {
2020

21-
/// Appends to a pattern list additional patterns to expand extended
22-
/// multiplication ops into regular arithmetic ops. Extended multiplication ops
23-
/// are not supported by the WebGPU Shading Language (WGSL).
21+
/// Appends patterns to expand extended multiplication and adition ops into
22+
/// regular arithmetic ops. Extended arithmetic ops are not supported by the
23+
/// WebGPU Shading Language (WGSL).
2424
void populateSPIRVExpandExtendedMultiplicationPatterns(
2525
RewritePatternSet &patterns);
2626

27+
/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
28+
/// These are not supported by the WebGPU Shading Language (WGSL). We follow
29+
/// fast math assumptions and assume that all floating point values are finite.
30+
void populateSPIRVExpandNonFiniteArithmeticPatterns(
31+
RewritePatternSet &patterns);
32+
2733
} // namespace spirv
2834
} // namespace mlir
2935

mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ namespace {
3939
//===----------------------------------------------------------------------===//
4040
// Helpers
4141
//===----------------------------------------------------------------------===//
42-
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42+
static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
4343
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
4444
if (auto intTy = dyn_cast<IntegerType>(type))
4545
return IntegerAttr::get(intTy, sizedValue);
4646

4747
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
4848
}
4949

50-
Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
51-
Value lhs, Value rhs,
52-
bool signExtendArguments) {
50+
static Value lowerExtendedMultiplication(Operation *mulOp,
51+
PatternRewriter &rewriter, Value lhs,
52+
Value rhs, bool signExtendArguments) {
5353
Location loc = mulOp->getLoc();
5454
Type argTy = lhs.getType();
5555
// Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
203203
}
204204
};
205205

206+
struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
207+
using OpRewritePattern::OpRewritePattern;
208+
209+
LogicalResult matchAndRewrite(IsInfOp op,
210+
PatternRewriter &rewriter) const override {
211+
// We assume values to be finite and turn `IsInf` info `false`.
212+
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
213+
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
214+
return success();
215+
}
216+
};
217+
218+
struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
219+
using OpRewritePattern::OpRewritePattern;
220+
221+
LogicalResult matchAndRewrite(IsNanOp op,
222+
PatternRewriter &rewriter) const override {
223+
// We assume values to be finite and turn `IsNan` info `false`.
224+
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
225+
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
226+
return success();
227+
}
228+
};
229+
206230
//===----------------------------------------------------------------------===//
207231
// Passes
208232
//===----------------------------------------------------------------------===//
209-
class WebGPUPreparePass
210-
: public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
211-
public:
233+
struct WebGPUPreparePass final
234+
: impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
212235
void runOnOperation() override {
213236
RewritePatternSet patterns(&getContext());
214237
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
238+
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
215239

216240
if (failed(
217241
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
227251
RewritePatternSet &patterns) {
228252
// WGSL currently does not support extended multiplication ops, see:
229253
// https://github.com/gpuweb/gpuweb/issues/1565.
230-
patterns.add<
231-
// clang-format off
232-
ExpandSMulExtendedPattern,
233-
ExpandUMulExtendedPattern,
234-
ExpandAddCarryPattern
235-
>(patterns.getContext());
254+
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255+
ExpandAddCarryPattern>(patterns.getContext());
236256
}
257+
258+
void populateSPIRVExpandNonFiniteArithmeticPatterns(
259+
RewritePatternSet &patterns) {
260+
// WGSL currently does not support `isInf` and `isNan`, see:
261+
// https://github.com/gpuweb/gpuweb/pull/2311.
262+
patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
263+
}
264+
237265
} // namespace spirv
238266
} // namespace mlir

mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
182182
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
183183
}
184184

185+
// CHECK-LABEL: func @is_inf_f32
186+
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
187+
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
188+
spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
189+
%0 = spirv.IsInf %a : f32
190+
spirv.ReturnValue %0 : i1
191+
}
192+
193+
// CHECK-LABEL: func @is_inf_4xf32
194+
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
195+
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
196+
spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
197+
%0 = spirv.IsInf %a : vector<4xf32>
198+
spirv.ReturnValue %0 : vector<4xi1>
199+
}
200+
201+
// CHECK-LABEL: func @is_nan_f32
202+
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
203+
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
204+
spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
205+
%0 = spirv.IsNan %a : f32
206+
spirv.ReturnValue %0 : i1
207+
}
208+
209+
// CHECK-LABEL: func @is_nan_4xf32
210+
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
211+
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
212+
spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
213+
%0 = spirv.IsNan %a : vector<4xf32>
214+
spirv.ReturnValue %0 : vector<4xi1>
215+
}
216+
185217
} // end module

0 commit comments

Comments
 (0)