Skip to content

Commit 4eeee41

Browse files
authored
[MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param (#142683)
Makes it possible to pass around the options to a pass inside a schedule. The refactoring also makes it so that the pass manager and pass are only constructed once per `apply()` of the transform op versus for each target payload given to the op's `apply()`.
1 parent b9d3a64 commit 4eeee41

File tree

3 files changed

+304
-36
lines changed

3 files changed

+304
-36
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -399,15 +399,16 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
399399
}
400400

401401
def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
402-
[TransformOpInterface, TransformEachOpTrait,
403-
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
402+
[DeclareOpInterfaceMethods<TransformOpInterface>,
403+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
404404
let summary = "Applies the specified registered pass or pass pipeline";
405405
let description = [{
406406
This transform applies the specified pass or pass pipeline to the targeted
407407
ops. The name of the pass/pipeline is specified as a string attribute, as
408408
set during pass/pipeline registration. Optionally, pass options may be
409-
specified as a string attribute. The pass options syntax is identical to the
410-
one used with "mlir-opt".
409+
specified as (space-separated) string attributes with the option to pass
410+
these attributes via params. The pass options syntax is identical to the one
411+
used with "mlir-opt".
411412

412413
This op first looks for a pass pipeline with the specified name. If no such
413414
pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,21 +421,17 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
420421
of targeted ops.
421422
}];
422423

423-
let arguments = (ins TransformHandleTypeInterface:$target,
424-
StrAttr:$pass_name,
425-
DefaultValuedAttr<StrAttr, "\"\"">:$options);
424+
let arguments = (ins StrAttr:$pass_name,
425+
DefaultValuedAttr<ArrayAttr, "{}">:$options,
426+
Variadic<TransformParamTypeInterface>:$dynamic_options,
427+
TransformHandleTypeInterface:$target);
426428
let results = (outs TransformHandleTypeInterface:$result);
427429
let assemblyFormat = [{
428-
$pass_name `to` $target attr-dict `:` functional-type(operands, results)
429-
}];
430-
431-
let extraClassDeclaration = [{
432-
::mlir::DiagnosedSilenceableFailure applyToOne(
433-
::mlir::transform::TransformRewriter &rewriter,
434-
::mlir::Operation *target,
435-
::mlir::transform::ApplyToEachResultList &results,
436-
::mlir::transform::TransformState &state);
430+
$pass_name (`with` `options` `=`
431+
custom<ApplyRegisteredPassOptions>($options, $dynamic_options)^)?
432+
`to` $target attr-dict `:` functional-type(operands, results)
437433
}];
434+
let hasVerifier = 1;
438435
}
439436

440437
def CastOp : TransformDialectOp<"cast",

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 158 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353

5454
using namespace mlir;
5555

56+
static ParseResult parseApplyRegisteredPassOptions(
57+
OpAsmParser &parser, ArrayAttr &options,
58+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
59+
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
60+
Operation *op, ArrayAttr options,
61+
ValueRange dynamicOptions);
5662
static ParseResult parseSequenceOpOperands(
5763
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
5864
Type &rootType,
@@ -766,17 +772,53 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
766772
// ApplyRegisteredPassOp
767773
//===----------------------------------------------------------------------===//
768774

769-
DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
770-
transform::TransformRewriter &rewriter, Operation *target,
771-
ApplyToEachResultList &results, transform::TransformState &state) {
772-
// Make sure that this transform is not applied to itself. Modifying the
773-
// transform IR while it is being interpreted is generally dangerous. Even
774-
// more so when applying passes because they may perform a wide range of IR
775-
// modifications.
776-
DiagnosedSilenceableFailure payloadCheck =
777-
ensurePayloadIsSeparateFromTransform(*this, target);
778-
if (!payloadCheck.succeeded())
779-
return payloadCheck;
775+
void transform::ApplyRegisteredPassOp::getEffects(
776+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
777+
consumesHandle(getTargetMutable(), effects);
778+
onlyReadsHandle(getDynamicOptionsMutable(), effects);
779+
producesHandle(getOperation()->getOpResults(), effects);
780+
modifiesPayload(effects);
781+
}
782+
783+
DiagnosedSilenceableFailure
784+
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
785+
transform::TransformResults &results,
786+
transform::TransformState &state) {
787+
// Obtain a single options-string from options passed statically as
788+
// string attributes as well as "dynamically" through params.
789+
std::string options;
790+
OperandRange dynamicOptions = getDynamicOptions();
791+
size_t dynamicOptionsIdx = 0;
792+
for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
793+
if (idx > 0)
794+
options += " "; // Interleave options seperator.
795+
796+
if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
797+
options += strAttr.getValue();
798+
} else if (isa<UnitAttr>(optionAttr)) {
799+
assert(dynamicOptionsIdx < dynamicOptions.size() &&
800+
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
801+
"should be the same as the number of options passed as params");
802+
ArrayRef<Attribute> dynamicOption =
803+
state.getParams(dynamicOptions[dynamicOptionsIdx++]);
804+
if (dynamicOption.size() != 1)
805+
return emitSilenceableError() << "options passed as a param must have "
806+
"a single value associated, param "
807+
<< dynamicOptionsIdx - 1 << " associates "
808+
<< dynamicOption.size();
809+
810+
if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
811+
options += dynamicOptionStr.getValue();
812+
} else {
813+
return emitSilenceableError()
814+
<< "options passed as a param must be a string, got "
815+
<< dynamicOption[0];
816+
}
817+
} else {
818+
llvm_unreachable(
819+
"expected options element to be either StringAttr or UnitAttr");
820+
}
821+
}
780822

781823
// Get pass or pass pipeline from registry.
782824
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,26 +828,124 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
786828
return emitDefiniteFailure()
787829
<< "unknown pass or pass pipeline: " << getPassName();
788830

789-
// Create pass manager and run the pass or pass pipeline.
831+
// Create pass manager and add the pass or pass pipeline.
790832
PassManager pm(getContext());
791-
if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
833+
if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
792834
emitError(msg);
793835
return failure();
794836
}))) {
795837
return emitDefiniteFailure()
796838
<< "failed to add pass or pass pipeline to pipeline: "
797839
<< getPassName();
798840
}
799-
if (failed(pm.run(target))) {
800-
auto diag = emitSilenceableError() << "pass pipeline failed";
801-
diag.attachNote(target->getLoc()) << "target op";
802-
return diag;
841+
842+
auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
843+
for (Operation *target : targets) {
844+
// Make sure that this transform is not applied to itself. Modifying the
845+
// transform IR while it is being interpreted is generally dangerous. Even
846+
// more so when applying passes because they may perform a wide range of IR
847+
// modifications.
848+
DiagnosedSilenceableFailure payloadCheck =
849+
ensurePayloadIsSeparateFromTransform(*this, target);
850+
if (!payloadCheck.succeeded())
851+
return payloadCheck;
852+
853+
// Run the pass or pass pipeline on the current target operation.
854+
if (failed(pm.run(target))) {
855+
auto diag = emitSilenceableError() << "pass pipeline failed";
856+
diag.attachNote(target->getLoc()) << "target op";
857+
return diag;
858+
}
803859
}
804860

805-
results.push_back(target);
861+
// The applied pass will have directly modified the payload IR(s).
862+
results.set(llvm::cast<OpResult>(getResult()), targets);
806863
return DiagnosedSilenceableFailure::success();
807864
}
808865

866+
static ParseResult parseApplyRegisteredPassOptions(
867+
OpAsmParser &parser, ArrayAttr &options,
868+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
869+
auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
870+
SmallVector<Attribute> optionsArray;
871+
872+
auto parseOperandOrString = [&]() -> OptionalParseResult {
873+
OpAsmParser::UnresolvedOperand operand;
874+
OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
875+
if (parsedOperand.has_value()) {
876+
if (failed(parsedOperand.value()))
877+
return failure();
878+
879+
dynamicOptions.push_back(operand);
880+
optionsArray.push_back(
881+
dynamicOptionMarker); // Placeholder for knowing where to
882+
// inject the dynamic option-as-param.
883+
return success();
884+
}
885+
886+
StringAttr stringAttr;
887+
OptionalParseResult parsedStringAttr =
888+
parser.parseOptionalAttribute(stringAttr);
889+
if (parsedStringAttr.has_value()) {
890+
if (failed(parsedStringAttr.value()))
891+
return failure();
892+
optionsArray.push_back(stringAttr);
893+
return success();
894+
}
895+
896+
return std::nullopt;
897+
};
898+
899+
OptionalParseResult parsedOptionsElement = parseOperandOrString();
900+
while (parsedOptionsElement.has_value()) {
901+
if (failed(parsedOptionsElement.value()))
902+
return failure();
903+
parsedOptionsElement = parseOperandOrString();
904+
}
905+
906+
if (optionsArray.empty()) {
907+
return parser.emitError(parser.getCurrentLocation())
908+
<< "expected at least one option (either a string or a param)";
909+
}
910+
options = parser.getBuilder().getArrayAttr(optionsArray);
911+
return success();
912+
}
913+
914+
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
915+
Operation *op, ArrayAttr options,
916+
ValueRange dynamicOptions) {
917+
size_t currentDynamicOptionIdx = 0;
918+
for (auto [idx, optionAttr] : llvm::enumerate(options)) {
919+
if (idx > 0)
920+
printer << " "; // Interleave options separator.
921+
922+
if (isa<UnitAttr>(optionAttr))
923+
printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
924+
else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
925+
printer.printAttribute(strAttr);
926+
else
927+
llvm_unreachable("each option should be either a StringAttr or UnitAttr");
928+
}
929+
}
930+
931+
LogicalResult transform::ApplyRegisteredPassOp::verify() {
932+
size_t numUnitsInOptions = 0;
933+
for (Attribute optionsElement : getOptions()) {
934+
if (isa<UnitAttr>(optionsElement))
935+
numUnitsInOptions++;
936+
else if (!isa<StringAttr>(optionsElement))
937+
return emitOpError() << "expected each option to be either a StringAttr "
938+
<< "or a UnitAttr, got " << optionsElement;
939+
}
940+
941+
if (getDynamicOptions().size() != numUnitsInOptions)
942+
return emitOpError()
943+
<< "expected the same number of options passed as params as "
944+
<< "UnitAttr elements in options ArrayAttr";
945+
946+
return success();
947+
}
948+
809949
//===----------------------------------------------------------------------===//
810950
// CastOp
811951
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)