Skip to content

Commit 6fa17ac

Browse files
committed
[mlir][linalg] Handle reassociationIndices correctly for 0D tensor
This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash.
1 parent fac4646 commit 6fa17ac

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,10 +611,13 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
611611
SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
612612
shapedType.getRank());
613613
int64_t index = 0;
614-
for (index = 0; index <= numExtraDims; index++)
615-
reassociationIndices[0].push_back(index);
616-
for (size_t position = 1; position < reassociationIndices.size(); position++)
617-
reassociationIndices[position].push_back(index++);
614+
if (shapedType.getRank() != 0) {
615+
for (index = 0; index <= numExtraDims; index++)
616+
reassociationIndices[0].push_back(index);
617+
for (size_t position = 1; position < reassociationIndices.size();
618+
position++)
619+
reassociationIndices[position].push_back(index++);
620+
}
618621

619622
// Compute result type
620623
SmallVector<int64_t> resultShape;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,3 +1964,26 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
19641964
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
19651965
return %0: tensor<1xi64>
19661966
}
1967+
1968+
// -----
1969+
1970+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
1971+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
1972+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
1973+
1974+
// CHECK-LABEL: func.func @test_add_0d_broadcast(
1975+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
1976+
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
1977+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
1978+
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
1979+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
1980+
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
1981+
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
1982+
// CHECK: linalg.yield %[[ADD]] : f32
1983+
// CHECK: } -> tensor<2x1xf32>
1984+
// CHECK: return %[[RESULT]] : tensor<2x1xf32>
1985+
// CHECK: }
1986+
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
1987+
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
1988+
return %0 : tensor<2x1xf32>
1989+
}

0 commit comments

Comments
 (0)