Skip to content

Commit 4af01bf

Browse files
committed
[mlir:bytecode] Support lazy loading dynamically isolated regions
We currently only support lazy loading for regions that statically implement the IsolatedFromAbove trait, but that limits the amount of operations that can be lazily loaded. This review lifts that restriction by computing which operations have isolated regions when numbering, allowing any operation to be lazily loaded as long as it doesn't use values defined above. Differential Revision: https://reviews.llvm.org/D156199
1 parent 5ab6589 commit 4af01bf

File tree

4 files changed

+193
-25
lines changed

4 files changed

+193
-25
lines changed

mlir/lib/Bytecode/Writer/BytecodeWriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
942942
// emitting the regions first (e.g. if the regions are huge, backpatching the
943943
// op encoding mask is more annoying).
944944
if (numRegions) {
945-
bool isIsolatedFromAbove = op->hasTrait<OpTrait::IsIsolatedFromAbove>();
945+
bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op);
946946
emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
947947

948948
// If the region is not isolated from above, or we are emitting bytecode

mlir/lib/Bytecode/Writer/IRNumbering.cpp

Lines changed: 130 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,29 @@ static void groupByDialectPerByte(T range) {
115115
IRNumberingState::IRNumberingState(Operation *op,
116116
const BytecodeWriterConfig &config)
117117
: config(config) {
118-
// Compute a global operation ID numbering according to the pre-order walk of
119-
// the IR. This is used as reference to construct use-list orders.
120-
unsigned operationID = 0;
121-
op->walk<WalkOrder::PreOrder>(
122-
[&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
118+
computeGlobalNumberingState(op);
123119

124120
// Number the root operation.
125121
number(*op);
126122

127-
// Push all of the regions of the root operation onto the worklist.
123+
// A worklist of region contexts to number and the next value id before that
124+
// region.
128125
SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
129-
for (Region &region : op->getRegions())
130-
numberContext.emplace_back(&region, nextValueID);
126+
127+
// Functor to push the regions of the given operation onto the numbering
128+
// context.
129+
auto addOpRegionsToNumber = [&](Operation *op) {
130+
MutableArrayRef<Region> regions = op->getRegions();
131+
if (regions.empty())
132+
return;
133+
134+
// Isolated regions don't share value numbers with their parent, so we can
135+
// start numbering these regions at zero.
136+
unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
137+
for (Region &region : regions)
138+
numberContext.emplace_back(&region, opFirstValueID);
139+
};
140+
addOpRegionsToNumber(op);
131141

132142
// Iteratively process each of the nested regions.
133143
while (!numberContext.empty()) {
@@ -136,14 +146,8 @@ IRNumberingState::IRNumberingState(Operation *op,
136146
number(*region);
137147

138148
// Traverse into nested regions.
139-
for (Operation &op : region->getOps()) {
140-
// Isolated regions don't share value numbers with their parent, so we can
141-
// start numbering these regions at zero.
142-
unsigned opFirstValueID =
143-
op.hasTrait<OpTrait::IsIsolatedFromAbove>() ? 0 : nextValueID;
144-
for (Region &region : op.getRegions())
145-
numberContext.emplace_back(&region, opFirstValueID);
146-
}
149+
for (Operation &op : region->getOps())
150+
addOpRegionsToNumber(&op);
147151
}
148152

149153
// Number each of the dialects. For now this is just in the order they were
@@ -178,6 +182,116 @@ IRNumberingState::IRNumberingState(Operation *op,
178182
finalizeDialectResourceNumberings(op);
179183
}
180184

185+
void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
186+
// A simple state struct tracking data used when walking operations.
187+
struct StackState {
188+
/// The operation currently being walked.
189+
Operation *op;
190+
191+
/// The numbering of the operation.
192+
OperationNumbering *numbering;
193+
194+
/// A flag indicating if the current state or one of its parents has
195+
/// unresolved isolation status. This is tracked separately from the
196+
/// isIsolatedFromAbove bit on `numbering` because we need to be able to
197+
/// handle the given case:
198+
/// top.op {
199+
/// %value = ...
200+
/// middle.op {
201+
/// %value2 = ...
202+
/// inner.op {
203+
/// // Here we mark `inner.op` as not isolated. Note `middle.op`
204+
/// // isn't known not isolated yet.
205+
/// use.op %value2
206+
///
207+
/// // Here inner.op is already known to be non-isolated, but
208+
/// // `middle.op` is now also discovered to be non-isolated.
209+
/// use.op %value
210+
/// }
211+
/// }
212+
/// }
213+
bool hasUnresolvedIsolation;
214+
};
215+
216+
// Compute a global operation ID numbering according to the pre-order walk of
217+
// the IR. This is used as reference to construct use-list orders.
218+
unsigned operationID = 0;
219+
220+
// Walk each of the operations within the IR, tracking a stack of operations
221+
// as we recurse into nested regions. This walk method hooks in at two stages
222+
// during the walk:
223+
//
224+
// BeforeAllRegions:
225+
// Here we generate a numbering for the operation and push it onto the
226+
// stack if it has regions. We also compute the isolation status of parent
227+
// regions at this stage. This is done by checking the parent regions of
228+
// operands used by the operation, and marking each region between the
229+
// the operand region and the current as not isolated. See
230+
// StackState::hasUnresolvedIsolation above for an example.
231+
//
232+
// AfterAllRegions:
233+
// Here we pop the operation from the stack, and if it hasn't been marked
234+
// as non-isolated, we mark it as so. A non-isolated use would have been
235+
// found while walking the regions, so it is safe to mark the operation at
236+
// this point.
237+
//
238+
SmallVector<StackState> opStack;
239+
rootOp->walk([&](Operation *op, const WalkStage &stage) {
240+
// After visiting all nested regions, we pop the operation from the stack.
241+
if (stage.isAfterAllRegions()) {
242+
// If no non-isolated uses were found, we can safely mark this operation
243+
// as isolated from above.
244+
OperationNumbering *numbering = opStack.pop_back_val().numbering;
245+
if (!numbering->isIsolatedFromAbove.has_value())
246+
numbering->isIsolatedFromAbove = true;
247+
return;
248+
}
249+
250+
// When visiting before nested regions, we process "IsolatedFromAbove"
251+
// checks and compute the number for this operation.
252+
if (!stage.isBeforeAllRegions())
253+
return;
254+
// Update the isolation status of parent regions if any have yet to be
255+
// resolved.
256+
if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
257+
Region *parentRegion = op->getParentRegion();
258+
for (Value operand : op->getOperands()) {
259+
Region *operandRegion = operand.getParentRegion();
260+
if (operandRegion == parentRegion)
261+
continue;
262+
// We've found a use of an operand outside of the current region,
263+
// walk the operation stack searching for the parent operation,
264+
// marking every region on the way as not isolated.
265+
Operation *operandContainerOp = operandRegion->getParentOp();
266+
auto it = std::find_if(
267+
opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
268+
// We only need to mark up to the container region, or the first
269+
// that has an unresolved status.
270+
return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
271+
});
272+
assert(it != opStack.rend() && "expected to find the container");
273+
for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
274+
// If we stopped at a region that knows its isolation status, we can
275+
// stop updating the isolation status for the parent regions.
276+
state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
277+
state.numbering->isIsolatedFromAbove = false;
278+
}
279+
}
280+
}
281+
282+
// Compute the number for this op and push it onto the stack.
283+
auto *numbering =
284+
new (opAllocator.Allocate()) OperationNumbering(operationID++);
285+
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
286+
numbering->isIsolatedFromAbove = true;
287+
operations.try_emplace(op, numbering);
288+
if (op->getNumRegions()) {
289+
opStack.emplace_back(StackState{
290+
op, numbering, !numbering->isIsolatedFromAbove.has_value()});
291+
}
292+
});
293+
}
294+
181295
void IRNumberingState::number(Attribute attr) {
182296
auto it = attrs.insert({attr, nullptr});
183297
if (!it.second) {

mlir/lib/Bytecode/Writer/IRNumbering.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,22 @@ struct DialectNumbering {
126126
llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
127127
};
128128

129+
//===----------------------------------------------------------------------===//
130+
// Operation Numbering
131+
//===----------------------------------------------------------------------===//
132+
133+
/// This class represents the numbering entry of an operation.
134+
struct OperationNumbering {
135+
OperationNumbering(unsigned number) : number(number) {}
136+
137+
/// The number assigned to this operation.
138+
unsigned number;
139+
140+
/// A flag indicating if this operation's regions are isolated. If unset, the
141+
/// operation isn't yet known to be isolated.
142+
std::optional<bool> isIsolatedFromAbove;
143+
};
144+
129145
//===----------------------------------------------------------------------===//
130146
// IRNumberingState
131147
//===----------------------------------------------------------------------===//
@@ -154,8 +170,8 @@ class IRNumberingState {
154170
return blockIDs[block];
155171
}
156172
unsigned getNumber(Operation *op) {
157-
assert(operationIDs.count(op) && "operation not numbered");
158-
return operationIDs[op];
173+
assert(operations.count(op) && "operation not numbered");
174+
return operations[op]->number;
159175
}
160176
unsigned getNumber(OperationName opName) {
161177
assert(opNames.count(opName) && "opName not numbered");
@@ -186,14 +202,23 @@ class IRNumberingState {
186202
return blockOperationCounts[block];
187203
}
188204

205+
/// Return if the given operation is isolated from above.
206+
bool isIsolatedFromAbove(Operation *op) {
207+
assert(operations.count(op) && "operation not numbered");
208+
return operations[op]->isIsolatedFromAbove.value_or(false);
209+
}
210+
189211
/// Get the set desired bytecode version to emit.
190212
int64_t getDesiredBytecodeVersion() const;
191-
213+
192214
private:
193215
/// This class is used to provide a fake dialect writer for numbering nested
194216
/// attributes and types.
195217
struct NumberingDialectWriter;
196218

219+
/// Compute the global numbering state for the given root operation.
220+
void computeGlobalNumberingState(Operation *rootOp);
221+
197222
/// Number the given IR unit for bytecode emission.
198223
void number(Attribute attr);
199224
void number(Block &block);
@@ -212,6 +237,7 @@ class IRNumberingState {
212237

213238
/// Mapping from IR to the respective numbering entries.
214239
DenseMap<Attribute, AttributeNumbering *> attrs;
240+
DenseMap<Operation *, OperationNumbering *> operations;
215241
DenseMap<OperationName, OpNameNumbering *> opNames;
216242
DenseMap<Type, TypeNumbering *> types;
217243
DenseMap<Dialect *, DialectNumbering *> registeredDialects;
@@ -228,12 +254,12 @@ class IRNumberingState {
228254
/// Allocators used for the various numbering entries.
229255
llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
230256
llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
257+
llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
231258
llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
232259
llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
233260
llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
234261

235-
/// The value ID for each Operation, Block and Value.
236-
DenseMap<Operation *, unsigned> operationIDs;
262+
/// The value ID for each Block and Value.
237263
DenseMap<Block *, unsigned> blockIDs;
238264
DenseMap<Value, unsigned> valueIDs;
239265

mlir/test/Bytecode/bytecode-lazy-loading.mlir

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ func.func @op_with_passthrough_region_args() {
2323
}, {
2424
"test.unknown_op"() : () -> ()
2525
}
26+
27+
// Ensure operations that aren't tagged as IsolatedFromAbove can
28+
// still be lazy loaded if they don't have references to values
29+
// defined above.
30+
"test.one_region_op"() ({
31+
"test.unknown_op"() : () -> ()
32+
}) : () -> ()
33+
34+
// Similar test as above, but check that if one region has a reference
35+
// to a value defined above, we don't lazy load the operation.
36+
"test.two_region_op"() ({
37+
"test.unknown_op"() : () -> ()
38+
}, {
39+
"test.consumer"(%0) : (index) -> ()
40+
}) : () -> ()
2641
return
2742
}
2843

@@ -53,7 +68,12 @@ func.func @op_with_passthrough_region_args() {
5368
// CHECK: test.consumer
5469
// CHECK: isolated_region
5570
// CHECK-NOT: test.consumer
56-
// CHECK: Has 3 ops to materialize
71+
// CHECK: test.one_region_op
72+
// CHECK-NOT: test.op
73+
// CHECK: test.two_region_op
74+
// CHECK: test.unknown_op
75+
// CHECK: test.consumer
76+
// CHECK: Has 4 ops to materialize
5777

5878
// CHECK: Before Materializing...
5979
// CHECK: test.isolated_region
@@ -62,15 +82,15 @@ func.func @op_with_passthrough_region_args() {
6282
// CHECK: test.isolated_region
6383
// CHECK: ^bb0(%arg0: index):
6484
// CHECK: test.consumer
65-
// CHECK: Has 2 ops to materialize
85+
// CHECK: Has 3 ops to materialize
6686

6787
// CHECK: Before Materializing...
6888
// CHECK: test.isolated_region
6989
// CHECK-NOT: test.consumer
7090
// CHECK: Materializing...
7191
// CHECK: test.isolated_region
7292
// CHECK: test.consumer
73-
// CHECK: Has 1 ops to materialize
93+
// CHECK: Has 2 ops to materialize
7494

7595
// CHECK: Before Materializing...
7696
// CHECK: test.isolated_regions
@@ -79,4 +99,12 @@ func.func @op_with_passthrough_region_args() {
7999
// CHECK: test.isolated_regions
80100
// CHECK: test.unknown_op
81101
// CHECK: test.unknown_op
102+
// CHECK: Has 1 ops to materialize
103+
104+
// CHECK: Before Materializing...
105+
// CHECK: test.one_region_op
106+
// CHECK-NOT: test.unknown_op
107+
// CHECK: Materializing...
108+
// CHECK: test.one_region_op
109+
// CHECK: test.unknown_op
82110
// CHECK: Has 0 ops to materialize

0 commit comments

Comments
 (0)