Skip to content

Commit ec03bbe

Browse files
author
Vladislav Vinogradov
committed
[mlir] Fix bug in partial dialect conversion
The discussion on forum: https://llvm.discourse.group/t/bug-in-partial-dialect-conversion/4115 The `applyPartialConversion` didn't handle the operations, that were marked as illegal inside dynamic legality callback. Instead of reporting error, if such operation was not converted to legal set, the method just added it to `unconvertedSet` in the same way as unknown operations. This patch fixes that and handle dynamically illegal operations as well. The patch includes 2 fixes for existing passes: * `tensor-bufferize` - explicitly mark `std.return` as legal. * `convert-parallel-loops-to-gpu` - ugly fix with marking visited operations to avoid recursive legality checks. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D108505
1 parent 9a2255d commit ec03bbe

File tree

9 files changed

+194
-65
lines changed

9 files changed

+194
-65
lines changed

mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class ConversionTarget;
1616
struct LogicalResult;
1717
class MLIRContext;
1818
class Value;
19+
class Operation;
1920
class RewritePatternSet;
2021
using OwningRewritePatternList = RewritePatternSet;
2122

@@ -49,6 +50,9 @@ void populateParallelLoopToGPUPatterns(RewritePatternSet &patterns);
4950
/// are not rewritten by the provided patterns are legal.
5051
void configureParallelLoopToGPULegality(ConversionTarget &target);
5152

53+
/// Clean up after applyPartialConversion/applyFullConversion call.
54+
void finalizeParallelLoopToGPUConversion(Operation *op);
55+
5256
} // namespace mlir
5357

5458
#endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_

mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,24 @@
3737
using namespace mlir;
3838
using namespace mlir::scf;
3939

40+
// Name of internal attribute to mark visited operations during conversion.
41+
//
42+
// NOTE: The conversion originally used the following legality criteria:
43+
// `!parallelOp->hasAttr(gpu::getMappingAttrName())`
44+
// But the provided pattern might reject some cases based on more detailed
45+
// analysis of the `mapping` attribute.
46+
// To avoid dialect conversion failure due to non-converted illegal operation
47+
// we use this extra Unit attribute as a marker, that the operation was checked
48+
// by the pattern and is should be considered as legal in the following legality
49+
// checks. The `finalizeParallelLoopToGPUConversion` function performs clean up
50+
// of this extra attributes ans is supposed to be called after the dialect
51+
// conversion.
52+
//
53+
// TODO: Implement a cleaner solution, factoring out the "matching" logic
54+
// from the pattern and its callees into a separate function that can be called
55+
// from both the pattern and the op legality check.
56+
static constexpr StringLiteral kVisitedAttrName = "SCFToGPU_visited";
57+
4058
// Extract an indexed value from KernelDim3.
4159
static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
4260
switch (pos) {
@@ -567,6 +585,9 @@ static LogicalResult processParallelLoop(
567585
LogicalResult
568586
ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
569587
PatternRewriter &rewriter) const {
588+
// Mark the operation as visited for recursive legality check.
589+
parallelOp->setAttr(kVisitedAttrName, rewriter.getUnitAttr());
590+
570591
// We can only transform starting at the outer-most loop. Launches inside of
571592
// parallel loops are not supported.
572593
if (auto parentLoop = parallelOp->getParentOfType<ParallelOp>())
@@ -649,6 +670,13 @@ void mlir::populateParallelLoopToGPUPatterns(RewritePatternSet &patterns) {
649670
void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
650671
target.addLegalDialect<memref::MemRefDialect>();
651672
target.addDynamicallyLegalOp<scf::ParallelOp>([](scf::ParallelOp parallelOp) {
652-
return !parallelOp->getAttr(gpu::getMappingAttrName());
673+
return !parallelOp->hasAttr(gpu::getMappingAttrName()) ||
674+
parallelOp->hasAttr(kVisitedAttrName);
675+
});
676+
}
677+
678+
void mlir::finalizeParallelLoopToGPUConversion(Operation *op) {
679+
op->walk([](scf::ParallelOp parallelOp) {
680+
parallelOp->removeAttr(kVisitedAttrName);
653681
});
654682
}

mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct ParallelLoopToGpuPass
5555
if (failed(applyPartialConversion(getOperation(), target,
5656
std::move(patterns))))
5757
signalPassFailure();
58+
finalizeParallelLoopToGPUConversion(getOperation());
5859
}
5960
};
6061

mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
175175
target.addLegalDialect<memref::MemRefDialect>();
176176
target.addDynamicallyLegalDialect<StandardOpsDialect>(
177177
[&](Operation *op) { return typeConverter.isLegal(op); });
178+
target.addLegalOp<ReturnOp>();
178179
target.addLegalDialect<scf::SCFDialect>();
179180

180181
if (failed(

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,13 @@ OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
16501650

16511651
bool OperationLegalizer::isIllegal(Operation *op) const {
16521652
// Check if the target explicitly marked this operation as illegal.
1653-
return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1653+
if (auto info = target.getOpAction(op->getName())) {
1654+
if (*info == LegalizationAction::Dynamic)
1655+
return !target.isLegal(op);
1656+
return *info == LegalizationAction::Illegal;
1657+
}
1658+
1659+
return false;
16541660
}
16551661

16561662
LogicalResult

mlir/test/Transforms/test-legalizer-full.mlir

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,55 +47,88 @@ func @recursively_legal_invalid_op() {
4747

4848
// -----
4949

50-
// Test that region cloning can be properly undone.
51-
func @test_undo_region_clone() {
52-
"test.region"() ({
53-
^bb1(%i0: i64):
54-
"test.invalid"(%i0) : (i64) -> ()
55-
}) {legalizer.should_clone} : () -> ()
56-
57-
// expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}}
58-
%ignored = "test.illegal_op_f"() : () -> (i32)
59-
"test.return"() : () -> ()
50+
// expected-remark@+1 {{applyFullConversion failed}}
51+
builtin.module {
52+
53+
// Test that region cloning can be properly undone.
54+
func @test_undo_region_clone() {
55+
"test.region"() ({
56+
^bb1(%i0: i64):
57+
"test.invalid"(%i0) : (i64) -> ()
58+
}) {legalizer.should_clone} : () -> ()
59+
60+
// expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}}
61+
%ignored = "test.illegal_op_f"() : () -> (i32)
62+
"test.return"() : () -> ()
63+
}
64+
6065
}
6166

6267
// -----
6368

64-
// Test that unknown operations can be dynamically legal.
65-
func @test_unknown_dynamically_legal() {
66-
"foo.unknown_op"() {test.dynamically_legal} : () -> ()
69+
// expected-remark@+1 {{applyFullConversion failed}}
70+
builtin.module {
71+
72+
// Test that unknown operations can be dynamically legal.
73+
func @test_unknown_dynamically_legal() {
74+
"foo.unknown_op"() {test.dynamically_legal} : () -> ()
75+
76+
// expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}}
77+
"foo.unknown_op"() {} : () -> ()
78+
"test.return"() : () -> ()
79+
}
6780

68-
// expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}}
69-
"foo.unknown_op"() {} : () -> ()
70-
"test.return"() : () -> ()
7181
}
7282

7383
// -----
7484

75-
// Test that region inlining can be properly undone.
76-
func @test_undo_region_inline() {
77-
"test.region"() ({
78-
^bb1(%i0: i64):
79-
// expected-error@+1 {{failed to legalize operation 'std.br'}}
80-
br ^bb2(%i0 : i64)
81-
^bb2(%i1: i64):
82-
"test.invalid"(%i1) : (i64) -> ()
83-
}) {} : () -> ()
85+
// expected-remark@+1 {{applyFullConversion failed}}
86+
builtin.module {
87+
88+
// Test that region inlining can be properly undone.
89+
func @test_undo_region_inline() {
90+
"test.region"() ({
91+
^bb1(%i0: i64):
92+
// expected-error@+1 {{failed to legalize operation 'std.br'}}
93+
br ^bb2(%i0 : i64)
94+
^bb2(%i1: i64):
95+
"test.invalid"(%i1) : (i64) -> ()
96+
}) {} : () -> ()
97+
98+
"test.return"() : () -> ()
99+
}
84100

85-
"test.return"() : () -> ()
86101
}
87102

88103
// -----
89104

90-
// Test that multiple block erases can be properly undone.
91-
func @test_undo_block_erase() {
92-
// expected-error@+1 {{failed to legalize operation 'test.region'}}
93-
"test.region"() ({
94-
^bb1(%i0: i64):
95-
br ^bb2(%i0 : i64)
96-
^bb2(%i1: i64):
97-
"test.invalid"(%i1) : (i64) -> ()
98-
}) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
105+
// expected-remark@+1 {{applyFullConversion failed}}
106+
builtin.module {
107+
108+
// Test that multiple block erases can be properly undone.
109+
func @test_undo_block_erase() {
110+
// expected-error@+1 {{failed to legalize operation 'test.region'}}
111+
"test.region"() ({
112+
^bb1(%i0: i64):
113+
br ^bb2(%i0 : i64)
114+
^bb2(%i1: i64):
115+
"test.invalid"(%i1) : (i64) -> ()
116+
}) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
117+
118+
"test.return"() : () -> ()
119+
}
120+
121+
}
122+
123+
// -----
124+
125+
// expected-remark@+1 {{applyFullConversion failed}}
126+
builtin.module {
127+
128+
func @create_unregistered_op_in_pattern() -> i32 {
129+
// expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}}
130+
%0 = "test.illegal_op_g"() : () -> (i32)
131+
"test.return"(%0) : (i32) -> ()
132+
}
99133

100-
"test.return"() : () -> ()
101134
}

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -173,36 +173,50 @@ func @bounded_recursion() {
173173

174174
// -----
175175

176-
func @fail_to_convert_illegal_op() -> i32 {
177-
// expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}}
178-
%result = "test.illegal_op_f"() : () -> (i32)
179-
return %result : i32
176+
// expected-remark@+1 {{applyPartialConversion failed}}
177+
builtin.module {
178+
179+
func @fail_to_convert_illegal_op() -> i32 {
180+
// expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}}
181+
%result = "test.illegal_op_f"() : () -> (i32)
182+
return %result : i32
183+
}
184+
180185
}
181186

182187
// -----
183188

184-
func @fail_to_convert_illegal_op_in_region() {
185-
// expected-error@+1 {{failed to legalize operation 'test.region_builder'}}
186-
"test.region_builder"() : () -> ()
187-
return
189+
// expected-remark@+1 {{applyPartialConversion failed}}
190+
builtin.module {
191+
192+
func @fail_to_convert_illegal_op_in_region() {
193+
// expected-error@+1 {{failed to legalize operation 'test.region_builder'}}
194+
"test.region_builder"() : () -> ()
195+
return
196+
}
197+
188198
}
189199

190200
// -----
191201

192202
// Check that the entry block arguments of a region are untouched in the case
193203
// of failure.
194204

195-
// CHECK-LABEL: func @fail_to_convert_region
196-
func @fail_to_convert_region() {
197-
// CHECK-NEXT: "test.region"
198-
// CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64):
199-
"test.region"() ({
200-
^bb1(%i0: i64):
201-
// expected-error@+1 {{failed to legalize operation 'test.region_builder'}}
202-
"test.region_builder"() : () -> ()
203-
"test.valid"() : () -> ()
204-
}) : () -> ()
205-
return
205+
// expected-remark@+1 {{applyPartialConversion failed}}
206+
builtin.module {
207+
208+
func @fail_to_convert_region() {
209+
// CHECK: "test.region"
210+
// CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64):
211+
"test.region"() ({
212+
^bb1(%i0: i64):
213+
// expected-error@+1 {{failed to legalize operation 'test.region_builder'}}
214+
"test.region_builder"() : () -> ()
215+
"test.valid"() : () -> ()
216+
}) : () -> ()
217+
return
218+
}
219+
206220
}
207221

208222
// -----
@@ -271,10 +285,8 @@ func @undo_child_created_before_parent() {
271285
return
272286
}
273287

274-
275288
// -----
276289

277-
278290
// Check that a conversion pattern on `test.blackhole` can mark the producer
279291
// for deletion.
280292
// CHECK-LABEL: @blackhole
@@ -284,3 +296,16 @@ func @blackhole() {
284296
// expected-remark@+1 {{op 'std.return' is not legalizable}}
285297
return
286298
}
299+
300+
// -----
301+
302+
// expected-remark@+1 {{applyPartialConversion failed}}
303+
builtin.module {
304+
305+
func @create_unregistered_op_in_pattern() -> i32 {
306+
// expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}}
307+
%0 = "test.illegal_op_g"() : () -> (i32)
308+
"test.return"(%0) : (i32) -> ()
309+
}
310+
311+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,9 +1415,12 @@ def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32)>;
14151415
def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>;
14161416
def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
14171417
def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
1418+
def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>;
14181419
def LegalOpA : TEST_Op<"legal_op_a">,
14191420
Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
14201421
def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
1422+
def LegalOpC : TEST_Op<"legal_op_c">,
1423+
Arguments<(ins I32)>, Results<(outs I32)>;
14211424

14221425
// Check that the conversion infrastructure can properly undo the creation of
14231426
// operations where an operation was created before its parent, in this case,

0 commit comments

Comments
 (0)