Skip to content

Commit 0a53005

Browse files
jpienaarkuhar
andauthored
[mlir] Avoid common folder assuming all types are supported (llvm#68054)
Previously this would just assume all conversions are possible and this would crash. Use an in-tree testing case here. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 9c1c221 commit 0a53005

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
9393
if (lhs.getType() != rhs.getType())
9494
return {};
9595

96-
auto lhsIt = lhs.value_begin<ElementValueT>();
97-
auto rhsIt = rhs.value_begin<ElementValueT>();
96+
auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97+
auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98+
if (!maybeLhsIt || !maybeRhsIt)
99+
return {};
100+
auto lhsIt = *maybeLhsIt;
101+
auto rhsIt = *maybeRhsIt;
98102
SmallVector<ElementValueT, 4> elementResults;
99103
elementResults.reserve(lhs.getNumElements());
100104
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
@@ -227,7 +231,10 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
227231
// expanding the values.
228232
auto op = cast<ElementsAttr>(operands[0]);
229233

230-
auto opIt = op.value_begin<ElementValueT>();
234+
auto maybeOpIt = op.try_value_begin<ElementValueT>();
235+
if (!maybeOpIt)
236+
return {};
237+
auto opIt = *maybeOpIt;
231238
SmallVector<ElementValueT> elementResults;
232239
elementResults.reserve(op.getNumElements());
233240
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
@@ -293,12 +300,14 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
293300
return {};
294301
return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
295302
}
296-
if (isa<ElementsAttr>(operands[0])) {
303+
if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
297304
// Operand is ElementsAttr-derived; perform an element-wise fold by
298305
// expanding the value.
299-
auto op = cast<ElementsAttr>(operands[0]);
300306
bool castStatus = true;
301-
auto opIt = op.value_begin<ElementValueT>();
307+
auto maybeOpIt = op.try_value_begin<ElementValueT>();
308+
if (!maybeOpIt)
309+
return {};
310+
auto opIt = *maybeOpIt;
302311
SmallVector<TargetElementValueT> elementResults;
303312
elementResults.reserve(op.getNumElements());
304313
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,3 +2669,20 @@ func.func @extsi_poison() -> i64 {
26692669
%1 = arith.extsi %0 : i32 to i64
26702670
return %1 : i64
26712671
}
2672+
2673+
// Just checks that this doesn't crash.
2674+
// CHECK-LABEL: @unsignedExtendConstantResource
2675+
func.func @unsignedExtendConstantResource() -> tensor<i16> {
2676+
%c2 = arith.constant dense_resource<blob1> : tensor<i8>
2677+
%ext = arith.extui %c2 : tensor<i8> to tensor<i16>
2678+
return %ext : tensor<i16>
2679+
}
2680+
2681+
{-#
2682+
dialect_resources: {
2683+
builtin: {
2684+
// Note: This is just copied blob, the actual value isn't used or checked.
2685+
blob1: "0x08000000010000000000000002000000000000000300000000000000"
2686+
}
2687+
}
2688+
#-}

0 commit comments

Comments
 (0)