Skip to content

Commit c6828e0

Browse files
caitlyncanoButygin
authored andcommitted
[mlir] Make ConversionTarget dynamic legality callbacks composable
* Change callback signature `bool(Operation *)` -> `Optional<bool>(Operation *)` * addDynamicallyLegalOp add callback to the chain * If callback returned empty `Optional` next callback in chain will be called Differential Revision: https://reviews.llvm.org/D110487
1 parent 649cc16 commit c6828e0

File tree

5 files changed

+134
-13
lines changed

5 files changed

+134
-13
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ class ConversionTarget {
661661

662662
/// The signature of the callback used to determine if an operation is
663663
/// dynamically legal on the target.
664-
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
664+
using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
665665

666666
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
667667
virtual ~ConversionTarget() = default;
@@ -827,10 +827,10 @@ class ConversionTarget {
827827
/// The set of information that configures the legalization of an operation.
828828
struct LegalizationInfo {
829829
/// The legality action this operation was given.
830-
LegalizationAction action;
830+
LegalizationAction action = LegalizationAction::Illegal;
831831

832832
/// If some legal instances of this operation may also be recursively legal.
833-
bool isRecursivelyLegal;
833+
bool isRecursivelyLegal = false;
834834

835835
/// The legality callback if this operation is dynamically legal.
836836
DynamicLegalityCallbackFn legalityFn;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,7 +2681,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
26812681
/// Register a legality action for the given operation.
26822682
void ConversionTarget::setOpAction(OperationName op,
26832683
LegalizationAction action) {
2684-
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
2684+
legalOperations[op].action = action;
26852685
}
26862686

26872687
/// Register a legality action for the given dialects.
@@ -2710,8 +2710,11 @@ auto ConversionTarget::isLegal(Operation *op) const
27102710
// Returns true if this operation instance is known to be legal.
27112711
auto isOpLegal = [&] {
27122712
// Handle dynamic legality either with the provided legality function.
2713-
if (info->action == LegalizationAction::Dynamic)
2714-
return info->legalityFn(op);
2713+
if (info->action == LegalizationAction::Dynamic) {
2714+
Optional<bool> result = info->legalityFn(op);
2715+
if (result)
2716+
return *result;
2717+
}
27152718

27162719
// Otherwise, the operation is only legal if it was marked 'Legal'.
27172720
return info->action == LegalizationAction::Legal;
@@ -2723,14 +2726,32 @@ auto ConversionTarget::isLegal(Operation *op) const
27232726
LegalOpDetails legalityDetails;
27242727
if (info->isRecursivelyLegal) {
27252728
auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2726-
if (legalityFnIt != opRecursiveLegalityFns.end())
2727-
legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2728-
else
2729+
if (legalityFnIt != opRecursiveLegalityFns.end()) {
2730+
legalityDetails.isRecursivelyLegal =
2731+
legalityFnIt->second(op).getValueOr(true);
2732+
} else {
27292733
legalityDetails.isRecursivelyLegal = true;
2734+
}
27302735
}
27312736
return legalityDetails;
27322737
}
27332738

2739+
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
2740+
ConversionTarget::DynamicLegalityCallbackFn oldCallback,
2741+
ConversionTarget::DynamicLegalityCallbackFn newCallback) {
2742+
if (!oldCallback)
2743+
return newCallback;
2744+
2745+
auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
2746+
Operation *op) -> Optional<bool> {
2747+
if (Optional<bool> result = newCl(op))
2748+
return *result;
2749+
2750+
return oldCl(op);
2751+
};
2752+
return chain;
2753+
}
2754+
27342755
/// Set the dynamic legality callback for the given operation.
27352756
void ConversionTarget::setLegalityCallback(
27362757
OperationName name, const DynamicLegalityCallbackFn &callback) {
@@ -2739,7 +2760,8 @@ void ConversionTarget::setLegalityCallback(
27392760
assert(infoIt != legalOperations.end() &&
27402761
infoIt->second.action == LegalizationAction::Dynamic &&
27412762
"expected operation to already be marked as dynamically legal");
2742-
infoIt->second.legalityFn = callback;
2763+
infoIt->second.legalityFn =
2764+
composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
27432765
}
27442766

27452767
/// Set the recursive legality callback for the given operation and mark the
@@ -2752,7 +2774,8 @@ void ConversionTarget::markOpRecursivelyLegal(
27522774
"expected operation to already be marked as legal");
27532775
infoIt->second.isRecursivelyLegal = true;
27542776
if (callback)
2755-
opRecursiveLegalityFns[name] = callback;
2777+
opRecursiveLegalityFns[name] = composeLegalityCallbacks(
2778+
std::move(opRecursiveLegalityFns[name]), callback);
27562779
else
27572780
opRecursiveLegalityFns.erase(name);
27582781
}
@@ -2762,14 +2785,15 @@ void ConversionTarget::setLegalityCallback(
27622785
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
27632786
assert(callback && "expected valid legality callback");
27642787
for (StringRef dialect : dialects)
2765-
dialectLegalityFns[dialect] = callback;
2788+
dialectLegalityFns[dialect] = composeLegalityCallbacks(
2789+
std::move(dialectLegalityFns[dialect]), callback);
27662790
}
27672791

27682792
/// Set the dynamic legality callback for the unknown ops.
27692793
void ConversionTarget::setLegalityCallback(
27702794
const DynamicLegalityCallbackFn &callback) {
27712795
assert(callback && "expected valid legality callback");
2772-
unknownLegalityFn = callback;
2796+
unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
27732797
}
27742798

27752799
/// Get the legalization information for the given operation.

mlir/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ add_subdirectory(IR)
1212
add_subdirectory(Pass)
1313
add_subdirectory(Rewrite)
1414
add_subdirectory(TableGen)
15+
add_subdirectory(Transforms)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_mlir_unittest(MLIRTransformsTests
2+
DialectConversion.cpp
3+
)
4+
target_link_libraries(MLIRTransformsTests
5+
PRIVATE
6+
MLIRTransforms)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Transforms/DialectConversion.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace mlir;
13+
14+
static Operation *createOp(MLIRContext *context) {
15+
context->allowUnregisteredDialects();
16+
return Operation::create(UnknownLoc::get(context),
17+
OperationName("foo.bar", context), llvm::None,
18+
llvm::None, llvm::None, llvm::None, 0);
19+
}
20+
21+
namespace {
22+
struct DummyOp {
23+
static StringRef getOperationName() { return "foo.bar"; }
24+
};
25+
26+
TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
27+
MLIRContext context;
28+
ConversionTarget target(context);
29+
30+
int index = 0;
31+
int callbackCalled1 = 0;
32+
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
33+
callbackCalled1 = ++index;
34+
return true;
35+
});
36+
37+
int callbackCalled2 = 0;
38+
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
39+
callbackCalled2 = ++index;
40+
return llvm::None;
41+
});
42+
43+
auto *op = createOp(&context);
44+
EXPECT_TRUE(target.isLegal(op));
45+
EXPECT_EQ(2, callbackCalled1);
46+
EXPECT_EQ(1, callbackCalled2);
47+
op->destroy();
48+
}
49+
50+
TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
51+
MLIRContext context;
52+
ConversionTarget target(context);
53+
54+
int index = 0;
55+
int callbackCalled = 0;
56+
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
57+
callbackCalled = ++index;
58+
return llvm::None;
59+
});
60+
61+
auto *op = createOp(&context);
62+
EXPECT_FALSE(target.isLegal(op));
63+
EXPECT_EQ(1, callbackCalled);
64+
op->destroy();
65+
}
66+
67+
TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
68+
MLIRContext context;
69+
ConversionTarget target(context);
70+
71+
int index = 0;
72+
int callbackCalled1 = 0;
73+
target.markUnknownOpDynamicallyLegal([&](Operation *) {
74+
callbackCalled1 = ++index;
75+
return true;
76+
});
77+
78+
int callbackCalled2 = 0;
79+
target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
80+
callbackCalled2 = ++index;
81+
return llvm::None;
82+
});
83+
84+
auto *op = createOp(&context);
85+
EXPECT_TRUE(target.isLegal(op));
86+
EXPECT_EQ(2, callbackCalled1);
87+
EXPECT_EQ(1, callbackCalled2);
88+
op->destroy();
89+
}
90+
} // namespace

0 commit comments

Comments
 (0)