Skip to content

Commit f497fe4

Browse files
authored
[mlir][spirv] Fix some issues related to converting ub.poison to SPIR-V (#125905)
This is a follow-up to 5df62bd. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however: - The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940e). - The conversion pattern wasn't being populated in VectorToSPIRVPass, which is used by the tests. This commit changes this. - The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction. With these fixed, this commit removes the redundant static poison index handling, and updates the tests.
1 parent b884be8 commit f497fe4

File tree

6 files changed

+23
-32
lines changed

6 files changed

+23
-32
lines changed

mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
2929
matchAndRewrite(ub::PoisonOp op, OpAdaptor,
3030
ConversionPatternRewriter &rewriter) const override {
3131
Type origType = op.getType();
32-
if (!origType.isIntOrIndexOrFloat())
33-
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
34-
diag << "unsupported type " << origType;
35-
});
36-
3732
Type resType = getTypeConverter()->convertType(origType);
3833
if (!resType)
3934
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {

mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
1515
MLIRSPIRVConversion
1616
MLIRVectorDialect
1717
MLIRTransforms
18+
MLIRUBToSPIRV
1819
)

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,13 @@ struct VectorExtractOpConvert final
182182

183183
if (std::optional<int64_t> id =
184184
getConstantIntValue(extractOp.getMixedPosition()[0])) {
185-
// TODO: ExtractOp::fold() already can fold a static poison index to
186-
// ub.poison; remove this once ub.poison can be converted to SPIR-V.
187-
if (id == vector::ExtractOp::kPoisonIndex) {
188-
// Arbitrary choice of poison result, intended to stick out.
189-
Value zero =
190-
spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
191-
rewriter.replaceOp(extractOp, zero);
192-
} else
193-
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
194-
extractOp, dstType, adaptor.getVector(),
195-
rewriter.getI32ArrayAttr(id.value()));
185+
if (id == vector::ExtractOp::kPoisonIndex)
186+
return rewriter.notifyMatchFailure(
187+
extractOp,
188+
"Static use of poison index handled elsewhere (folded to poison)");
189+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
190+
extractOp, dstType, adaptor.getVector(),
191+
rewriter.getI32ArrayAttr(id.value()));
196192
} else {
197193
Value sanitizedIndex = sanitizeDynamicIndex(
198194
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
@@ -306,16 +302,12 @@ struct VectorInsertOpConvert final
306302

307303
if (std::optional<int64_t> id =
308304
getConstantIntValue(insertOp.getMixedPosition()[0])) {
309-
// TODO: ExtractOp::fold() already can fold a static poison index to
310-
// ub.poison; remove this once ub.poison can be converted to SPIR-V.
311-
if (id == vector::InsertOp::kPoisonIndex) {
312-
// Arbitrary choice of poison result, intended to stick out.
313-
Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
314-
insertOp.getLoc(), rewriter);
315-
rewriter.replaceOp(insertOp, zero);
316-
} else
317-
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
318-
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
305+
if (id == vector::InsertOp::kPoisonIndex)
306+
return rewriter.notifyMatchFailure(
307+
insertOp,
308+
"Static use of poison index handled elsewhere (folded to poison)");
309+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
310+
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
319311
} else {
320312
Value sanitizedIndex = sanitizeDynamicIndex(
321313
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
1414

15+
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
1516
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1617
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1718
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -49,6 +50,8 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
4950

5051
RewritePatternSet patterns(context);
5152
populateVectorToSPIRVPatterns(typeConverter, patterns);
53+
// Used for folds, e.g. vector.extract[-1] -> ub.poison -> spirv.Undef.
54+
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
5255

5356
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
5457
return signalPassFailure();

mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ func.func @check_poison() {
1313
%1 = ub.poison : i16
1414
// CHECK: {{.*}} = spirv.Undef : f64
1515
%2 = ub.poison : f64
16-
// TODO: vector is not covered yet
17-
// CHECK: {{.*}} = ub.poison : vector<4xf32>
16+
// CHECK: {{.*}} = spirv.Undef : vector<4xf32>
1817
%3 = ub.poison : vector<4xf32>
1918
return
2019
}

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
175175

176176
// -----
177177

178+
// CHECK-LABEL: @extract_poison_idx
179+
// CHECK: %[[R:.+]] = spirv.Undef : f32
180+
// CHECK: return %[[R]]
178181
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
179-
// CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
180-
// CHECK: return %[[ZERO]]
181182
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
182183
return %0: f32
183184
}
@@ -285,8 +286,8 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
285286
// -----
286287

287288
// CHECK-LABEL: @insert_poison_idx
288-
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
289-
// CHECK: return %[[ZERO]]
289+
// CHECK: %[[R:.+]] = spirv.Undef : vector<4xf32>
290+
// CHECK: return %[[R]]
290291
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
291292
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
292293
return %1: vector<4xf32>

0 commit comments

Comments
 (0)