Skip to content

Commit eea8648

Browse files
authored
[mlir][vector] Fold broadcast(poison) -> poison (#135677)
In addition to the new folder, I've also a test for broadcast(splat) -> splat which I think was missing Signed-off-by: James Newling <[email protected]>
1 parent 73b8750 commit eea8648

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,6 +2590,8 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
25902590
}
25912591
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
25922592
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2593+
if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2594+
return ub::PoisonAttr::get(getContext());
25932595
return {};
25942596
}
25952597

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,28 @@ func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) {
11511151

11521152
// -----
11531153

1154+
// CHECK-LABEL: broadcast_poison
1155+
// CHECK: %[[POISON:.*]] = ub.poison : vector<4x6xi8>
1156+
// CHECK: return %[[POISON]] : vector<4x6xi8>
1157+
func.func @broadcast_poison() -> vector<4x6xi8> {
1158+
%poison = ub.poison : vector<6xi8>
1159+
%broadcast = vector.broadcast %poison : vector<6xi8> to vector<4x6xi8>
1160+
return %broadcast : vector<4x6xi8>
1161+
}
1162+
1163+
// -----
1164+
1165+
// CHECK-LABEL: broadcast_splat_constant
1166+
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
1167+
// CHECK: return %[[CONST]] : vector<4x6xi8>
1168+
func.func @broadcast_splat_constant() -> vector<4x6xi8> {
1169+
%cst = arith.constant dense<1> : vector<6xi8>
1170+
%broadcast = vector.broadcast %cst : vector<6xi8> to vector<4x6xi8>
1171+
return %broadcast : vector<4x6xi8>
1172+
}
1173+
1174+
// -----
1175+
11541176
// CHECK-LABEL: broadcast_folding1
11551177
// CHECK: %[[CST:.*]] = arith.constant dense<42> : vector<4xi32>
11561178
// CHECK-NOT: vector.broadcast

0 commit comments

Comments
 (0)