Skip to content

[mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes #135504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2025

Conversation

Hardcode84
Copy link
Contributor

Calling getIntOrFloatBitWidth on non-int/float types (gpu.shuffle also accepts vectors) will crash.

Calling `getIntOrFloatBitWidth` on non-int/float types (gpu.shuffle also accepts vectors) will crash.
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Ivan Butygin (Hardcode84)

Changes

Calling getIntOrFloatBitWidth on non-int/float types (gpu.shuffle also accepts vectors) will crash.


Full diff: https://github.com/llvm/llvm-project/pull/135504.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+9-7)
  • (modified) mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp (+3-2)
  • (added) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir (+13)
  • (modified) mlir/test/Dialect/GPU/shuffle-rewrite.mlir (+11)
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 4891dab3aa1d0..c6c695b442b4f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
+    Value initShflValue = adaptor.getValue();
+    Type shflType = initShflValue.getType();
     // TODO: Add support for non 32-bit shuffle values.
-    if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
-      return failure();
+    if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
+      return rewriter.notifyMatchFailure(
+          op, "only 32-bit int/float types are supported");
+
     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
 
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
     Value dwordAlignedDstLane =
         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
-    Value initShflValue = adaptor.getValue();
-    if (adaptor.getValue().getType().isF32()) {
+    if (shflType.isF32()) {
       initShflValue =
           rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
     }
     Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
         loc, int32Type, dwordAlignedDstLane, initShflValue);
-    if (adaptor.getValue().getType().isF32()) {
-      shflValue = rewriter.create<LLVM::BitcastOp>(
-          loc, adaptor.getValue().getType(), shflValue);
+    if (shflType.isF32()) {
+      shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
     }
     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
     return success();
diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
index 4bd4da25f6e52..9f2900214e8b1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
@@ -40,8 +40,9 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
     auto i64 = rewriter.getI64Type();
 
     // If the type of the value is either i32 or f32, the op is already valid.
-    if (valueType.getIntOrFloatBitWidth() == 32)
-      return failure();
+    if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
+      return rewriter.notifyMatchFailure(
+          op, "only 64-bit int/float types are supported");
 
     Value lo, hi;
 
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
new file mode 100644
index 0000000000000..90f2e5f047cd9
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
+
+gpu.module @test_module {
+  // ROCDL lowering only suport shuffles for 32bit ints/floats, but they
+  // shouldn't crash on unsupported types.
+  func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 64 : i32
+    // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
+    %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+    return %shfl : vector<4xf16>
+  }
+}
diff --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
index 4618258201532..c0ccae05a0572 100644
--- a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
@@ -49,3 +49,14 @@ module {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: @gpu_shuffle_unsupported
+func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 64 : i32
+  // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : vector<4xf16>
+  %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+  return %shfl : vector<4xf16>
+}

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved to fix crashing but maybe we should add bitcasts for the vector cases as well?

@Hardcode84
Copy link
Contributor Author

Approved to fix crashing but maybe we should add bitcasts for the vector cases as well?

Yeah, I was planning to work on it soon.

@Hardcode84 Hardcode84 merged commit d893d12 into llvm:main Apr 13, 2025
14 checks passed
@Hardcode84 Hardcode84 deleted the fix-shuffle-crash branch April 15, 2025 06:00
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…vm#135504)

Calling `getIntOrFloatBitWidth` on non-int/float types (`gpu.shuffle`
also accepts vectors) will crash.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants