Skip to content

Commit e438414

Browse files
authored
[mlir] use transform-interpreter in test passes (#70040)
Update most test passes to use the transform-interpreter pass instead of the test-transform-dialect-interpreter-pass. The new "main" interpreter pass has a named entry point instead of looking up the top-level op with `PossibleTopLevelOpTrait`, which is arguably a more understandable interface. The change is mechanical, rewriting an unnamed sequence into a named one and wrapping the transform IR in to a module when necessary. Add an option to the transform-interpreter pass to target a tagged payload op instead of the root anchor op, which is also useful for repro generation. Only the test in the transform dialect proper and the examples have not been updated yet. These will be updated separately after a more careful consideration of testing coverage of the transform interpreter logic.
1 parent f364a7a commit e438414

File tree

131 files changed

+5096
-3956
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

131 files changed

+5096
-3956
lines changed

mlir/include/mlir/Dialect/Transform/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def InterpreterPass : Pass<"transform-interpreter"> {
7070
}];
7171
let dependentDialects = ["::mlir::transform::TransformDialect"];
7272
let options = [
73+
Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string",
74+
/*default=*/[{""}],
75+
"Select the operation with 'transform.target_tag' attribute having "
76+
"the given value as payload IR root. If empty select the pass "
77+
"anchor operation as the payload IR root.">,
78+
Option<"disableExpensiveChecks", "disable-expensive-checks", "bool",
79+
"false",
80+
"Disable expensive checks in the interpreter for a faster run.">,
7381
Option<"entryPoint", "entry-point", "std::string",
7482
/*default=*/[{"__transform_main"}],
7583
"Entry point of the pass pipeline.">,

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,18 @@ LogicalResult mergeSymbolsInto(Operation *target,
8585
OwningOpRef<Operation *> other);
8686
} // namespace detail
8787

88-
/// Standalone util to apply the named sequence `entryPoint` to the payload.
89-
/// This is done in 3 steps:
90-
/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
91-
/// calling detail::findTransformEntryPoint.
92-
/// 2. if the entry point is found and not nested under
93-
/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
94-
/// the `sharedTransformModule`. Note: this may modify the transform IR
95-
/// embedded with the payload IR.
96-
/// 3. apply the transform IR to the payload IR, relaxing the requirement that
97-
/// the transform IR is a top-level transform op. We are applying a named
98-
/// sequence anyway.
99-
LogicalResult applyTransformNamedSequence(
100-
Operation *payload, ModuleOp transformModule,
101-
const TransformOptions &options,
102-
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
88+
/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.
89+
/// This is done in 2 steps:
90+
/// 1. If `transformModule` is provided and is not nested under
91+
/// `transformRoot`, it will be "linked into" the IR containing
92+
/// `transformRoot` to resolve undefined named sequences.
93+
/// 2. The transforms specified in `transformRoot` are applied to `payload`,
94+
/// assuming the named sequence has a single argument handle that will be
95+
/// associated with `payload` on run.
96+
LogicalResult applyTransformNamedSequence(Operation *payload,
97+
Operation *transformRoot,
98+
ModuleOp transformModule,
99+
const TransformOptions &options);
103100

104101
} // namespace transform
105102
} // namespace mlir

mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@ namespace transform {
1919
} // namespace transform
2020
} // namespace mlir
2121

22+
/// Returns the payload operation to be used as payload root:
23+
/// - the operation nested under `passRoot` that has the given tag attribute,
24+
/// must be unique;
25+
/// - the `passRoot` itself if the tag is empty.
26+
static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
27+
// Fast return.
28+
if (tag.empty())
29+
return passRoot;
30+
31+
// Walk to do a lookup.
32+
Operation *target = nullptr;
33+
auto tagAttrName = StringAttr::get(
34+
passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
35+
WalkResult walkResult = passRoot->walk([&](Operation *op) {
36+
auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
37+
if (!attr || attr.getValue() != tag)
38+
return WalkResult::advance();
39+
40+
if (!target) {
41+
target = op;
42+
return WalkResult::advance();
43+
}
44+
45+
InFlightDiagnostic diag = op->emitError()
46+
<< "repeated operation with the target tag '"
47+
<< tag << "'";
48+
diag.attachNote(target->getLoc()) << "previously seen operation";
49+
return WalkResult::interrupt();
50+
});
51+
52+
return walkResult.wasInterrupted() ? nullptr : target;
53+
}
54+
2255
namespace {
2356
class InterpreterPass
2457
: public transform::impl::InterpreterPassBase<InterpreterPass> {
@@ -29,10 +62,22 @@ class InterpreterPass
2962
MLIRContext *context = &getContext();
3063
ModuleOp transformModule =
3164
transform::detail::getPreloadedTransformModule(context);
65+
Operation *payloadRoot =
66+
findPayloadRoot(getOperation(), debugPayloadRootTag);
67+
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
68+
getOperation(), transformModule, entryPoint);
69+
if (!transformEntryPoint) {
70+
getOperation()->emitError()
71+
<< "could not find transform entry point: " << entryPoint
72+
<< " in either payload or transform module";
73+
return signalPassFailure();
74+
}
75+
3276
if (failed(transform::applyTransformNamedSequence(
33-
getOperation(), transformModule,
34-
options.enableExpensiveChecks(true), entryPoint)))
77+
payloadRoot, transformEntryPoint, transformModule,
78+
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
3579
return signalPassFailure();
80+
}
3681
}
3782

3883
private:

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -409,16 +409,8 @@ transform::detail::mergeSymbolsInto(Operation *target,
409409
}
410410

411411
LogicalResult transform::applyTransformNamedSequence(
412-
Operation *payload, ModuleOp transformModule,
413-
const TransformOptions &options, StringRef entryPoint) {
414-
Operation *transformRoot =
415-
detail::findTransformEntryPoint(payload, transformModule, entryPoint);
416-
if (!transformRoot) {
417-
return payload->emitError()
418-
<< "could not find transform entry point: " << entryPoint
419-
<< " in either payload or transform module";
420-
}
421-
412+
Operation *payload, Operation *transformRoot, ModuleOp transformModule,
413+
const TransformOptions &options) {
422414
// `transformModule` may not be modified.
423415
if (transformModule && !transformModule->isAncestor(transformRoot)) {
424416
OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());

mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1 use-opaque-pointers=1' %s | FileCheck %s --check-prefix=BAREPTR
44

5-
// RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck %s --check-prefix=BAREPTR
5+
// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
66

77
// These tests were separated from func-memref.mlir because applying
88
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
@@ -110,17 +110,20 @@ func.func @unranked_memref(%arg0:memref<*xi32>) {
110110
}
111111
func.func private @printMemrefI32(memref<*xi32>)
112112

113-
transform.sequence failures(propagate) {
114-
^bb1(%toplevel_module: !transform.any_op):
115-
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
116-
: (!transform.any_op) -> !transform.any_op
117-
transform.apply_conversion_patterns to %func {
118-
transform.apply_conversion_patterns.func.func_to_llvm
119-
} with type_converter {
120-
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
121-
{use_bare_ptr_call_conv = true, use_opaque_pointers = true}
122-
} {
123-
legal_dialects = ["llvm"],
124-
partial_conversion
125-
} : !transform.any_op
113+
module attributes {transform.with_named_sequence} {
114+
transform.named_sequence @__transform_main(
115+
%toplevel_module: !transform.any_op {transform.readonly}) {
116+
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
117+
: (!transform.any_op) -> !transform.any_op
118+
transform.apply_conversion_patterns to %func {
119+
transform.apply_conversion_patterns.func.func_to_llvm
120+
} with type_converter {
121+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
122+
{use_bare_ptr_call_conv = true, use_opaque_pointers = true}
123+
} {
124+
legal_dialects = ["llvm"],
125+
partial_conversion
126+
} : !transform.any_op
127+
transform.yield
128+
}
126129
}

mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32 use-opaque-pointers=1},reconcile-unrealized-casts)" %s | FileCheck --check-prefix=CHECK32 %s
44

5-
// RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck --check-prefix=CHECK32 %s
5+
// RUN: mlir-opt -transform-interpreter %s | FileCheck --check-prefix=CHECK32 %s
66

77
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
88
// and the generic `convert-to-llvm` pass.
@@ -537,20 +537,22 @@ func.func @switchi8(%arg0 : i8) -> i32 {
537537
// CHECK-NEXT: llvm.return %[[E1]] : i32
538538
// CHECK-NEXT: }
539539

540-
transform.sequence failures(propagate) {
541-
^bb1(%toplevel_module: !transform.any_op):
542-
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
543-
: (!transform.any_op) -> !transform.any_op
544-
transform.apply_conversion_patterns to %func {
545-
transform.apply_conversion_patterns.dialect_to_llvm "math"
546-
transform.apply_conversion_patterns.dialect_to_llvm "arith"
547-
transform.apply_conversion_patterns.dialect_to_llvm "cf"
548-
transform.apply_conversion_patterns.func.func_to_llvm
549-
} with type_converter {
550-
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
551-
{index_bitwidth = 32, use_opaque_pointers = true}
552-
} {
553-
legal_dialects = ["llvm"],
554-
partial_conversion
555-
} : !transform.any_op
540+
module attributes {transform.with_named_sequence} {
541+
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
542+
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
543+
: (!transform.any_op) -> !transform.any_op
544+
transform.apply_conversion_patterns to %func {
545+
transform.apply_conversion_patterns.dialect_to_llvm "math"
546+
transform.apply_conversion_patterns.dialect_to_llvm "arith"
547+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
548+
transform.apply_conversion_patterns.func.func_to_llvm
549+
} with type_converter {
550+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
551+
{index_bitwidth = 32, use_opaque_pointers = true}
552+
} {
553+
legal_dialects = ["llvm"],
554+
partial_conversion
555+
} : !transform.any_op
556+
transform.yield
557+
}
556558
}

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32 use-opaque-pointers=1' -split-input-file | FileCheck %s
22

3-
// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
3+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
44

55
gpu.module @test_module_0 {
66
// CHECK-LABEL: func @gpu_index_ops()
@@ -48,30 +48,32 @@ gpu.module @test_module_1 {
4848
}
4949
}
5050

51-
transform.sequence failures(propagate) {
52-
^bb1(%toplevel_module: !transform.any_op):
53-
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
54-
: (!transform.any_op) -> !transform.any_op
55-
transform.apply_conversion_patterns to %gpu_module {
56-
transform.apply_conversion_patterns.dialect_to_llvm "arith"
57-
transform.apply_conversion_patterns.dialect_to_llvm "cf"
58-
transform.apply_conversion_patterns.vector.vector_to_llvm
59-
transform.apply_conversion_patterns.func.func_to_llvm
60-
transform.apply_conversion_patterns.dialect_to_llvm "memref"
61-
transform.apply_conversion_patterns.gpu.gpu_to_nvvm
62-
transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
63-
transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm {has_redux = true}
64-
transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm
65-
} with type_converter {
66-
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
67-
{index_bitwidth = 32, use_opaque_pointers = true}
68-
} {
69-
legal_dialects = ["llvm", "memref", "nvvm"],
70-
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
71-
illegal_dialects = ["gpu"],
72-
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
73-
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
74-
"llvm.sin", "llvm.sqrt"],
75-
partial_conversion
76-
} : !transform.any_op
51+
module attributes {transform.with_named_sequence} {
52+
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
53+
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
54+
: (!transform.any_op) -> !transform.any_op
55+
transform.apply_conversion_patterns to %gpu_module {
56+
transform.apply_conversion_patterns.dialect_to_llvm "arith"
57+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
58+
transform.apply_conversion_patterns.vector.vector_to_llvm
59+
transform.apply_conversion_patterns.func.func_to_llvm
60+
transform.apply_conversion_patterns.dialect_to_llvm "memref"
61+
transform.apply_conversion_patterns.gpu.gpu_to_nvvm
62+
transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
63+
transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm {has_redux = true}
64+
transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm
65+
} with type_converter {
66+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
67+
{index_bitwidth = 32, use_opaque_pointers = true}
68+
} {
69+
legal_dialects = ["llvm", "memref", "nvvm"],
70+
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
71+
illegal_dialects = ["gpu"],
72+
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
73+
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
74+
"llvm.sin", "llvm.sqrt"],
75+
partial_conversion
76+
} : !transform.any_op
77+
transform.yield
78+
}
7779
}

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-opaque-pointers=1' -split-input-file | FileCheck %s
2-
// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
2+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
33

44
gpu.module @test_module_0 {
55
// CHECK-LABEL: func @gpu_index_ops()
@@ -627,38 +627,40 @@ gpu.module @test_module_31 {
627627
}
628628
}
629629

630-
transform.sequence failures(propagate) {
631-
^bb1(%toplevel_module: !transform.any_op):
632-
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
633-
: (!transform.any_op) -> !transform.any_op
634-
635-
transform.apply_patterns to %gpu_module {
636-
transform.apply_patterns.gpu.gpu_rewrite_patterns
637-
} : !transform.any_op
638-
639-
transform.apply_conversion_patterns to %gpu_module {
640-
transform.apply_conversion_patterns.dialect_to_llvm "arith"
641-
transform.apply_conversion_patterns.dialect_to_llvm "cf"
642-
transform.apply_conversion_patterns.vector.vector_to_llvm
643-
transform.apply_conversion_patterns.func.func_to_llvm
644-
transform.apply_conversion_patterns.dialect_to_llvm "memref"
645-
transform.apply_conversion_patterns.gpu.gpu_to_nvvm
646-
transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
647-
transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm
648-
transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm
649-
} with type_converter {
650-
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
651-
{index_bitwidth = 64,
652-
use_bare_ptr = true,
653-
use_bare_ptr_memref_call_conv = true,
654-
use_opaque_pointers = true}
655-
} {
656-
legal_dialects = ["llvm", "memref", "nvvm", "test"],
657-
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
658-
illegal_dialects = ["gpu"],
659-
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
660-
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2","llvm.pow",
661-
"llvm.sin", "llvm.sqrt"],
662-
partial_conversion
663-
} : !transform.any_op
630+
module attributes {transform.with_named_sequence} {
631+
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
632+
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
633+
: (!transform.any_op) -> !transform.any_op
634+
635+
transform.apply_patterns to %gpu_module {
636+
transform.apply_patterns.gpu.gpu_rewrite_patterns
637+
} : !transform.any_op
638+
639+
transform.apply_conversion_patterns to %gpu_module {
640+
transform.apply_conversion_patterns.dialect_to_llvm "arith"
641+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
642+
transform.apply_conversion_patterns.vector.vector_to_llvm
643+
transform.apply_conversion_patterns.func.func_to_llvm
644+
transform.apply_conversion_patterns.dialect_to_llvm "memref"
645+
transform.apply_conversion_patterns.gpu.gpu_to_nvvm
646+
transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
647+
transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm
648+
transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm
649+
} with type_converter {
650+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
651+
{index_bitwidth = 64,
652+
use_bare_ptr = true,
653+
use_bare_ptr_memref_call_conv = true,
654+
use_opaque_pointers = true}
655+
} {
656+
legal_dialects = ["llvm", "memref", "nvvm", "test"],
657+
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
658+
illegal_dialects = ["gpu"],
659+
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
660+
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2","llvm.pow",
661+
"llvm.sin", "llvm.sqrt"],
662+
partial_conversion
663+
} : !transform.any_op
664+
transform.yield
665+
}
664666
}

0 commit comments

Comments
 (0)