Skip to content

Commit 3ce3281

Browse files
authored
[mlir][spirv] Check output of getConstantInt (#140568)
This patch adds an assert to check if the result of `getConstantInt` is non-null. Previously the code failed with Segmentation Fault if `getConstantInt` failed to look up the value. This primarily occurrs when the value is defined as OpSpecConstant rather than OpConstant.
1 parent 5a531b1 commit 3ce3281

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,12 +1062,30 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
10621062
<< operands[2];
10631063
}
10641064

1065-
unsigned rows = getConstantInt(operands[3]).getInt();
1066-
unsigned columns = getConstantInt(operands[4]).getInt();
1065+
IntegerAttr rowsAttr = getConstantInt(operands[3]);
1066+
IntegerAttr columnsAttr = getConstantInt(operands[4]);
1067+
IntegerAttr useAttr = getConstantInt(operands[5]);
1068+
1069+
if (!rowsAttr)
1070+
return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1071+
"undefined constant <id> ")
1072+
<< operands[3];
1073+
1074+
if (!columnsAttr)
1075+
return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1076+
"references undefined constant <id> ")
1077+
<< operands[4];
1078+
1079+
if (!useAttr)
1080+
return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1081+
"undefined constant <id> ")
1082+
<< operands[5];
1083+
1084+
unsigned rows = rowsAttr.getInt();
1085+
unsigned columns = columnsAttr.getInt();
10671086

10681087
std::optional<spirv::CooperativeMatrixUseKHR> use =
1069-
spirv::symbolizeCooperativeMatrixUseKHR(
1070-
getConstantInt(operands[5]).getInt());
1088+
spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
10711089
if (!use) {
10721090
return emitError(
10731091
unknownLoc,

0 commit comments

Comments
 (0)