-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add IsInf/IsNan expansion for WebGPU #86903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in llvm#86578.
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesThese non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in Also do some misc cleanups in the surrounding code. Full diff: https://github.com/llvm/llvm-project/pull/86903.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
index ac4d38e0c5b1eb..d0fc85ccc9de49 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
@@ -18,12 +18,18 @@
namespace mlir {
namespace spirv {
-/// Appends to a pattern list additional patterns to expand extended
-/// multiplication ops into regular arithmetic ops. Extended multiplication ops
-/// are not supported by the WebGPU Shading Language (WGSL).
+/// Appends patterns to expand extended multiplication and adition ops into
+/// regular arithmetic ops. Extended arithmetic ops are not supported by the
+/// WebGPU Shading Language (WGSL).
void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns);
+/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
+/// These are not supported by the WebGPU Shading Language (WGSL). We follow
+/// fast math assumptions and assume that all floating point values are finite.
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+ RewritePatternSet &patterns);
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 21de1c9e867c04..5d4dd5b3a1e013 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -39,7 +39,7 @@ namespace {
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
-Attribute getScalarOrSplatAttr(Type type, int64_t value) {
+static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}
-Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
- Value lhs, Value rhs,
- bool signExtendArguments) {
+static Value lowerExtendedMultiplication(Operation *mulOp,
+ PatternRewriter &rewriter, Value lhs,
+ Value rhs, bool signExtendArguments) {
Location loc = mulOp->getLoc();
Type argTy = lhs.getType();
// Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
}
};
+struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsInfOp op,
+ PatternRewriter &rewriter) const override {
+ // We assume values to be finite and turn `IsInf` info `false`.
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+ return success();
+ }
+};
+
+struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsNanOp op,
+ PatternRewriter &rewriter) const override {
+ // We assume values to be finite and turn `IsNan` info `false`.
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
-class WebGPUPreparePass
- : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
-public:
+struct WebGPUPreparePass final
+ : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
+ populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
- patterns.add<
- // clang-format off
- ExpandSMulExtendedPattern,
- ExpandUMulExtendedPattern,
- ExpandAddCarryPattern
- >(patterns.getContext());
+ patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
+ ExpandAddCarryPattern>(patterns.getContext());
}
+
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+ RewritePatternSet &patterns) {
+ // WGSL currently does not support `isInf` and `isNan`, see:
+ // https://github.com/gpuweb/gpuweb/pull/2311.
+ patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
+}
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 1ec4e5e4f9664b..45f188da3815cf 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
+// CHECK-LABEL: func @is_inf_f32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
+ %0 = spirv.IsInf %a : f32
+ spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_inf_4xf32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+ %0 = spirv.IsInf %a : vector<4xf32>
+ spirv.ReturnValue %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: func @is_nan_f32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
+ %0 = spirv.IsNan %a : f32
+ spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_nan_4xf32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+ %0 = spirv.IsNan %a : vector<4xf32>
+ spirv.ReturnValue %0 : vector<4xi1>
+}
+
} // end module
|
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into
false
.Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in
#86578.
Also do some misc cleanups in the surrounding code.