Skip to content

Commit e2919fb

Browse files
CoTinkerDanielCChen
authored andcommitted
[mlir][linalg] Bugfix for InlineScalarOperands (llvm#111534)
This PR fixes a bug where `scalarOperand` is a simple scalar and should be used directly, rather than accessed via `tensor.extract`. Fixes llvm#111243.
1 parent 7dc1faf commit e2919fb

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,12 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
7878
for (auto idx : indices)
7979
indicesValues.emplace_back(
8080
rewriter.create<arith::ConstantIndexOp>(loc, idx));
81-
Value extractedValue = rewriter.create<tensor::ExtractOp>(
82-
loc, opOperand->get(), indicesValues);
83-
body->getArgument(idx).replaceAllUsesWith(extractedValue);
81+
Value scalarValue = opOperand->get();
82+
if (isa<RankedTensorType>(scalarValue.getType())) {
83+
scalarValue =
84+
rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
85+
}
86+
body->getArgument(idx).replaceAllUsesWith(scalarValue);
8487
body->eraseArgument(idx);
8588
}
8689

mlir/test/Dialect/Linalg/inline-scalar-operands.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
4646
} -> tensor<4xf32>
4747
return %1 : tensor<4xf32>
4848
}
49+
50+
// -----
51+
52+
// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
53+
#map2 = affine_map<(d0) -> (d0)>
54+
#map3 = affine_map<(d0) -> ()>
55+
56+
// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
57+
func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
58+
%0 = tensor.empty() : tensor<4xf32>
59+
// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
60+
// CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
61+
%1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
62+
iterator_types = ["parallel"]}
63+
ins(%arg0, %scalar : tensor<4xf32>, f32)
64+
outs(%0 : tensor<4xf32>) {
65+
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
66+
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
67+
// CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
68+
%2 = arith.divf %arg1, %arg2 : f32
69+
linalg.yield %2 : f32
70+
} -> tensor<4xf32>
71+
return %1 : tensor<4xf32>
72+
}

0 commit comments

Comments
 (0)