Skip to content

[mlir][spirv] Fix some issues related to converting ub.poison to SPIR-V #125905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025

Conversation

andfau-amd
Copy link
Contributor

@andfau-amd andfau-amd commented Feb 5, 2025

This is a follow-up to 5df62bd. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however:

  • The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940e).
  • The conversion pattern wasn't being populated in VectorToSPIRVPass, which is used by the tests. This commit changes this.
  • The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction.

With these fixed, this commit removes the redundant static poison index handling, and updates the tests.

@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Andrea Faulds (andfau-amd)

Changes

This is a follow-up to 5df62bd. It should not be necessary for the vector.insert and vector.extract conversions to SPIR-V to directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however:

  • The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940e).
  • The conversion pattern from ub.poison to spirv.Undef wasn't being used. This commit adds it to TestConvertToSPIRVPass for testing. There may be other places that should also have it.
  • The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction.

With these fixed, this commit removes the redundant static poison index handling.


Full diff: https://github.com/llvm/llvm-project/pull/125905.diff

6 Files Affected:

  • (modified) mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp (-5)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+11-21)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+20)
  • (modified) mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir (+1-2)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+7-4)
  • (modified) mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp (+2)
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index a3806189e40608..01c35cba48c490 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -29,11 +29,6 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
   matchAndRewrite(ub::PoisonOp op, OpAdaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type origType = op.getType();
-    if (!origType.isIntOrIndexOrFloat())
-      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-        diag << "unsupported type " << origType;
-      });
-
     Type resType = getTypeConverter()->convertType(origType);
     if (!resType)
       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2c8bc149dc708d..1c70cb4d287d45 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -182,17 +182,12 @@ struct VectorExtractOpConvert final
 
     if (std::optional<int64_t> id =
             getConstantIntValue(extractOp.getMixedPosition()[0])) {
-      // TODO: ExtractOp::fold() already can fold a static poison index to
-      //       ub.poison; remove this once ub.poison can be converted to SPIR-V.
-      if (id == vector::ExtractOp::kPoisonIndex) {
-        // Arbitrary choice of poison result, intended to stick out.
-        Value zero =
-            spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
-        rewriter.replaceOp(extractOp, zero);
-      } else
-        rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-            extractOp, dstType, adaptor.getVector(),
-            rewriter.getI32ArrayAttr(id.value()));
+      // Static use of the poison index is handled elsewhere (folded to poison).
+      if (id == vector::ExtractOp::kPoisonIndex)
+        return failure();
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
           rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
@@ -306,16 +301,11 @@ struct VectorInsertOpConvert final
 
     if (std::optional<int64_t> id =
             getConstantIntValue(insertOp.getMixedPosition()[0])) {
-      // TODO: ExtractOp::fold() already can fold a static poison index to
-      //       ub.poison; remove this once ub.poison can be converted to SPIR-V.
-      if (id == vector::InsertOp::kPoisonIndex) {
-        // Arbitrary choice of poison result, intended to stick out.
-        Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
-                                                insertOp.getLoc(), rewriter);
-        rewriter.replaceOp(insertOp, zero);
-      } else
-        rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-            insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+      // Static use of the poison index is handled elsewhere (folded to poison).
+      if (id == vector::InsertOp::kPoisonIndex)
+        return failure();
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
           rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index 510f7a2d94c9ec..951bcf607ccb5f 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -22,6 +22,16 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_poison_idx
+//       CHECK:   %[[R:.+]] = spirv.Undef : f32
+//       CHECK:   spirv.ReturnValue %[[R]] : f32
+func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
+  %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -51,6 +61,16 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_poison_idx
+//       CHECK:   %[[R:.+]] = spirv.Undef : vector<4xf32>
+//       CHECK:   spirv.ReturnValue %[[R]] : vector<4xf32>
+func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
+  %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
+  return %1: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
index 771b53ad123b92..f497eb3bc552ca 100644
--- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -13,8 +13,7 @@ func.func @check_poison() {
   %1 = ub.poison : i16
 // CHECK: {{.*}} = spirv.Undef : f64
   %2 = ub.poison : f64
-// TODO: vector is not covered yet
-// CHECK: {{.*}} = ub.poison : vector<4xf32>
+// CHECK: {{.*}} = spirv.Undef : vector<4xf32>
   %3 = ub.poison : vector<4xf32>
   return
 }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 5fd7324b1d3c73..408b09cd6d794b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -175,9 +175,11 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 
 // -----
 
+// CHECK-LABEL: @extract_poison_idx
+//  CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+//       CHECK:   %[[R:.+]] = vector.extract %[[ARG0]][-1] : f32 from vector<4xf32>
+//       CHECK:   return %[[R]] : f32
 func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
-  // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
-  // CHECK: return %[[ZERO]]
   %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
   return %0: f32
 }
@@ -285,8 +287,9 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 // -----
 
 // CHECK-LABEL: @insert_poison_idx
-// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
-// CHECK: return %[[ZERO]]
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
+//       CHECK:   %[[R:.*]] = vector.insert %[[S]], %[[V]] [-1] : f32 into vector<4xf32>
+//       CHECK:   return %[[R]] : vector<4xf32>
 func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
   %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
   return %1: vector<4xf32>
diff --git a/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp b/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp
index 3c99d3c5b60ced..d406a74ea024cd 100644
--- a/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp
+++ b/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -93,6 +94,7 @@ struct TestConvertToSPIRVPass final
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<spirv::SPIRVDialect>();
+    registry.insert<ub::UBDialect>();
     registry.insert<vector::VectorDialect>();
   }
 

@andfau-amd andfau-amd force-pushed the vector-to-ub-to-spirv branch 3 times, most recently from aba1a10 to 4d1f345 Compare February 6, 2025 14:10
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % one suggestion

@andfau-amd andfau-amd force-pushed the vector-to-ub-to-spirv branch from 4d1f345 to 208e584 Compare February 6, 2025 15:54
@andfau-amd
Copy link
Contributor Author

@Hardcode84 I added you as a reviewer here because you seem to be the one that added this restriction against converting ub.poison to spirv.Undef for some types.

This is a follow-up to 5df62bd. That
commit should not have needed to make the vector.insert and
vector.extract conversions to SPIR-V directly handle the static poison
index case, as there is a fold from those to ub.poison, and a conversion
pattern from ub.poison to spirv.Undef, however:

- The ub.poison fold result could not be materialized by the vector
  dialect (fixed as of d13940e).
- The conversion pattern wasn't being populated in VectorToSPIRVPass,
  which is used by the tests. This commit changes this.
- The ub.poison to spirv.Undef pattern rejected non-scalar types, which
  prevented its use for vector results. It is unclear why this
  restriction existed; a remark in D156163 said this was to avoid
  converting "user types", but it is not obvious why these shouldn't
  be permitted (the SPIR-V specification allows OpUndef for all types
  except OpTypeVoid). This commit removes this restriction.

With these fixed, this commit removes the redundant static poison index
handling, and updates the tests.
@andfau-amd andfau-amd force-pushed the vector-to-ub-to-spirv branch from 208e584 to a806d81 Compare February 6, 2025 15:59
Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hardcode84 I added you as a reviewer here because you seem to be the one that added this restriction against converting ub.poison to spirv.Undef for some types.

Yeah, I'm fine with removing the restriction

@andfau-amd
Copy link
Contributor Author

Thanks!

@andfau-amd andfau-amd merged commit f497fe4 into llvm:main Feb 6, 2025
6 of 7 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…-V (llvm#125905)

This is a follow-up to 5df62bd. That
commit should not have needed to make the vector.insert and
vector.extract conversions to SPIR-V directly handle the static poison
index case, as there is a fold from those to ub.poison, and a conversion
pattern from ub.poison to spirv.Undef, however:

- The ub.poison fold result could not be materialized by the vector
dialect (fixed as of d13940e).
- The conversion pattern wasn't being populated in VectorToSPIRVPass,
which is used by the tests. This commit changes this.
- The ub.poison to spirv.Undef pattern rejected non-scalar types, which
prevented its use for vector results. It is unclear why this restriction
existed; a remark in D156163 said this was to avoid converting "user
types", but it is not obvious why these shouldn't be permitted (the
SPIR-V specification allows OpUndef for all types except OpTypeVoid).
This commit removes this restriction.

With these fixed, this commit removes the redundant static poison index
handling, and updates the tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants