Skip to content

Commit 497452b

Browse files
committed
[mlir] Guard sccp pass from crashing with different source type
Vector::BroadCastOp expects the idential element type in folding. It causes the crash if the different source type is given to the SCCP pass. We need to guard the pass from crashing if the non-idential element type is given, but still compatible. (e.g. index vs integer type)
1 parent 34f8573 commit 497452b

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
25232523
if (!adaptor.getSource())
25242524
return {};
25252525
auto vectorType = getResultVectorType();
2526-
if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2527-
return DenseElementsAttr::get(vectorType, adaptor.getSource());
2526+
if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2527+
if (vectorType.getElementType() != attr.getType())
2528+
return {};
2529+
return DenseElementsAttr::get(vectorType, attr);
2530+
}
2531+
if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2532+
if (vectorType.getElementType() != attr.getType())
2533+
return {};
2534+
return DenseElementsAttr::get(vectorType, attr);
2535+
}
25282536
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
25292537
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
25302538
return {};

mlir/test/Transforms/sccp.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,16 @@ func.func @op_with_region() -> (i32) {
246246
^b:
247247
return %1 : i32
248248
}
249+
250+
// CHECK-LABEL: no_crash_with_different_source_type
251+
func.func @no_crash_with_different_source_type() {
252+
// CHECK: llvm.mlir.constant(0 : index) : i64
253+
%0 = llvm.mlir.constant(0 : index) : i64
254+
llvm.br ^b1(%0 : i64)
255+
^b1(%1: i64):
256+
llvm.br ^b2
257+
^b2:
258+
// CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
259+
%2 = vector.broadcast %1 : i64 to vector<128xi64>
260+
llvm.return
261+
}

0 commit comments

Comments
 (0)