Skip to content

Commit d30c022

Browse files
committed
[mlir] Split MLProgram global load and store to Graph variants
* Split ops into X_graph variants as discussed; * Remove tokens from non-Graph region variants and rely on side-effect modelling there while removing side-effect modelling from Graph variants and relying on explicit ordering there; * Make tokens required to be produced by Graph variants - but kept explicit token type specification given previous discussion on this potentially being configurable in future; This results in duplicating some code. I considered adding helper functions but decided against adding an abstraction there early given size of duplication and creating accidental coupling. Differential Revision: https://reviews.llvm.org/D127813
1 parent f2bcf33 commit d30c022

File tree

4 files changed

+173
-18
lines changed

4 files changed

+173
-18
lines changed

mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
171171
advanced cases.
172172

173173
This op is side effecting and may not be valid to use in graph regions
174-
without additional consideration to evaluation order constraints.
174+
without additional consideration to evaluation order constraints. See
175+
`global_load_graph` for op which allows for explicit ordering constraints.
175176

176177
Example:
177178

@@ -181,16 +182,14 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
181182
}];
182183

183184
let arguments = (ins
184-
Arg<SymbolRefAttr, "", [MemRead]>:$global,
185-
Variadic<MLProgram_TokenType>:$consumeTokens
185+
Arg<SymbolRefAttr, "", [MemRead]>:$global
186186
);
187187
let results = (outs
188-
AnyType:$result,
189-
Optional<MLProgram_TokenType>:$produceToken
188+
AnyType:$result
190189
);
191190

192191
let assemblyFormat = [{
193-
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
192+
$global `:` type($result) attr-dict
194193
}];
195194

196195
let extraClassDeclaration = [{
@@ -238,6 +237,52 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
238237
}];
239238
}
240239

240+
//===----------------------------------------------------------------------===//
241+
// GlobalLoadGraphOp
242+
//===----------------------------------------------------------------------===//
243+
244+
def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [
245+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
246+
]> {
247+
let summary = "Direct load of a mutable value from a global in Graph region";
248+
let description = [{
249+
Performs a non-atomic, non-volatile, non-synchronized load from a global
250+
that may be mutable.
251+
252+
It is fully expected that these constraints are not suitable for all
253+
situations, and alternative ops should be defined and used for more advanced
254+
cases.
255+
256+
This op is side effecting and may not be valid to use in graph regions
257+
without additional consideration to evaluation order constraints.
258+
259+
Example:
260+
261+
```mlir
262+
%0, %cstr = ml_program.global_load_graph @foobar
263+
ordering (%token -> !ml_program.token) : tensor<?xi32>
264+
```
265+
}];
266+
267+
let arguments = (ins
268+
Arg<SymbolRefAttr, "", [MemRead]>:$global,
269+
Variadic<MLProgram_TokenType>:$consumeTokens
270+
);
271+
let results = (outs
272+
AnyType:$result,
273+
MLProgram_TokenType:$produceToken
274+
);
275+
276+
let assemblyFormat = [{
277+
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
278+
}];
279+
280+
let extraClassDeclaration = [{
281+
/// Gets the corresponding GlobalOp (or nullptr).
282+
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
283+
}];
284+
}
285+
241286
//===----------------------------------------------------------------------===//
242287
// GlobalStoreOp
243288
//===----------------------------------------------------------------------===//
@@ -255,7 +300,8 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
255300
advanced cases.
256301

257302
This op is side effecting and may not be valid to use in graph regions
258-
without additional consideration to evaluation order constraints.
303+
without additional consideration to evaluation order constraints. See
304+
`global_store_graph` for op which allows for explicit ordering constraints.
259305

260306
Example:
261307

@@ -266,11 +312,53 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
266312

267313
let arguments = (ins
268314
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
315+
AnyType:$value
316+
);
317+
318+
let assemblyFormat = [{
319+
$global `=` $value `:` type($value) attr-dict
320+
}];
321+
322+
let extraClassDeclaration = [{
323+
/// Gets the corresponding GlobalOp (or nullptr).
324+
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
325+
}];
326+
}
327+
328+
//===----------------------------------------------------------------------===//
329+
// GlobalStoreGraphOp
330+
//===----------------------------------------------------------------------===//
331+
332+
def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [
333+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
334+
]> {
335+
let summary = "Direct store of a value into a mutable global";
336+
let description = [{
337+
Performs a non-atomic, non-volatile, non-synchronized store to a mutable
338+
global.
339+
340+
It is fully expected that these constraints are not suitable for
341+
all situations, and alternative ops should be defined and used for more
342+
advanced cases.
343+
344+
This op is side effecting and may not be valid to use in graph regions
345+
without additional consideration to evaluation order constraints.
346+
347+
Example:
348+
349+
```mlir
350+
%token = ml_program.global_store @foobar = %0 : tensor<?xi32>
351+
ordering (%in_token -> !ml_program.token) : tensor<?xi32>
352+
```
353+
}];
354+
355+
let arguments = (ins
356+
Arg<SymbolRefAttr, "", [MemRead]>:$global,
269357
AnyType:$value,
270358
Variadic<MLProgram_TokenType>:$consumeTokens
271359
);
272360
let results = (outs
273-
Optional<MLProgram_TokenType>:$produceToken
361+
MLProgram_TokenType:$produceToken
274362
);
275363

276364
let assemblyFormat = [{

mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@ using namespace mlir::ml_program;
1818
//===----------------------------------------------------------------------===//
1919

2020
/// Parse and print an ordering clause for a variadic of consuming tokens
21-
/// and an optional producing token.
21+
/// and an producing token.
2222
///
2323
/// Syntax:
2424
/// ordering(%0, %1 -> !ml_program.token)
2525
/// ordering(() -> !ml_program.token)
26-
/// ordering(%0, %1)
2726
///
2827
/// If both the consuming and producing token are not present on the op, then
2928
/// the clause prints nothing.
@@ -46,10 +45,11 @@ static ParseResult parseTokenOrdering(
4645
return failure();
4746
}
4847

49-
// Parse optional producer token.
50-
if (succeeded(parser.parseOptionalArrow()))
51-
if (failed(parser.parseType(produceTokenType)))
52-
return failure();
48+
// Parse producer token.
49+
if (failed(parser.parseArrow()))
50+
return failure();
51+
if (failed(parser.parseType(produceTokenType)))
52+
return failure();
5353

5454
if (failed(parser.parseRParen()))
5555
return failure();
@@ -220,6 +220,30 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
220220
return success();
221221
}
222222

223+
//===----------------------------------------------------------------------===//
224+
// GlobalLoadGraphOp
225+
//===----------------------------------------------------------------------===//
226+
227+
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
228+
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
229+
getOperation()->getParentOp(), getGlobalAttr());
230+
}
231+
232+
LogicalResult
233+
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
234+
GlobalOp referrent = getGlobalOp(symbolTable);
235+
if (!referrent)
236+
return emitOpError() << "undefined global: " << getGlobal();
237+
238+
if (referrent.getType() != getResult().getType()) {
239+
return emitOpError() << "cannot load from global typed "
240+
<< referrent.getType() << " as "
241+
<< getResult().getType();
242+
}
243+
244+
return success();
245+
}
246+
223247
//===----------------------------------------------------------------------===//
224248
// GlobalStoreOp
225249
//===----------------------------------------------------------------------===//
@@ -249,6 +273,35 @@ GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
249273
return success();
250274
}
251275

276+
//===----------------------------------------------------------------------===//
277+
// GlobalStoreGraphOp
278+
//===----------------------------------------------------------------------===//
279+
280+
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
281+
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
282+
getOperation()->getParentOp(), getGlobalAttr());
283+
}
284+
285+
LogicalResult
286+
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
287+
GlobalOp referrent = getGlobalOp(symbolTable);
288+
if (!referrent)
289+
return emitOpError() << "undefined global: " << getGlobal();
290+
291+
if (!referrent.getIsMutable()) {
292+
return emitOpError() << "cannot store to an immutable global "
293+
<< getGlobal();
294+
}
295+
296+
if (referrent.getType() != getValue().getType()) {
297+
return emitOpError() << "cannot store to a global typed "
298+
<< referrent.getType() << " from "
299+
<< getValue().getType();
300+
}
301+
302+
return success();
303+
}
304+
252305
//===----------------------------------------------------------------------===//
253306
// SubgraphOp
254307
//===----------------------------------------------------------------------===//

mlir/test/Dialect/MLProgram/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,17 @@ ml_program.func @store_immutable(%arg0: i64) {
9696
ml_program.global_store @var = %arg0 : i64
9797
ml_program.return
9898
}
99+
100+
// -----
101+
102+
ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
103+
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
104+
%token1 = ml_program.token
105+
%0, %token2 = ml_program.global_load_graph @global_mutable_undef
106+
ordering(() -> !ml_program.token) : tensor<?xi32>
107+
%token3 = ml_program.global_store_graph @global_mutable_undef = %0
108+
// expected-error @+1 {{expected '->'}}
109+
ordering(%token1, %token2) : tensor<?xi32>
110+
111+
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
112+
}

mlir/test/Dialect/MLProgram/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ ml_program.func @global_load_store() {
4545
// CHECK-LABEL: @global_load_store_tokens
4646
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
4747
%token1 = ml_program.token
48-
%0, %token2 = ml_program.global_load @global_mutable_undef
48+
%0, %token2 = ml_program.global_load_graph @global_mutable_undef
4949
ordering(() -> !ml_program.token) : tensor<?xi32>
50-
%token3 = ml_program.global_store @global_mutable_undef = %0
50+
%token3 = ml_program.global_store_graph @global_mutable_undef = %0
5151
ordering(%token1, %token2 -> !ml_program.token) : tensor<?xi32>
52-
ml_program.global_store @global_mutable_undef = %0
53-
ordering(%token3) : tensor<?xi32>
52+
%token4 = ml_program.global_store_graph @global_mutable_undef = %0
53+
ordering(%token3 -> !ml_program.token) : tensor<?xi32>
5454

5555
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
5656
}

0 commit comments

Comments
 (0)