Skip to content

Commit 9c43021

Browse files
committed
rebase
1 parent ec1495d commit 9c43021

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,16 +1109,15 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
11091109
static LogicalResult ReduceInferReturnTypes(
11101110
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
11111111
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1112-
if (!operandShape.hasRank() || operandShape.getRank() == 0) {
1112+
int64_t axisVal = axis.getValue().getSExtValue();
1113+
if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
11131114
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
11141115
return success();
11151116
}
11161117

11171118
SmallVector<int64_t> outputShape;
11181119
operandShape.getDims(outputShape);
1119-
int64_t axisVal = axis.getValue().getSExtValue();
1120-
if (axisVal < operandShape.getRank())
1121-
outputShape[axisVal] = 1;
1120+
outputShape[axisVal] = 1;
11221121
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
11231122
return success();
11241123
}

0 commit comments

Comments
 (0)