Skip to content

Commit a6cb5cc

Browse files
authored
[mlir] Add nullptr checks in SparseElementsAttr parser (#133222)
This PR adds nullptr checks in the SparseElementsAttr parser to improve robustness and prevent crashes. Fixes #132891.
1 parent ebe1ece commit a6cb5cc

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
10811081
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
10821082
}
10831083
auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1084+
if (!indices)
1085+
return nullptr;
10841086

10851087
// If the values are a splat, set the shape explicitly based on the number of
10861088
// indices. The number of indices is encoded in the first dimension of the
@@ -1091,6 +1093,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
10911093
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
10921094
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
10931095
auto values = valuesParser.getAttr(valuesLoc, valuesType);
1096+
if (!values)
1097+
return nullptr;
10941098

10951099
// Build the sparse elements attribute by the indices and values.
10961100
return getChecked<SparseElementsAttr>(loc, type, indices, values);

mlir/test/IR/invalid-builtin-attributes.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ func.func @invalid_tensor_literal() {
109109

110110
// -----
111111

112+
func.func @invalid_sparse_indices() {
113+
// expected-error @+1 {{expected integer elements, but parsed floating-point}}
114+
"foo"(){bar = sparse<0.5, 1> : tensor<1xi16>} : () -> ()
115+
}
116+
117+
// -----
118+
119+
func.func @invalid_sparse_values() {
120+
// expected-error @+1 {{expected integer elements, but parsed floating-point}}
121+
"foo"(){bar = sparse<0, 1.1> : tensor<1xi16>} : () -> ()
122+
}
123+
124+
// -----
125+
112126
func.func @hexadecimal_float_leading_minus() {
113127
// expected-error @+1 {{hexadecimal float literal should not have a leading minus}}
114128
"foo"() {value = -0x7fff : f16} : () -> ()

0 commit comments

Comments
 (0)