-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Fix some issues related to converting ub.poison to SPIR-V #125905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Andrea Faulds (andfau-amd) ChangesThis is a follow-up to 5df62bd. It should not be necessary for the vector.insert and vector.extract conversions to SPIR-V to directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however:
With these fixed, this commit removes the redundant static poison index handling. Full diff: https://github.com/llvm/llvm-project/pull/125905.diff 6 Files Affected:
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index a3806189e40608..01c35cba48c490 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -29,11 +29,6 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
matchAndRewrite(ub::PoisonOp op, OpAdaptor,
ConversionPatternRewriter &rewriter) const override {
Type origType = op.getType();
- if (!origType.isIntOrIndexOrFloat())
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "unsupported type " << origType;
- });
-
Type resType = getTypeConverter()->convertType(origType);
if (!resType)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2c8bc149dc708d..1c70cb4d287d45 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -182,17 +182,12 @@ struct VectorExtractOpConvert final
if (std::optional<int64_t> id =
getConstantIntValue(extractOp.getMixedPosition()[0])) {
- // TODO: ExtractOp::fold() already can fold a static poison index to
- // ub.poison; remove this once ub.poison can be converted to SPIR-V.
- if (id == vector::ExtractOp::kPoisonIndex) {
- // Arbitrary choice of poison result, intended to stick out.
- Value zero =
- spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
- rewriter.replaceOp(extractOp, zero);
- } else
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
- rewriter.getI32ArrayAttr(id.value()));
+ // Static use of the poison index is handled elsewhere (folded to poison).
+ if (id == vector::ExtractOp::kPoisonIndex)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
@@ -306,16 +301,11 @@ struct VectorInsertOpConvert final
if (std::optional<int64_t> id =
getConstantIntValue(insertOp.getMixedPosition()[0])) {
- // TODO: ExtractOp::fold() already can fold a static poison index to
- // ub.poison; remove this once ub.poison can be converted to SPIR-V.
- if (id == vector::InsertOp::kPoisonIndex) {
- // Arbitrary choice of poison result, intended to stick out.
- Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
- insertOp.getLoc(), rewriter);
- rewriter.replaceOp(insertOp, zero);
- } else
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ // Static use of the poison index is handled elsewhere (folded to poison).
+ if (id == vector::InsertOp::kPoisonIndex)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index 510f7a2d94c9ec..951bcf607ccb5f 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -22,6 +22,16 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
// -----
+// CHECK-LABEL: @extract_poison_idx
+// CHECK: %[[R:.+]] = spirv.Undef : f32
+// CHECK: spirv.ReturnValue %[[R]] : f32
+func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
+ %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -51,6 +61,16 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
// -----
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[R:.+]] = spirv.Undef : vector<4xf32>
+// CHECK: spirv.ReturnValue %[[R]] : vector<4xf32>
+func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
+ %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
+ return %1: vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_element
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
index 771b53ad123b92..f497eb3bc552ca 100644
--- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -13,8 +13,7 @@ func.func @check_poison() {
%1 = ub.poison : i16
// CHECK: {{.*}} = spirv.Undef : f64
%2 = ub.poison : f64
-// TODO: vector is not covered yet
-// CHECK: {{.*}} = ub.poison : vector<4xf32>
+// CHECK: {{.*}} = spirv.Undef : vector<4xf32>
%3 = ub.poison : vector<4xf32>
return
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 5fd7324b1d3c73..408b09cd6d794b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -175,9 +175,11 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// -----
+// CHECK-LABEL: @extract_poison_idx
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+// CHECK: %[[R:.+]] = vector.extract %[[ARG0]][-1] : f32 from vector<4xf32>
+// CHECK: return %[[R]] : f32
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
- // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
- // CHECK: return %[[ZERO]]
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}
@@ -285,8 +287,9 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
// CHECK-LABEL: @insert_poison_idx
-// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
-// CHECK: return %[[ZERO]]
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
+// CHECK: %[[R:.*]] = vector.insert %[[S]], %[[V]] [-1] : f32 into vector<4xf32>
+// CHECK: return %[[R]] : vector<4xf32>
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
diff --git a/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp b/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp
index 3c99d3c5b60ced..d406a74ea024cd 100644
--- a/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp
+++ b/mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -93,6 +94,7 @@ struct TestConvertToSPIRVPass final
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<spirv::SPIRVDialect>();
+ registry.insert<ub::UBDialect>();
registry.insert<vector::VectorDialect>();
}
|
aba1a10
to
4d1f345
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % one suggestion
4d1f345
to
208e584
Compare
@Hardcode84 I added you as a reviewer here because you seem to be the one that added this restriction against converting |
This is a follow-up to 5df62bd. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however: - The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940e). - The conversion pattern wasn't being populated in VectorToSPIRVPass, which is used by the tests. This commit changes this. - The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction. With these fixed, this commit removes the redundant static poison index handling, and updates the tests.
208e584
to
a806d81
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Hardcode84 I added you as a reviewer here because you seem to be the one that added this restriction against converting ub.poison to spirv.Undef for some types.
Yeah, I'm fine with removing the restriction
Thanks! |
…-V (llvm#125905) This is a follow-up to 5df62bd. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however: - The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940e). - The conversion pattern wasn't being populated in VectorToSPIRVPass, which is used by the tests. This commit changes this. - The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction. With these fixed, this commit removes the redundant static poison index handling, and updates the tests.
This is a follow-up to 5df62bd. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however:
With these fixed, this commit removes the redundant static poison index handling, and updates the tests.