Skip to content

Commit 070d211

Browse files
[mlir][Linalg] Fix SoftmaxOp's reify result shape calculation (#67790)
-- SoftmaxOp's `reifyResultShapes` function was wrongly casting it as a `LinalgOp`. -- This commit thus adds a fix to SoftmaxOp's reify result shape calculation. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 1a4b9b6 commit 070d211

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,8 +2387,23 @@ LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
23872387
LogicalResult
23882388
SoftmaxOp::reifyResultShapes(OpBuilder &b,
23892389
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2390-
return cast<LinalgOp>(getOperation())
2391-
.reifyResultShapes(b, reifiedReturnShapes);
2390+
SmallVector<OpFoldResult> shapes;
2391+
Location loc = getOperation()->getLoc();
2392+
IRRewriter rewriter(b);
2393+
auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2394+
auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2395+
for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2396+
if (!outputShapedType.isDynamicDim(dim)) {
2397+
// Static dim: Return IntegerAttr.
2398+
shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2399+
} else {
2400+
// Dynamic dim: Return Value.
2401+
OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2402+
shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2403+
}
2404+
}
2405+
reifiedReturnShapes.emplace_back(std::move(shapes));
2406+
return success();
23922407
}
23932408

23942409
void SoftmaxOp::getEffects(

mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,22 @@ func.func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index
276276
// CHECK: %[[IN_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
277277
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
278278
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
279+
280+
// -----
281+
282+
func.func @dim_of_softmax_op(%arg0: tensor<?x16x?xf32>, %arg1: tensor<2x?x?xf32>) -> (index, index, index) {
283+
%c0 = arith.constant 0 : index
284+
%c1 = arith.constant 1 : index
285+
%c2 = arith.constant 2 : index
286+
%0 = linalg.softmax dimension(2) ins(%arg0 : tensor<?x16x?xf32>) outs(%arg1 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
287+
%dim = tensor.dim %0, %c0 : tensor<2x?x?xf32>
288+
%dim_0 = tensor.dim %0, %c1 : tensor<2x?x?xf32>
289+
%dim_1 = tensor.dim %0, %c2 : tensor<2x?x?xf32>
290+
return %dim, %dim_0, %dim_1 : index, index, index
291+
}
292+
// CHECK-LABEL: @dim_of_softmax_op
293+
// CHECK-SAME: (%[[INPUT:.*]]: tensor<?x16x?xf32>
294+
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index
295+
// CHECK-NEXT: %[[C16:.*]] = arith.constant 16 : index
296+
// CHECK-NEXT: %[[IN_DIM2:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<?x16x?xf32>
297+
// CHECK-NEXT: return %[[C2]], %[[C16]], %[[IN_DIM2]] : index, index, index

0 commit comments

Comments
 (0)