Skip to content

[MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param #142683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,16 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
}

def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
[TransformOpInterface, TransformEachOpTrait,
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Applies the specified registered pass or pass pipeline";
let description = [{
This transform applies the specified pass or pass pipeline to the targeted
ops. The name of the pass/pipeline is specified as a string attribute, as
set during pass/pipeline registration. Optionally, pass options may be
specified as a string attribute. The pass options syntax is identical to the
one used with "mlir-opt".
specified as (space-separated) string attributes with the option to pass
these attributes via params. The pass options syntax is identical to the one
used with "mlir-opt".

This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
Expand All @@ -420,21 +421,17 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
StrAttr:$pass_name,
DefaultValuedAttr<StrAttr, "\"\"">:$options);
let arguments = (ins StrAttr:$pass_name,
DefaultValuedAttr<ArrayAttr, "{}">:$options,
Variadic<TransformParamTypeInterface>:$dynamic_options,
TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
$pass_name `to` $target attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
$pass_name (`with` `options` `=`
custom<ApplyRegisteredPassOptions>($options, $dynamic_options)^)?
`to` $target attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}

def CastOp : TransformDialectOp<"cast",
Expand Down
176 changes: 158 additions & 18 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@

using namespace mlir;

static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser, ArrayAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op, ArrayAttr options,
ValueRange dynamicOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
Expand Down Expand Up @@ -766,17 +772,53 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so when applying passes because they may perform a wide range of IR
// modifications.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
void transform::ApplyRegisteredPassOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getDynamicOptionsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

DiagnosedSilenceableFailure
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Obtain a single options-string from options passed statically as
// string attributes as well as "dynamically" through params.
std::string options;
OperandRange dynamicOptions = getDynamicOptions();
size_t dynamicOptionsIdx = 0;
for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
if (idx > 0)
options += " "; // Interleave options seperator.

if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
options += strAttr.getValue();
} else if (isa<UnitAttr>(optionAttr)) {
assert(dynamicOptionsIdx < dynamicOptions.size() &&
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
"should be the same as the number of options passed as params");
ArrayRef<Attribute> dynamicOption =
state.getParams(dynamicOptions[dynamicOptionsIdx++]);
if (dynamicOption.size() != 1)
return emitSilenceableError() << "options passed as a param must have "
"a single value associated, param "
<< dynamicOptionsIdx - 1 << " associates "
<< dynamicOption.size();

if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
options += dynamicOptionStr.getValue();
} else {
return emitSilenceableError()
<< "options passed as a param must be a string, got "
<< dynamicOption[0];
}
} else {
llvm_unreachable(
"expected options element to be either StringAttr or UnitAttr");
}
}

// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
Expand All @@ -786,26 +828,124 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return emitDefiniteFailure()
<< "unknown pass or pass pipeline: " << getPassName();

// Create pass manager and run the pass or pass pipeline.
// Create pass manager and add the pass or pass pipeline.
PassManager pm(getContext());
if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
emitError(msg);
return failure();
}))) {
return emitDefiniteFailure()
<< "failed to add pass or pass pipeline to pipeline: "
<< getPassName();
}
if (failed(pm.run(target))) {
auto diag = emitSilenceableError() << "pass pipeline failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;

auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
for (Operation *target : targets) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so when applying passes because they may perform a wide range of IR
// modifications.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;

// Run the pass or pass pipeline on the current target operation.
if (failed(pm.run(target))) {
auto diag = emitSilenceableError() << "pass pipeline failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
}

results.push_back(target);
// The applied pass will have directly modified the payload IR(s).
results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}

static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser, ArrayAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
SmallVector<Attribute> optionsArray;

auto parseOperandOrString = [&]() -> OptionalParseResult {
OpAsmParser::UnresolvedOperand operand;
OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
if (parsedOperand.has_value()) {
if (failed(parsedOperand.value()))
return failure();

dynamicOptions.push_back(operand);
optionsArray.push_back(
dynamicOptionMarker); // Placeholder for knowing where to
// inject the dynamic option-as-param.
return success();
}

StringAttr stringAttr;
OptionalParseResult parsedStringAttr =
parser.parseOptionalAttribute(stringAttr);
if (parsedStringAttr.has_value()) {
if (failed(parsedStringAttr.value()))
return failure();
optionsArray.push_back(stringAttr);
return success();
}

return std::nullopt;
};

OptionalParseResult parsedOptionsElement = parseOperandOrString();
while (parsedOptionsElement.has_value()) {
if (failed(parsedOptionsElement.value()))
return failure();
parsedOptionsElement = parseOperandOrString();
}

if (optionsArray.empty()) {
return parser.emitError(parser.getCurrentLocation())
<< "expected at least one option (either a string or a param)";
}
options = parser.getBuilder().getArrayAttr(optionsArray);
return success();
}

static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op, ArrayAttr options,
ValueRange dynamicOptions) {
size_t currentDynamicOptionIdx = 0;
for (auto [idx, optionAttr] : llvm::enumerate(options)) {
if (idx > 0)
printer << " "; // Interleave options separator.

if (isa<UnitAttr>(optionAttr))
printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
printer.printAttribute(strAttr);
else
llvm_unreachable("each option should be either a StringAttr or UnitAttr");
}
}

LogicalResult transform::ApplyRegisteredPassOp::verify() {
size_t numUnitsInOptions = 0;
for (Attribute optionsElement : getOptions()) {
if (isa<UnitAttr>(optionsElement))
numUnitsInOptions++;
else if (!isa<StringAttr>(optionsElement))
return emitOpError() << "expected each option to be either a StringAttr "
<< "or a UnitAttr, got " << optionsElement;
}

if (getDynamicOptions().size() != numUnitsInOptions)
return emitOpError()
<< "expected the same number of options passed as params as "
<< "UnitAttr elements in options ArrayAttr";

return success();
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading