Skip to content

Commit dacfb24

Browse files
committed
[mlir] Support inlining into affine operations
Introduce support for inlining into affine operations. This uses the generic inline infrastructure and boils down to checking that, if applied, the inlining doesn't violate the affine dimension/symbol value categorization. Given valid IR, only the values that are valid dimensions/symbols thanks to being top-level in their affine scope need special handling. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D92770
1 parent 62b4a69 commit dacfb24

File tree

2 files changed

+207
-22
lines changed

2 files changed

+207
-22
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 142 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
1111
#include "mlir/Dialect/StandardOps/IR/Ops.h"
12+
#include "mlir/IR/BlockAndValueMapping.h"
1213
#include "mlir/IR/BuiltinOps.h"
1314
#include "mlir/IR/IntegerSet.h"
1415
#include "mlir/IR/Matchers.h"
@@ -25,6 +26,99 @@ using llvm::dbgs;
2526

2627
#define DEBUG_TYPE "affine-analysis"
2728

29+
/// A utility function to check if a value is defined at the top level of
30+
/// `region` or is an argument of `region`. A value of index type defined at the
31+
/// top level of a `AffineScope` region is always a valid symbol for all
32+
/// uses in that region.
33+
static bool isTopLevelValue(Value value, Region *region) {
34+
if (auto arg = value.dyn_cast<BlockArgument>())
35+
return arg.getParentRegion() == region;
36+
return value.getDefiningOp()->getParentRegion() == region;
37+
}
38+
39+
/// Checks if `value` known to be a legal affine dimension or symbol in `src`
40+
/// region remains legal if the operation that uses it is inlined into `dest`
41+
/// with the given value mapping. `legalityCheck` is either `isValidDim` or
42+
/// `isValidSymbol`, depending on the value being required to remain a valid
43+
/// dimension or symbol.
44+
static bool
45+
remainsLegalAfterInline(Value value, Region *src, Region *dest,
46+
const BlockAndValueMapping &mapping,
47+
function_ref<bool(Value, Region *)> legalityCheck) {
48+
// If the value is a valid dimension for any other reason than being
49+
// a top-level value, it will remain valid: constants get inlined
50+
// with the function, transitive affine applies also get inlined and
51+
// will be checked themselves, etc.
52+
if (!isTopLevelValue(value, src))
53+
return true;
54+
55+
// If it's a top-level value because it's a block operand, i.e. a
56+
// function argument, check whether the value replacing it after
57+
// inlining is a valid dimension in the new region.
58+
if (value.isa<BlockArgument>())
59+
return legalityCheck(mapping.lookup(value), dest);
60+
61+
// If it's a top-level value beacuse it's defined in the region,
62+
// it can only be inlined if the defining op is a constant or a
63+
// `dim`, which can appear anywhere and be valid, since the defining
64+
// op won't be top-level anymore after inlining.
65+
Attribute operandCst;
66+
return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
67+
value.getDefiningOp<DimOp>();
68+
}
69+
70+
/// Checks if all values known to be legal affine dimensions or symbols in `src`
71+
/// remain so if their respective users are inlined into `dest`.
72+
static bool
73+
remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
74+
const BlockAndValueMapping &mapping,
75+
function_ref<bool(Value, Region *)> legalityCheck) {
76+
return llvm::all_of(values, [&](Value v) {
77+
return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
78+
});
79+
}
80+
81+
/// Checks if an affine read or write operation remains legal after inlining
82+
/// from `src` to `dest`.
83+
template <typename OpTy>
84+
static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
85+
const BlockAndValueMapping &mapping) {
86+
static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
87+
AffineWriteOpInterface>::value,
88+
"only ops with affine read/write interface are supported");
89+
90+
AffineMap map = op.getAffineMap();
91+
ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
92+
ValueRange symbolOperands =
93+
op.getMapOperands().take_back(map.getNumSymbols());
94+
if (!remainsLegalAfterInline(
95+
dimOperands, src, dest, mapping,
96+
static_cast<bool (*)(Value, Region *)>(isValidDim)))
97+
return false;
98+
if (!remainsLegalAfterInline(
99+
symbolOperands, src, dest, mapping,
100+
static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
101+
return false;
102+
return true;
103+
}
104+
105+
/// Checks if an affine apply operation remains legal after inlining from `src`
106+
/// to `dest`.
107+
template <>
108+
bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
109+
const BlockAndValueMapping &mapping) {
110+
// If it's a valid dimension, we need to check that it remains so.
111+
if (isValidDim(op.getResult(), src))
112+
return remainsLegalAfterInline(
113+
op.getMapOperands(), src, dest, mapping,
114+
static_cast<bool (*)(Value, Region *)>(isValidDim));
115+
116+
// Otherwise it must be a valid symbol, check that it remains so.
117+
return remainsLegalAfterInline(
118+
op.getMapOperands(), src, dest, mapping,
119+
static_cast<bool (*)(Value, Region *)>(isValidSymbol));
120+
}
121+
28122
//===----------------------------------------------------------------------===//
29123
// AffineDialect Interfaces
30124
//===----------------------------------------------------------------------===//
@@ -41,22 +135,62 @@ struct AffineInlinerInterface : public DialectInlinerInterface {
41135

42136
/// Returns true if the given region 'src' can be inlined into the region
43137
/// 'dest' that is attached to an operation registered to the current dialect.
138+
/// 'wouldBeCloned' is set if the region is cloned into its new location
139+
/// rather than moved, indicating there may be other users.
44140
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45141
BlockAndValueMapping &valueMapping) const final {
46-
// Conservatively don't allow inlining into affine structures.
47-
return false;
142+
// We can inline into affine loops and conditionals if this doesn't break
143+
// affine value categorization rules.
144+
Operation *destOp = dest->getParentOp();
145+
if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
146+
return false;
147+
148+
// Multi-block regions cannot be inlined into affine constructs, all of
149+
// which require single-block regions.
150+
if (!llvm::hasSingleElement(*src))
151+
return false;
152+
153+
// Side-effecting operations that the affine dialect cannot understand
154+
// should not be inlined.
155+
Block &srcBlock = src->front();
156+
for (Operation &op : srcBlock) {
157+
// Ops with no side effects are fine,
158+
if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
159+
if (iface.hasNoEffect())
160+
continue;
161+
}
162+
163+
// Assuming the inlined region is valid, we only need to check if the
164+
// inlining would change it.
165+
bool remainsValid =
166+
llvm::TypeSwitch<Operation *, bool>(&op)
167+
.Case<AffineApplyOp, AffineReadOpInterface,
168+
AffineWriteOpInterface>([&](auto op) {
169+
return remainsLegalAfterInline(op, src, dest, valueMapping);
170+
})
171+
.Default([](Operation *) {
172+
// Conservatively disallow inlining ops we cannot reason about.
173+
return false;
174+
});
175+
176+
if (!remainsValid)
177+
return false;
178+
}
179+
180+
return true;
48181
}
49182

50183
/// Returns true if the given operation 'op', that is registered to this
51184
/// dialect, can be inlined into the given region, false otherwise.
52185
bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
53186
BlockAndValueMapping &valueMapping) const final {
54-
// Always allow inlining affine operations into the top-level region of a
55-
// function. There are some edge cases when inlining *into* affine
56-
// structures, but that is handled in the other 'isLegalToInline' hook
57-
// above.
58-
// TODO: We should be able to inline into other regions than functions.
59-
return isa<FuncOp>(region->getParentOp());
187+
// Always allow inlining affine operations into a region that is marked as
188+
// affine scope, or into affine loops and conditionals. There are some edge
189+
// cases when inlining *into* affine structures, but that is handled in the
190+
// other 'isLegalToInline' hook above.
191+
Operation *parentOp = region->getParentOp();
192+
return parentOp->hasTrait<OpTrait::AffineScope>() ||
193+
isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
60194
}
61195

62196
/// Affine regions should be analyzed recursively.
@@ -101,16 +235,6 @@ bool mlir::isTopLevelValue(Value value) {
101235
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
102236
}
103237

104-
/// A utility function to check if a value is defined at the top level of
105-
/// `region` or is an argument of `region`. A value of index type defined at the
106-
/// top level of a `AffineScope` region is always a valid symbol for all
107-
/// uses in that region.
108-
static bool isTopLevelValue(Value value, Region *region) {
109-
if (auto arg = value.dyn_cast<BlockArgument>())
110-
return arg.getParentRegion() == region;
111-
return value.getDefiningOp()->getParentRegion() == region;
112-
}
113-
114238
/// Returns the closest region enclosing `op` that is held by an operation with
115239
/// trait `AffineScope`; `nullptr` if there is no such region.
116240
// TODO: getAffineScope should be publicly exposed for affine passes/utilities.

mlir/test/Dialect/Affine/inlining.mlir

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,77 @@ func @not_inline_invalid_nest_op() {
5454

5555
// -----
5656

57-
// Test that calls are not inlined into affine structures.
57+
// Test that calls are inlined into affine structures.
5858
func @func_noop() {
5959
return
6060
}
6161

62-
// CHECK-LABEL: func @not_inline_into_affine_ops
63-
func @not_inline_into_affine_ops() {
64-
// CHECK: call @func_noop
62+
// CHECK-LABEL: func @inline_into_affine_ops
63+
func @inline_into_affine_ops() {
64+
// CHECK-NOT: call @func_noop
6565
affine.for %i = 1 to 10 {
6666
call @func_noop() : () -> ()
6767
}
6868
return
6969
}
70+
71+
// -----
72+
73+
// Test that calls with dimension arguments are properly inlined.
74+
func @func_dim(%arg0: index, %arg1: memref<?xf32>) {
75+
affine.load %arg1[%arg0] : memref<?xf32>
76+
return
77+
}
78+
79+
// CHECK-LABEL: @inline_dimension
80+
// CHECK: (%[[ARG0:.*]]: memref<?xf32>)
81+
func @inline_dimension(%arg0: memref<?xf32>) {
82+
// CHECK: affine.for %[[IV:.*]] =
83+
affine.for %i = 1 to 42 {
84+
// CHECK-NOT: call @func_dim
85+
// CHECK: affine.load %[[ARG0]][%[[IV]]]
86+
call @func_dim(%i, %arg0) : (index, memref<?xf32>) -> ()
87+
}
88+
return
89+
}
90+
91+
// -----
92+
93+
// Test that calls with vector operations are also inlined.
94+
func @func_vector_dim(%arg0: index, %arg1: memref<32xf32>) {
95+
affine.vector_load %arg1[%arg0] : memref<32xf32>, vector<4xf32>
96+
return
97+
}
98+
99+
// CHECK-LABEL: @inline_dimension_vector
100+
// CHECK: (%[[ARG0:.*]]: memref<32xf32>)
101+
func @inline_dimension_vector(%arg0: memref<32xf32>) {
102+
// CHECK: affine.for %[[IV:.*]] =
103+
affine.for %i = 1 to 42 {
104+
// CHECK-NOT: call @func_dim
105+
// CHECK: affine.vector_load %[[ARG0]][%[[IV]]]
106+
call @func_vector_dim(%i, %arg0) : (index, memref<32xf32>) -> ()
107+
}
108+
return
109+
}
110+
111+
// -----
112+
113+
// Test that calls that would result in violation of affine value
114+
// categorization (top-level value stop being top-level) are not inlined.
115+
func private @get_index() -> index
116+
117+
func @func_top_level(%arg0: memref<?xf32>) {
118+
%0 = call @get_index() : () -> index
119+
affine.load %arg0[%0] : memref<?xf32>
120+
return
121+
}
122+
123+
// CHECK-LABEL: @no_inline_not_top_level
124+
func @no_inline_not_top_level(%arg0: memref<?xf32>) {
125+
affine.for %i = 1 to 42 {
126+
// CHECK: call @func_top_level
127+
call @func_top_level(%arg0) : (memref<?xf32>) -> ()
128+
}
129+
return
130+
}

0 commit comments

Comments
 (0)