Skip to content

[mlir][spirv] Add IsInf/IsNan expansion for WebGPU #86903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Mar 28, 2024

These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into false.

Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in
#86578.

Also do some misc cleanups in the surrounding code.

These non-finite math ops are supported by SPIR-V but not by WGSL.
Assume finite floating point values and expand these ops into `false`.

Previously, this worked by adding fast math flags during conversion from
arith to spirv, but this got removed in
llvm#86578.
@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into false.

Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in
#86578.

Also do some misc cleanups in the surrounding code.


Full diff: https://github.com/llvm/llvm-project/pull/86903.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h (+9-3)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp (+41-13)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
index ac4d38e0c5b1eb..d0fc85ccc9de49 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
@@ -18,12 +18,18 @@
 namespace mlir {
 namespace spirv {
 
-/// Appends to a pattern list additional patterns to expand extended
-/// multiplication ops into regular arithmetic ops. Extended multiplication ops
-/// are not supported by the WebGPU Shading Language (WGSL).
+/// Appends patterns to expand extended multiplication and adition ops into
+/// regular arithmetic ops. Extended arithmetic ops are not supported by the
+/// WebGPU Shading Language (WGSL).
 void populateSPIRVExpandExtendedMultiplicationPatterns(
     RewritePatternSet &patterns);
 
+/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
+/// These are not supported by the WebGPU Shading Language (WGSL). We follow
+/// fast math assumptions and assume that all floating point values are finite.
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+    RewritePatternSet &patterns);
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 21de1c9e867c04..5d4dd5b3a1e013 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -39,7 +39,7 @@ namespace {
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
-Attribute getScalarOrSplatAttr(Type type, int64_t value) {
+static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
   APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
   if (auto intTy = dyn_cast<IntegerType>(type))
     return IntegerAttr::get(intTy, sizedValue);
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
   return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
 }
 
-Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
-                                  Value lhs, Value rhs,
-                                  bool signExtendArguments) {
+static Value lowerExtendedMultiplication(Operation *mulOp,
+                                         PatternRewriter &rewriter, Value lhs,
+                                         Value rhs, bool signExtendArguments) {
   Location loc = mulOp->getLoc();
   Type argTy = lhs.getType();
   // Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
   }
 };
 
+struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IsInfOp op,
+                                PatternRewriter &rewriter) const override {
+    // We assume values to be finite and turn `IsInf` info `false`.
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+    return success();
+  }
+};
+
+struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IsNanOp op,
+                                PatternRewriter &rewriter) const override {
+    // We assume values to be finite and turn `IsNan` info `false`.
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
-class WebGPUPreparePass
-    : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
-public:
+struct WebGPUPreparePass final
+    : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
+    populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
 
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
     RewritePatternSet &patterns) {
   // WGSL currently does not support extended multiplication ops, see:
   // https://github.com/gpuweb/gpuweb/issues/1565.
-  patterns.add<
-      // clang-format off
-    ExpandSMulExtendedPattern,
-    ExpandUMulExtendedPattern,
-    ExpandAddCarryPattern
-  >(patterns.getContext());
+  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
+               ExpandAddCarryPattern>(patterns.getContext());
 }
+
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+    RewritePatternSet &patterns) {
+  // WGSL currently does not support `isInf` and `isNan`, see:
+  // https://github.com/gpuweb/gpuweb/pull/2311.
+  patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
+}
+
 } // namespace spirv
 } // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 1ec4e5e4f9664b..45f188da3815cf 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
   spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
 }
 
+// CHECK-LABEL: func @is_inf_f32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
+  %0 = spirv.IsInf %a : f32
+  spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_inf_4xf32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+  %0 = spirv.IsInf %a : vector<4xf32>
+  spirv.ReturnValue %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: func @is_nan_f32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
+  %0 = spirv.IsNan %a : f32
+  spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_nan_4xf32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+  %0 = spirv.IsNan %a : vector<4xf32>
+  spirv.ReturnValue %0 : vector<4xi1>
+}
+
 } // end module

@kuhar kuhar requested a review from Groverkss March 28, 2024 15:37
@kuhar kuhar merged commit d61ec51 into llvm:main Mar 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants