Skip to content

Commit 9190e1c

Browse files
authored
[mlir][linalg] Handle reassociationIndices correctly for 0D tensor (#121683)
This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043.
1 parent 2d10b7b commit 9190e1c

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -600,30 +600,33 @@ static Value createLinalgBodyCalculationForElementwiseOp(
600600
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
601601
int64_t rank) {
602602
// No need to expand if we are already at the desired rank
603-
auto shapedType = dyn_cast<ShapedType>(tensor.getType());
604-
assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
605-
int64_t numExtraDims = rank - shapedType.getRank();
603+
auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
604+
assert(tensorType && "expected a ranked tensor type");
605+
int64_t tensorRank = tensorType.getRank();
606+
int64_t numExtraDims = rank - tensorRank;
606607
assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
607608
if (!numExtraDims)
608609
return tensor;
609610

610611
// Compute reassociation indices
611-
SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
612-
shapedType.getRank());
612+
SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
613613
int64_t index = 0;
614-
for (index = 0; index <= numExtraDims; index++)
615-
reassociationIndices[0].push_back(index);
616-
for (size_t position = 1; position < reassociationIndices.size(); position++)
617-
reassociationIndices[position].push_back(index++);
614+
if (tensorRank != 0) {
615+
for (index = 0; index <= numExtraDims; index++)
616+
reassociationIndices[0].push_back(index);
617+
for (size_t position = 1; position < reassociationIndices.size();
618+
position++)
619+
reassociationIndices[position].push_back(index++);
620+
}
618621

619622
// Compute result type
620623
SmallVector<int64_t> resultShape;
621624
for (index = 0; index < numExtraDims; index++)
622625
resultShape.push_back(1);
623-
for (auto size : shapedType.getShape())
626+
for (auto size : tensorType.getShape())
624627
resultShape.push_back(size);
625628
auto resultType =
626-
RankedTensorType::get(resultShape, shapedType.getElementType());
629+
RankedTensorType::get(resultShape, tensorType.getElementType());
627630

628631
// Emit 'tensor.expand_shape' op
629632
return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,29 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
100100

101101
// -----
102102

103+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
104+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
105+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
106+
107+
// CHECK-LABEL: func.func @test_add_0d_broadcast(
108+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
109+
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
110+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
111+
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
112+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
113+
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
114+
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
115+
// CHECK: linalg.yield %[[ADD]] : f32
116+
// CHECK: } -> tensor<2x1xf32>
117+
// CHECK: return %[[RESULT]] : tensor<2x1xf32>
118+
// CHECK: }
119+
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
120+
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
121+
return %0 : tensor<2x1xf32>
122+
}
123+
124+
// -----
125+
103126
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
104127
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
105128
// CHECK-LABEL: @test_add_1d_all_dynamic

0 commit comments

Comments
 (0)