Skip to content

Commit 5f74d9b

Browse files
authored
[mlir][linalg] Add support for inlined const to isaFillOpInterface (#144870)
1 parent 653d0d0 commit 5f74d9b

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
142142
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
143143

144144
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
145+
/// Supports two patterns:
146+
/// 1. External: linalg.generic ins(%scalar) outs(%tensor) { yield %scalar }
147+
/// 2. Inlined: linalg.generic outs(%tensor) { yield %constant }
145148
/// Returns the scalar fill value if true.
146149
std::optional<Value> isaFillOpInterface(GenericOp genericOp);
147150

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,37 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
7777
//===----------------------------------------------------------------------===//
7878
// FillOpInterface implementation
7979
//===----------------------------------------------------------------------===//
80-
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
80+
/// Detects if a linalg.generic operation represents a fill with an inlined
81+
/// constant. If so, returns the constant value. Otherwise, returns
82+
/// std::nullopt.
83+
static std::optional<Value> isaInlinedFillOp(GenericOp op) {
84+
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
85+
op.getNumDpsInputs() != 0)
86+
return std::nullopt;
87+
88+
// Init should not be referenced.
89+
if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
90+
return std::nullopt;
91+
92+
Block *body = op.getBody();
93+
if (body->getOperations().size() != 1)
94+
return std::nullopt;
95+
96+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
97+
if (!yieldOp || yieldOp.getNumOperands() != 1)
98+
return std::nullopt;
99+
100+
Value yieldOperand = yieldOp->getOperand(0);
101+
if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
102+
!yieldOperand.getDefiningOp<complex::ConstantOp>())
103+
return std::nullopt;
104+
105+
return yieldOperand;
106+
}
107+
108+
/// Detects if a linalg.generic operation represents an external scalar input.
109+
/// If so, returns the constant value. Otherwise, returns std::nullopt.
110+
static std::optional<Value> isaExternalFillOp(GenericOp op) {
81111
// Structural.
82112
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
83113
!op.isSingleYieldOp())
@@ -94,6 +124,12 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
94124
return value->get();
95125
}
96126

127+
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
128+
if (auto fillVal = isaInlinedFillOp(op))
129+
return fillVal;
130+
return isaExternalFillOp(op);
131+
}
132+
97133
//===----------------------------------------------------------------------===//
98134
// BroadcastOpInterface implementation
99135
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
267267
}
268268

269269
// Fill
270-
if (isaFillOpInterface(genericOp)) {
270+
if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
271+
// Always use the detected fill value, regardless of pattern
271272
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
272-
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
273+
genericOp, *fillValue, genericOp.getDpsInits()[0]);
273274
return namedOp;
274275
}
275276

mlir/test/Dialect/Linalg/transform-op-specialize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,28 @@ module attributes {transform.with_named_sequence} {
154154
transform.yield
155155
}
156156
}
157+
158+
// -----
159+
160+
#map = affine_map<(d0, d1) -> (d0, d1)>
161+
func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
162+
%cst = arith.constant 0.000000e+00 : f32
163+
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
164+
^bb0(%out: f32):
165+
linalg.yield %cst : f32
166+
} -> tensor<7x7xf32>
167+
return %0 : tensor<7x7xf32>
168+
}
169+
170+
// CHECK-LABEL: linalg_generic_inlined_constant_fill
171+
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
172+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
173+
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
174+
175+
module attributes {transform.with_named_sequence} {
176+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
177+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
178+
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
179+
transform.yield
180+
}
181+
}

0 commit comments

Comments
 (0)