Skip to content

Commit 6ce9369

Browse files
committed
[mlir][transform] Add transform.get_operand op
Similar to `transform.get_result`, except it returns a handle to the operand indicated by `operand_number`, or all operands if no index is given. Additionally updates `get_result` to make the `result_number` optional. This makes the use case of wanting to get all of the results of an operation easier by no longer requiring the user to reconstruct the list of results one-by-one.
1 parent 90bdf76 commit 6ce9369

File tree

3 files changed

+138
-10
lines changed

3 files changed

+138
-10
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -728,23 +728,43 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
728728
"functional-type(operands, results)";
729729
}
730730

731+
def GetOperandOp : TransformDialectOp<"get_operand",
732+
[DeclareOpInterfaceMethods<TransformOpInterface>,
733+
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
734+
let summary = "Get a handle to the operand(s) of the targeted op";
735+
let description = [{
736+
The handle defined by this Transform op corresponds to the Operands of the
737+
given `target` operation. Optionally `operand_number` can be specified to
738+
select a specific operand.
739+
740+
This transform fails silently if the targeted operation does not have enough
741+
operands. It reads the target handle and produces the result handle.
742+
}];
743+
744+
let arguments = (ins TransformHandleTypeInterface:$target,
745+
OptionalAttr<I64Attr>:$operand_number);
746+
let results = (outs TransformValueHandleTypeInterface:$result);
747+
let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
748+
"functional-type(operands, results)";
749+
}
750+
731751
def GetResultOp : TransformDialectOp<"get_result",
732752
[DeclareOpInterfaceMethods<TransformOpInterface>,
733753
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
734-
let summary = "Get handle to the a result of the targeted op";
754+
let summary = "Get a handle to the result(s) of the targeted op";
735755
let description = [{
736-
The handle defined by this Transform op corresponds to the OpResult with
737-
`result_number` that is defined by the given `target` operation.
738-
739-
This transform produces a silenceable failure if the targeted operation
740-
does not have enough results. It reads the target handle and produces the
741-
result handle.
756+
The handle defined by this Transform op correspond to the OpResults of the
757+
given `target` operation. Optionally `result_number` can be specified to
758+
select a specific result.
759+
760+
This transform fails silently if the targeted operation does not have enough
761+
results. It reads the target handle and produces the result handle.
742762
}];
743763

744764
let arguments = (ins TransformHandleTypeInterface:$target,
745-
I64Attr:$result_number);
765+
OptionalAttr<I64Attr>:$result_number);
746766
let results = (outs TransformValueHandleTypeInterface:$result);
747-
let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
767+
let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
748768
"functional-type(operands, results)";
749769
}
750770

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,35 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
14641464
return DiagnosedSilenceableFailure::success();
14651465
}
14661466

1467+
//===----------------------------------------------------------------------===//
1468+
// GetOperandOp
1469+
//===----------------------------------------------------------------------===//
1470+
1471+
DiagnosedSilenceableFailure
1472+
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1473+
transform::TransformResults &results,
1474+
transform::TransformState &state) {
1475+
std::optional<int64_t> maybeOperandNumber = getOperandNumber();
1476+
SmallVector<Value> operands;
1477+
for (Operation *target : state.getPayloadOps(getTarget())) {
1478+
if (!maybeOperandNumber) {
1479+
for (Value operand : target->getOperands())
1480+
operands.push_back(operand);
1481+
continue;
1482+
}
1483+
int64_t operandNumber = *maybeOperandNumber;
1484+
if (operandNumber >= target->getNumOperands()) {
1485+
DiagnosedSilenceableFailure diag =
1486+
emitSilenceableError() << "targeted op does not have enough operands";
1487+
diag.attachNote(target->getLoc()) << "target op";
1488+
return diag;
1489+
}
1490+
operands.push_back(target->getOperand(operandNumber));
1491+
}
1492+
results.setValues(llvm::cast<OpResult>(getResult()), operands);
1493+
return DiagnosedSilenceableFailure::success();
1494+
}
1495+
14671496
//===----------------------------------------------------------------------===//
14681497
// GetResultOp
14691498
//===----------------------------------------------------------------------===//
@@ -1472,9 +1501,15 @@ DiagnosedSilenceableFailure
14721501
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
14731502
transform::TransformResults &results,
14741503
transform::TransformState &state) {
1475-
int64_t resultNumber = getResultNumber();
1504+
std::optional<int64_t> maybeResultNumber = getResultNumber();
14761505
SmallVector<Value> opResults;
14771506
for (Operation *target : state.getPayloadOps(getTarget())) {
1507+
if (!maybeResultNumber) {
1508+
for (Value result : target->getResults())
1509+
opResults.push_back(result);
1510+
continue;
1511+
}
1512+
int64_t resultNumber = *maybeResultNumber;
14781513
if (resultNumber >= target->getNumResults()) {
14791514
DiagnosedSilenceableFailure diag =
14801515
emitSilenceableError() << "targeted op does not have enough results";

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,60 @@ module attributes {transform.with_named_sequence} {
14831483

14841484
// -----
14851485

1486+
// expected-remark @below {{addi operand}}
1487+
// expected-note @below {{value handle points to a block argument #0}}
1488+
func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
1489+
%r = arith.addi %arg0, %arg1 : index
1490+
return %r : index
1491+
}
1492+
1493+
module attributes {transform.with_named_sequence} {
1494+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
1495+
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1496+
%operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
1497+
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
1498+
transform.yield
1499+
}
1500+
}
1501+
1502+
// -----
1503+
1504+
func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
1505+
// expected-note @below {{target op}}
1506+
%r = arith.addi %arg0, %arg1 : index
1507+
return %r : index
1508+
}
1509+
1510+
module attributes {transform.with_named_sequence} {
1511+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
1512+
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1513+
// expected-error @below {{targeted op does not have enough operands}}
1514+
%operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
1515+
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
1516+
transform.yield
1517+
}
1518+
}
1519+
1520+
// -----
1521+
1522+
func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
1523+
%r = arith.addi %arg0, %arg1 : index
1524+
return %r : index
1525+
}
1526+
1527+
module attributes {transform.with_named_sequence} {
1528+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
1529+
%addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1530+
%operands = transform.get_operand %addui : (!transform.any_op) -> !transform.any_value
1531+
%p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
1532+
// expected-remark @below {{2}}
1533+
transform.debug.emit_param_as_remark %p : !transform.param<i64>
1534+
transform.yield
1535+
}
1536+
}
1537+
1538+
// -----
1539+
14861540
func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
14871541
// expected-remark @below {{addi result}}
14881542
// expected-note @below {{value handle points to an op result #0}}
@@ -1537,6 +1591,25 @@ module attributes {transform.with_named_sequence} {
15371591

15381592
// -----
15391593

1594+
func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
1595+
// expected-remark @below {{matched bool}}
1596+
%r, %b = arith.addui_extended %arg0, %arg1 : index, i1
1597+
return %r, %b : index, i1
1598+
}
1599+
1600+
module attributes {transform.with_named_sequence} {
1601+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
1602+
%addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1603+
%results = transform.get_result %addui : (!transform.any_op) -> !transform.any_value
1604+
%adds = transform.get_defining_op %results : (!transform.any_value) -> !transform.any_op
1605+
%_, %add_again = transform.split_handle %adds : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1606+
transform.debug.emit_remark_at %add_again, "matched bool" : !transform.any_op
1607+
transform.yield
1608+
}
1609+
}
1610+
1611+
// -----
1612+
15401613
// expected-note @below {{target value}}
15411614
func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
15421615
%r = arith.addi %arg0, %arg1 : index

0 commit comments

Comments
 (0)