53
53
54
54
using namespace mlir ;
55
55
56
+ static ParseResult parseApplyRegisteredPassOptions (
57
+ OpAsmParser &parser, ArrayAttr &options,
58
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
59
+ static void printApplyRegisteredPassOptions (OpAsmPrinter &printer,
60
+ Operation *op, ArrayAttr options,
61
+ ValueRange dynamicOptions);
56
62
static ParseResult parseSequenceOpOperands (
57
63
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
58
64
Type &rootType,
@@ -766,17 +772,53 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
766
772
// ApplyRegisteredPassOp
767
773
// ===----------------------------------------------------------------------===//
768
774
769
- DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne (
770
- transform::TransformRewriter &rewriter, Operation *target,
771
- ApplyToEachResultList &results, transform::TransformState &state) {
772
- // Make sure that this transform is not applied to itself. Modifying the
773
- // transform IR while it is being interpreted is generally dangerous. Even
774
- // more so when applying passes because they may perform a wide range of IR
775
- // modifications.
776
- DiagnosedSilenceableFailure payloadCheck =
777
- ensurePayloadIsSeparateFromTransform (*this , target);
778
- if (!payloadCheck.succeeded ())
779
- return payloadCheck;
775
+ void transform::ApplyRegisteredPassOp::getEffects (
776
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
777
+ consumesHandle (getTargetMutable (), effects);
778
+ onlyReadsHandle (getDynamicOptionsMutable (), effects);
779
+ producesHandle (getOperation ()->getOpResults (), effects);
780
+ modifiesPayload (effects);
781
+ }
782
+
783
+ DiagnosedSilenceableFailure
784
+ transform::ApplyRegisteredPassOp::apply (transform::TransformRewriter &rewriter,
785
+ transform::TransformResults &results,
786
+ transform::TransformState &state) {
787
+ // Obtain a single options-string from options passed statically as
788
+ // string attributes as well as "dynamically" through params.
789
+ std::string options;
790
+ OperandRange dynamicOptions = getDynamicOptions ();
791
+ size_t dynamicOptionsIdx = 0 ;
792
+ for (auto [idx, optionAttr] : llvm::enumerate (getOptions ())) {
793
+ if (idx > 0 )
794
+ options += " " ; // Interleave options seperator.
795
+
796
+ if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
797
+ options += strAttr.getValue ();
798
+ } else if (isa<UnitAttr>(optionAttr)) {
799
+ assert (dynamicOptionsIdx < dynamicOptions.size () &&
800
+ " number of dynamic option markers (UnitAttr) in options ArrayAttr "
801
+ " should be the same as the number of options passed as params" );
802
+ ArrayRef<Attribute> dynamicOption =
803
+ state.getParams (dynamicOptions[dynamicOptionsIdx++]);
804
+ if (dynamicOption.size () != 1 )
805
+ return emitSilenceableError () << " options passed as a param must have "
806
+ " a single value associated, param "
807
+ << dynamicOptionsIdx - 1 << " associates "
808
+ << dynamicOption.size ();
809
+
810
+ if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0 ])) {
811
+ options += dynamicOptionStr.getValue ();
812
+ } else {
813
+ return emitSilenceableError ()
814
+ << " options passed as a param must be a string, got "
815
+ << dynamicOption[0 ];
816
+ }
817
+ } else {
818
+ llvm_unreachable (
819
+ " expected options element to be either StringAttr or UnitAttr" );
820
+ }
821
+ }
780
822
781
823
// Get pass or pass pipeline from registry.
782
824
const PassRegistryEntry *info = PassPipelineInfo::lookup (getPassName ());
@@ -786,26 +828,124 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
786
828
return emitDefiniteFailure ()
787
829
<< " unknown pass or pass pipeline: " << getPassName ();
788
830
789
- // Create pass manager and run the pass or pass pipeline.
831
+ // Create pass manager and add the pass or pass pipeline.
790
832
PassManager pm (getContext ());
791
- if (failed (info->addToPipeline (pm, getOptions () , [&](const Twine &msg) {
833
+ if (failed (info->addToPipeline (pm, options , [&](const Twine &msg) {
792
834
emitError (msg);
793
835
return failure ();
794
836
}))) {
795
837
return emitDefiniteFailure ()
796
838
<< " failed to add pass or pass pipeline to pipeline: "
797
839
<< getPassName ();
798
840
}
799
- if (failed (pm.run (target))) {
800
- auto diag = emitSilenceableError () << " pass pipeline failed" ;
801
- diag.attachNote (target->getLoc ()) << " target op" ;
802
- return diag;
841
+
842
+ auto targets = SmallVector<Operation *>(state.getPayloadOps (getTarget ()));
843
+ for (Operation *target : targets) {
844
+ // Make sure that this transform is not applied to itself. Modifying the
845
+ // transform IR while it is being interpreted is generally dangerous. Even
846
+ // more so when applying passes because they may perform a wide range of IR
847
+ // modifications.
848
+ DiagnosedSilenceableFailure payloadCheck =
849
+ ensurePayloadIsSeparateFromTransform (*this , target);
850
+ if (!payloadCheck.succeeded ())
851
+ return payloadCheck;
852
+
853
+ // Run the pass or pass pipeline on the current target operation.
854
+ if (failed (pm.run (target))) {
855
+ auto diag = emitSilenceableError () << " pass pipeline failed" ;
856
+ diag.attachNote (target->getLoc ()) << " target op" ;
857
+ return diag;
858
+ }
803
859
}
804
860
805
- results.push_back (target);
861
+ // The applied pass will have directly modified the payload IR(s).
862
+ results.set (llvm::cast<OpResult>(getResult ()), targets);
806
863
return DiagnosedSilenceableFailure::success ();
807
864
}
808
865
866
+ static ParseResult parseApplyRegisteredPassOptions (
867
+ OpAsmParser &parser, ArrayAttr &options,
868
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
869
+ auto dynamicOptionMarker = UnitAttr::get (parser.getContext ());
870
+ SmallVector<Attribute> optionsArray;
871
+
872
+ auto parseOperandOrString = [&]() -> OptionalParseResult {
873
+ OpAsmParser::UnresolvedOperand operand;
874
+ OptionalParseResult parsedOperand = parser.parseOptionalOperand (operand);
875
+ if (parsedOperand.has_value ()) {
876
+ if (failed (parsedOperand.value ()))
877
+ return failure ();
878
+
879
+ dynamicOptions.push_back (operand);
880
+ optionsArray.push_back (
881
+ dynamicOptionMarker); // Placeholder for knowing where to
882
+ // inject the dynamic option-as-param.
883
+ return success ();
884
+ }
885
+
886
+ StringAttr stringAttr;
887
+ OptionalParseResult parsedStringAttr =
888
+ parser.parseOptionalAttribute (stringAttr);
889
+ if (parsedStringAttr.has_value ()) {
890
+ if (failed (parsedStringAttr.value ()))
891
+ return failure ();
892
+ optionsArray.push_back (stringAttr);
893
+ return success ();
894
+ }
895
+
896
+ return std::nullopt;
897
+ };
898
+
899
+ OptionalParseResult parsedOptionsElement = parseOperandOrString ();
900
+ while (parsedOptionsElement.has_value ()) {
901
+ if (failed (parsedOptionsElement.value ()))
902
+ return failure ();
903
+ parsedOptionsElement = parseOperandOrString ();
904
+ }
905
+
906
+ if (optionsArray.empty ()) {
907
+ return parser.emitError (parser.getCurrentLocation ())
908
+ << " expected at least one option (either a string or a param)" ;
909
+ }
910
+ options = parser.getBuilder ().getArrayAttr (optionsArray);
911
+ return success ();
912
+ }
913
+
914
+ static void printApplyRegisteredPassOptions (OpAsmPrinter &printer,
915
+ Operation *op, ArrayAttr options,
916
+ ValueRange dynamicOptions) {
917
+ size_t currentDynamicOptionIdx = 0 ;
918
+ for (auto [idx, optionAttr] : llvm::enumerate (options)) {
919
+ if (idx > 0 )
920
+ printer << " " ; // Interleave options separator.
921
+
922
+ if (isa<UnitAttr>(optionAttr))
923
+ printer.printOperand (dynamicOptions[currentDynamicOptionIdx++]);
924
+ else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
925
+ printer.printAttribute (strAttr);
926
+ else
927
+ llvm_unreachable (" each option should be either a StringAttr or UnitAttr" );
928
+ }
929
+ }
930
+
931
+ LogicalResult transform::ApplyRegisteredPassOp::verify () {
932
+ size_t numUnitsInOptions = 0 ;
933
+ for (Attribute optionsElement : getOptions ()) {
934
+ if (isa<UnitAttr>(optionsElement))
935
+ numUnitsInOptions++;
936
+ else if (!isa<StringAttr>(optionsElement))
937
+ return emitOpError () << " expected each option to be either a StringAttr "
938
+ << " or a UnitAttr, got " << optionsElement;
939
+ }
940
+
941
+ if (getDynamicOptions ().size () != numUnitsInOptions)
942
+ return emitOpError ()
943
+ << " expected the same number of options passed as params as "
944
+ << " UnitAttr elements in options ArrayAttr" ;
945
+
946
+ return success ();
947
+ }
948
+
809
949
// ===----------------------------------------------------------------------===//
810
950
// CastOp
811
951
// ===----------------------------------------------------------------------===//
0 commit comments