@@ -2681,7 +2681,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
2681
2681
// / Register a legality action for the given operation.
2682
2682
void ConversionTarget::setOpAction (OperationName op,
2683
2683
LegalizationAction action) {
2684
- legalOperations[op] = { action, /* isRecursivelyLegal= */ false , nullptr } ;
2684
+ legalOperations[op]. action = action ;
2685
2685
}
2686
2686
2687
2687
// / Register a legality action for the given dialects.
@@ -2710,8 +2710,11 @@ auto ConversionTarget::isLegal(Operation *op) const
2710
2710
// Returns true if this operation instance is known to be legal.
2711
2711
auto isOpLegal = [&] {
2712
2712
// Handle dynamic legality either with the provided legality function.
2713
- if (info->action == LegalizationAction::Dynamic)
2714
- return info->legalityFn (op);
2713
+ if (info->action == LegalizationAction::Dynamic) {
2714
+ Optional<bool > result = info->legalityFn (op);
2715
+ if (result)
2716
+ return *result;
2717
+ }
2715
2718
2716
2719
// Otherwise, the operation is only legal if it was marked 'Legal'.
2717
2720
return info->action == LegalizationAction::Legal;
@@ -2723,14 +2726,32 @@ auto ConversionTarget::isLegal(Operation *op) const
2723
2726
LegalOpDetails legalityDetails;
2724
2727
if (info->isRecursivelyLegal ) {
2725
2728
auto legalityFnIt = opRecursiveLegalityFns.find (op->getName ());
2726
- if (legalityFnIt != opRecursiveLegalityFns.end ())
2727
- legalityDetails.isRecursivelyLegal = legalityFnIt->second (op);
2728
- else
2729
+ if (legalityFnIt != opRecursiveLegalityFns.end ()) {
2730
+ legalityDetails.isRecursivelyLegal =
2731
+ legalityFnIt->second (op).getValueOr (true );
2732
+ } else {
2729
2733
legalityDetails.isRecursivelyLegal = true ;
2734
+ }
2730
2735
}
2731
2736
return legalityDetails;
2732
2737
}
2733
2738
2739
+ static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks (
2740
+ ConversionTarget::DynamicLegalityCallbackFn oldCallback,
2741
+ ConversionTarget::DynamicLegalityCallbackFn newCallback) {
2742
+ if (!oldCallback)
2743
+ return newCallback;
2744
+
2745
+ auto chain = [oldCl = std::move (oldCallback), newCl = std::move (newCallback)](
2746
+ Operation *op) -> Optional<bool > {
2747
+ if (Optional<bool > result = newCl (op))
2748
+ return *result;
2749
+
2750
+ return oldCl (op);
2751
+ };
2752
+ return chain;
2753
+ }
2754
+
2734
2755
// / Set the dynamic legality callback for the given operation.
2735
2756
void ConversionTarget::setLegalityCallback (
2736
2757
OperationName name, const DynamicLegalityCallbackFn &callback) {
@@ -2739,7 +2760,8 @@ void ConversionTarget::setLegalityCallback(
2739
2760
assert (infoIt != legalOperations.end () &&
2740
2761
infoIt->second .action == LegalizationAction::Dynamic &&
2741
2762
" expected operation to already be marked as dynamically legal" );
2742
- infoIt->second .legalityFn = callback;
2763
+ infoIt->second .legalityFn =
2764
+ composeLegalityCallbacks (std::move (infoIt->second .legalityFn ), callback);
2743
2765
}
2744
2766
2745
2767
// / Set the recursive legality callback for the given operation and mark the
@@ -2752,7 +2774,8 @@ void ConversionTarget::markOpRecursivelyLegal(
2752
2774
" expected operation to already be marked as legal" );
2753
2775
infoIt->second .isRecursivelyLegal = true ;
2754
2776
if (callback)
2755
- opRecursiveLegalityFns[name] = callback;
2777
+ opRecursiveLegalityFns[name] = composeLegalityCallbacks (
2778
+ std::move (opRecursiveLegalityFns[name]), callback);
2756
2779
else
2757
2780
opRecursiveLegalityFns.erase (name);
2758
2781
}
@@ -2762,14 +2785,15 @@ void ConversionTarget::setLegalityCallback(
2762
2785
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2763
2786
assert (callback && " expected valid legality callback" );
2764
2787
for (StringRef dialect : dialects)
2765
- dialectLegalityFns[dialect] = callback;
2788
+ dialectLegalityFns[dialect] = composeLegalityCallbacks (
2789
+ std::move (dialectLegalityFns[dialect]), callback);
2766
2790
}
2767
2791
2768
2792
// / Set the dynamic legality callback for the unknown ops.
2769
2793
void ConversionTarget::setLegalityCallback (
2770
2794
const DynamicLegalityCallbackFn &callback) {
2771
2795
assert (callback && " expected valid legality callback" );
2772
- unknownLegalityFn = callback;
2796
+ unknownLegalityFn = composeLegalityCallbacks (unknownLegalityFn, callback) ;
2773
2797
}
2774
2798
2775
2799
// / Get the legalization information for the given operation.
0 commit comments