Skip to content

Commit d13940e

Browse files
authored
[mlir][Vector] Teach how to materialize UB constant to Vector (llvm#125596)
This PR adds support for UB constant materialization (i.e., generating `ub::PoisonOp` to `VectorDialect::materializeConstant`. This was the reason why the vector folders generating poison didn't work.
1 parent 005b23b commit d13940e

File tree

2 files changed

+17
-20
lines changed

2 files changed

+17
-20
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ void VectorDialect::initialize() {
437437
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
438438
Attribute value, Type type,
439439
Location loc) {
440+
if (isa<ub::PoisonAttrInterface>(value))
441+
return value.getDialect().materializeConstant(builder, value, type, loc);
442+
440443
return arith::ConstantOp::materialize(builder, value, type, loc);
441444
}
442445

@@ -2273,20 +2276,6 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22732276
return success();
22742277
}
22752278

2276-
/// Fold an insert or extract operation into an poison value when a poison index
2277-
/// is found at any dimension of the static position.
2278-
template <typename OpTy>
2279-
LogicalResult
2280-
canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
2281-
if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
2282-
op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
2283-
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
2284-
return success();
2285-
}
2286-
2287-
return failure();
2288-
}
2289-
22902279
} // namespace
22912280

22922281
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2295,7 +2284,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22952284
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22962285
results.add(foldExtractFromShapeCastToShapeCast);
22972286
results.add(foldExtractFromFromElements);
2298-
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
22992287
}
23002288

23012289
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -3068,7 +3056,6 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30683056
MLIRContext *context) {
30693057
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30703058
InsertOpConstantFolder>(context);
3071-
results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
30723059
}
30733060

30743061
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,13 +1250,13 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
12501250

12511251
// -----
12521252

1253-
func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
1253+
func.func @extract_scalar_from_vec_1d_f32_poison_idx(%arg0: vector<16xf32>) -> f32 {
12541254
%0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
12551255
return %0 : f32
12561256
}
1257-
// CHECK-LABEL: @extract_poison_idx
1258-
// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
1259-
// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
1257+
// CHECK-LABEL: @extract_scalar_from_vec_1d_f32_poison_idx
1258+
// CHECK: %[[UB:.*]] = ub.poison : f32
1259+
// CHECK: return %[[UB]] : f32
12601260

12611261
// -----
12621262

@@ -1335,6 +1335,16 @@ func.func @extract_vec_2d_from_vec_3d_f32(%arg0: vector<4x3x16xf32>) -> vector<3
13351335

13361336
// -----
13371337

1338+
func.func @extract_vec_2d_from_vec_3d_f32_poison_idx(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
1339+
%0 = vector.extract %arg0[-1]: vector<3x16xf32> from vector<4x3x16xf32>
1340+
return %0 : vector<3x16xf32>
1341+
}
1342+
// CHECK-LABEL: @extract_vec_2d_from_vec_3d_f32_poison_idx
1343+
// CHECK: %[[UB:.*]] = ub.poison : vector<3x16xf32>
1344+
// CHECK: return %[[UB]] : vector<3x16xf32>
1345+
1346+
// -----
1347+
13381348
func.func @extract_vec_2d_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
13391349
%0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
13401350
return %0 : vector<3x[16]xf32>

0 commit comments

Comments
 (0)