Skip to content

Commit e2d39f7

Browse files
[mlir][Transform] Add updateConversionTarget to ConversionPatternDescriptorOpInterface
This change adds a method to modify the ConversionTarget used during `transform.apply_conversion_patterns` to the `ConversionPatternDescriptorOpInterface`. This is needed when the TypeConverter is used to dictate the dynamic legality of operations, as in "structural" conversion patterns present in, for example, the SCF and func dialects. As a first use case/test, this change also adds a `transform.apply_patterns.scf.structural_conversions` operation to the SCF dialect. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158672
1 parent 5a58e98 commit e2d39f7

File tree

8 files changed

+105
-9
lines changed

8 files changed

+105
-9
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ def ApplyForLoopCanonicalizationPatternsOp : Op<Transform_Dialect,
2727
let assemblyFormat = "attr-dict";
2828
}
2929

30+
def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
31+
"apply_conversion_patterns.scf.structural_conversions",
32+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
33+
["populateConversionTargetRules"]>]> {
34+
let description = [{
35+
Collects patterns for performing structural conversions of SCF operations.
36+
}];
37+
38+
let assemblyFormat = "attr-dict";
39+
}
40+
3041
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
3142

3243
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
@@ -273,8 +284,8 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
273284
TransformOpInterface, TransformEachOpTrait]> {
274285
let description = [{
275286
Given an scf.if conditional, inject user-defined information that it is
276-
always safe to execute only the if or else branch.
277-
287+
always safe to execute only the if or else branch.
288+
278289
This is achieved by just replacing the scf.if by the content of one of its
279290
branches.
280291

mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ void populateSCFStructuralTypeConversionsAndLegality(
5353
TypeConverter &typeConverter, RewritePatternSet &patterns,
5454
ConversionTarget &target);
5555

56+
/// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not
57+
/// populate the conversion target.
58+
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter,
59+
RewritePatternSet &patterns);
60+
61+
/// Updates the ConversionTarget with dynamic legality of SCF operations based
62+
/// on the provided type converter.
63+
void populateSCFStructuralTypeConversionTarget(
64+
const TypeConverter &typeConverter, ConversionTarget &target);
65+
5666
/// Populates the provided pattern set with patterns that do 1:N type
5767
/// conversions on (some) SCF ops. This is intended to be used with
5868
/// applyPartialOneToNConversion.

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,23 @@ def ConversionPatternDescriptorOpInterface
333333
/*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter,
334334
"::mlir::RewritePatternSet &":$patterns)
335335
>,
336+
InterfaceMethod<
337+
/*desc=*/[{
338+
Populate the ConversionTarget using the final TypeConverter. The default
339+
implementation is to do nothing. Overriding this method can be useful
340+
in order to setup the ConversionTarget for structural type conversions.
341+
In such a situation, an op's legality depends on using the TypeConverter
342+
to determine whether the op's operand and result types are legal
343+
(defined as converting to themselves).
344+
345+
}],
346+
/*returnType=*/"void",
347+
/*name=*/"populateConversionTargetRules",
348+
/*arguments=*/(ins "const ::mlir::TypeConverter &":$typeConverter,
349+
"::mlir::ConversionTarget &":$conversionTarget),
350+
/*methodBody=*/"",
351+
/*defaultImplementation=*/"return;"
352+
>,
336353
InterfaceMethod<
337354
/*desc=*/[{
338355
Return the type converter to be used with this pattern set. If no

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
3232
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
3333
}
3434

35+
void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
36+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
37+
scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
38+
}
39+
40+
void transform::ApplySCFStructuralConversionPatternsOp::
41+
populateConversionTargetRules(const TypeConverter &typeConverter,
42+
ConversionTarget &conversionTarget) {
43+
scf::populateSCFStructuralTypeConversionTarget(typeConverter,
44+
conversionTarget);
45+
}
46+
3547
//===----------------------------------------------------------------------===//
3648
// GetParentForOp
3749
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,15 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
247247
};
248248
} // namespace
249249

250-
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
251-
TypeConverter &typeConverter, RewritePatternSet &patterns,
252-
ConversionTarget &target) {
250+
void mlir::scf::populateSCFStructuralTypeConversions(
251+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
253252
patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
254253
ConvertWhileOpTypes, ConvertConditionOpTypes>(
255254
typeConverter, patterns.getContext());
255+
}
256+
257+
void mlir::scf::populateSCFStructuralTypeConversionTarget(
258+
const TypeConverter &typeConverter, ConversionTarget &target) {
256259
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
257260
return typeConverter.isLegal(op->getResultTypes());
258261
});
@@ -266,3 +269,10 @@ void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
266269
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
267270
[&](Operation *op) { return typeConverter.isLegal(op); });
268271
}
272+
273+
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
274+
TypeConverter &typeConverter, RewritePatternSet &patterns,
275+
ConversionTarget &target) {
276+
populateSCFStructuralTypeConversions(typeConverter, patterns);
277+
populateSCFStructuralTypeConversionTarget(typeConverter, target);
278+
}

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,12 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
547547
}
548548
converter = defaultTypeConverter.get();
549549
}
550+
551+
// Add descriptor-specific updates to the conversion target, which may
552+
// depend on the final type converter. In structural converters, the
553+
// legality of types dictates the dynamic legality of an operation.
554+
descriptor.populateConversionTargetRules(*converter, conversionTarget);
555+
550556
descriptor.populatePatterns(*converter, patterns);
551557
}
552558
}

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,30 @@ transform.sequence failures(propagate) {
280280
%0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
281281
transform.loop.promote_if_one_iteration %0 : !transform.any_op
282282
}
283+
284+
285+
// -----
286+
287+
// CHECK-LABEL: func @test_structural_conversion_patterns(
288+
// CHECK: scf.for {{.*}} -> (memref<f32>) {
289+
290+
func.func @test_structural_conversion_patterns(%a: tensor<f32>) -> tensor<f32> {
291+
%c0 = arith.constant 0 : index
292+
%c1 = arith.constant 1 : index
293+
%c10 = arith.constant 10 : index
294+
%0 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg0 = %a) -> tensor<f32> {
295+
%1 = "test.foo"(%arg0) : (tensor<f32>) -> (tensor<f32>)
296+
scf.yield %1 : tensor<f32>
297+
}
298+
return %0 : tensor<f32>
299+
}
300+
301+
transform.sequence failures(propagate) {
302+
^bb1(%arg1: !transform.any_op):
303+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
304+
transform.apply_conversion_patterns to %0 {
305+
transform.apply_conversion_patterns.scf.structural_conversions
306+
} with type_converter {
307+
transform.apply_conversion_patterns.transform.test_type_converter
308+
} { partial_conversion } : !transform.any_op
309+
}

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -956,17 +956,20 @@ namespace {
956956
class TestTypeConverter : public TypeConverter {
957957
public:
958958
TestTypeConverter() {
959+
addConversion([](Type t) { return t; });
959960
addConversion([](RankedTensorType type) -> Type {
960961
return MemRefType::get(type.getShape(), type.getElementType());
961962
});
962-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
963-
ValueRange inputs,
964-
Location loc) -> std::optional<Value> {
963+
auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
964+
ValueRange inputs,
965+
Location loc) -> std::optional<Value> {
965966
if (inputs.size() != 1)
966967
return std::nullopt;
967968
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
968969
.getResult(0);
969-
});
970+
};
971+
addSourceMaterialization(unrealizedCastConverter);
972+
addTargetMaterialization(unrealizedCastConverter);
970973
}
971974
};
972975
} // namespace

0 commit comments

Comments
 (0)