Skip to content

Commit b087fdc

Browse files
committed
[MLIR][DLTI] Enable types as keys in DLTI-query utils
Enable support for query functions - include transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way.
1 parent 1c46fc0 commit b087fdc

File tree

7 files changed

+149
-10
lines changed

7 files changed

+149
-10
lines changed

mlir/include/mlir/Dialect/DLTI/DLTI.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace mlir {
2626
namespace dlti {
2727
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
2828
/// query interface-implementing attrs, starting from attr obtained from `op`.
29-
FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
29+
FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
3030
bool emitError = false);
3131
} // namespace dlti
3232
} // namespace mlir

mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
2626

2727
A lookup is performed for the given `keys` at `target` op - or its closest
2828
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
29-
returns an attribute for a key. If more than one key is provided, the lookup
30-
continues recursively, now on the returned attributes, with the condition
31-
that these implement the above interface. For example if the payload IR is
29+
returns an attribute for a key. Each key should be either a (quoted) string
30+
or a type. If more than one key is provided, the lookup continues
31+
recursively, now on the returned attributes, with the condition that these
32+
implement the above interface. For example if the payload IR is
3233

3334
```
3435
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
@@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
5253
}];
5354

5455
let arguments = (ins TransformHandleTypeInterface:$target,
55-
StrArrayAttr:$keys);
56+
ArrayAttr:$keys);
5657
let results = (outs TransformParamTypeInterface:$associated_attr);
5758
let assemblyFormat =
5859
"$keys `at` $target attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/DLTI/DLTI.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ getClosestQueryable(Operation *op) {
424424
return std::pair(queryable, op);
425425
}
426426

427-
FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
428-
bool emitError) {
427+
FailureOr<Attribute>
428+
dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
429429
auto [queryable, queryOp] = getClosestQueryable(op);
430430
Operation *reportOp = (queryOp ? queryOp : op);
431431

@@ -438,6 +438,17 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
438438
return failure();
439439
}
440440

441+
auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
442+
if (auto strKey = llvm::dyn_cast<StringAttr>(key))
443+
return "\"" + std::string(strKey.getValue()) + "\"";
444+
if (auto typeKey = llvm::dyn_cast<Type>(key)) {
445+
std::string buf;
446+
llvm::raw_string_ostream(buf) << typeKey;
447+
return buf;
448+
}
449+
llvm_unreachable("DataLayoutEntryKey was not `StringAttr` or `Type`");
450+
};
451+
441452
Attribute currentAttr = queryable;
442453
for (auto &&[idx, key] : llvm::enumerate(keys)) {
443454
if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
@@ -446,17 +457,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
446457
if (emitError) {
447458
auto diag = op->emitError() << "target op of failed DLTI query";
448459
diag.attachNote(reportOp->getLoc())
449-
<< "key " << key << " has no DLTI-mapping per attr: " << map;
460+
<< "key " << keyToStr(key)
461+
<< " has no DLTI-mapping per attr: " << map;
450462
}
451463
return failure();
452464
}
453465
currentAttr = *maybeAttr;
454466
} else {
455467
if (emitError) {
468+
std::string commaSeparatedKeys;
469+
llvm::interleave(
470+
keys.take_front(idx), // All prior keys.
471+
[&](auto key) { commaSeparatedKeys += keyToStr(key); },
472+
[&]() { commaSeparatedKeys += ","; });
473+
456474
auto diag = op->emitError() << "target op of failed DLTI query";
457475
diag.attachNote(reportOp->getLoc())
458476
<< "got non-DLTI-queryable attribute upon looking up keys ["
459-
<< keys.take_front(idx) << "] at op";
477+
<< commaSeparatedKeys << "] at op";
460478
}
461479
return failure();
462480
}

mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
3333
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
3434
transform::TransformRewriter &rewriter, Operation *target,
3535
transform::ApplyToEachResultList &results, TransformState &state) {
36-
auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
36+
auto keys = SmallVector<DataLayoutEntryKey>();
37+
for (Attribute key : getKeys()) {
38+
if (auto strKey = dyn_cast<StringAttr>(key))
39+
keys.push_back(strKey);
40+
else if (auto typeKey = dyn_cast<TypeAttr>(key))
41+
keys.push_back(typeKey.getValue());
42+
else
43+
return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
44+
"only StringAttr and TypeAttr are allowed");
45+
}
3746

3847
FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);
3948

mlir/test/Dialect/DLTI/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333

3434
// -----
3535

36+
// expected-error@below {{repeated layout entry key: 'i32'}}
37+
"test.unknown_op"() { test.unknown_attr = #dlti.map<
38+
#dlti.dl_entry<i32, 42>,
39+
#dlti.dl_entry<i32, 42>
40+
>} : () -> ()
41+
42+
// -----
43+
3644
// expected-error@below {{repeated layout entry key: 'i32'}}
3745
"test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
3846
#dlti.dl_entry<i32, 42>,

mlir/test/Dialect/DLTI/query.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {
1717

1818
// -----
1919

20+
// expected-remark @below {{associated attr 42 : i32}}
21+
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, 42 : i32>>} {
22+
func.func private @f()
23+
}
24+
25+
module attributes {transform.with_named_sequence} {
26+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
27+
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
28+
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
29+
%param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
30+
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
31+
transform.yield
32+
}
33+
}
34+
35+
// -----
36+
37+
// expected-remark @below {{associated attr 32 : i32}}
38+
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
39+
func.func private @f()
40+
}
41+
42+
module attributes {transform.with_named_sequence} {
43+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
44+
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
45+
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
46+
%param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
47+
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
48+
transform.yield
49+
}
50+
}
51+
52+
// -----
53+
54+
// expected-remark @below {{width in bits of i32 = 32 : i64}}
55+
// expected-remark @below {{width in bits of f64 = 64 : i64}}
56+
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
57+
func.func private @f()
58+
}
59+
60+
module attributes {transform.with_named_sequence} {
61+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
62+
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
63+
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
64+
%i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
65+
%f64bits = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
66+
transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
67+
transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
68+
transform.yield
69+
}
70+
}
71+
72+
// -----
73+
2074
// expected-remark @below {{associated attr 42 : i32}}
2175
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
2276
func.func private @f()
@@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {
336390

337391
// -----
338392

393+
// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
394+
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
395+
// expected-error @below {{target op of failed DLTI query}}
396+
func.func private @f()
397+
}
398+
399+
module attributes {transform.with_named_sequence} {
400+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
401+
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
402+
// expected-error @below {{'transform.dlti.query' op failed to apply}}
403+
%param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
404+
transform.yield
405+
}
406+
}
407+
408+
// -----
409+
339410
module {
340411
// expected-error @below {{target op of failed DLTI query}}
341412
// expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
@@ -353,6 +424,23 @@ module attributes {transform.with_named_sequence} {
353424

354425
// -----
355426

427+
// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
428+
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
429+
// expected-error @below {{target op of failed DLTI query}}
430+
func.func private @f()
431+
}
432+
433+
module attributes {transform.with_named_sequence} {
434+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
435+
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
436+
// expected-error @below {{'transform.dlti.query' op failed to apply}}
437+
%param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
438+
transform.yield
439+
}
440+
}
441+
442+
// -----
443+
356444
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
357445
func.func private @f()
358446
}

mlir/test/Dialect/DLTI/valid.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,18 @@ module attributes {
206206
"GPU": #dlti.target_device_spec<
207207
#dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
208208
>} {}
209+
210+
211+
// -----
212+
213+
// CHECK: "test.op_with_dlti_map"() ({
214+
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
215+
"test.op_with_dlti_map"() ({
216+
}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()
217+
218+
// -----
219+
220+
// CHECK: "test.op_with_dlti_map"() ({
221+
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
222+
"test.op_with_dlti_map"() ({
223+
}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()

0 commit comments

Comments
 (0)