Skip to content

Commit 5caab8b

Browse files
authored
[mlir][transform] Add transform.get_operand op (#78397)
Similar to `transform.get_result`, except it returns a handle to the operand indicated by a positional specification, same as is defined for the linalg match ops. Additionally updates `get_result` to take the same positional specification. This makes the use case of wanting to get all of the results of an operation easier by no longer requiring the user to reconstruct the list of results one-by-one.
1 parent e90e43f commit 5caab8b

File tree

8 files changed

+398
-191
lines changed

8 files changed

+398
-191
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
288288
let results = (outs Optional<TransformParamTypeInterface>:$result);
289289
let assemblyFormat =
290290
"$operand_handle `[`"
291-
"custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
291+
"custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
292292
"`]` attr-dict `:` "
293293
"custom<SemiFunctionType>(type($operand_handle), type($result))";
294294

@@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
347347
(outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
348348
let assemblyFormat =
349349
"$operand_handle `[`"
350-
"custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
350+
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
351351
"`]` attr-dict "
352352
"`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
353353

mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
1010
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
1111

12+
#include <optional>
13+
#include <type_traits>
14+
1215
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1316
#include "mlir/IR/OpDefinition.h"
1417
#include "llvm/ADT/STLExtras.h"
15-
#include <optional>
16-
#include <type_traits>
1718

1819
namespace mlir {
1920
namespace transform {
@@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
168169
}
169170
};
170171

172+
//===----------------------------------------------------------------------===//
173+
// Printing/parsing for positional specification matchers
174+
//===----------------------------------------------------------------------===//
175+
176+
/// Parses a positional index specification for transform match operations.
177+
/// The following forms are accepted:
178+
///
179+
/// - `all`: sets `isAll` and returns;
180+
/// - comma-separated-integer-list: populates `rawDimList` with the values;
181+
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
182+
/// with the values and sets `isInverted`.
183+
ParseResult parseTransformMatchDims(OpAsmParser &parser,
184+
DenseI64ArrayAttr &rawDimList,
185+
UnitAttr &isInverted, UnitAttr &isAll);
186+
187+
/// Prints a positional index specification for transform match operations.
188+
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
189+
DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
190+
UnitAttr isAll);
191+
192+
//===----------------------------------------------------------------------===//
193+
// Utilities for positional specification matchers
194+
//===----------------------------------------------------------------------===//
195+
196+
/// Checks if the positional specification defined is valid and reports errors
197+
/// otherwise.
198+
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
199+
bool inverted, bool all);
200+
201+
/// Populates `result` with the positional identifiers relative to `maxNumber`.
202+
/// If `isAll` is set, the result will contain all numbers from `0` to
203+
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
204+
/// values from `rawList` are are interpreted as counting backwards from
205+
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
206+
/// numbers remain as is. If `isInverted` is set, populates `result` with those
207+
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
208+
/// `rawList`. If `rawList` contains values that are greater than or equal to
209+
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
210+
/// given location. `maxNumber` must be positive. If `rawList` contains
211+
/// duplicate numbers or numbers that become duplicate after negative value
212+
/// remapping, emits a silenceable error.
213+
DiagnosedSilenceableFailure
214+
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
215+
ArrayRef<int64_t> rawList, int64_t maxNumber,
216+
SmallVectorImpl<int64_t> &result);
217+
171218
} // namespace transform
172219
} // namespace mlir
173220

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -728,24 +728,75 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
728728
"functional-type(operands, results)";
729729
}
730730

731+
def GetOperandOp : TransformDialectOp<"get_operand",
732+
[DeclareOpInterfaceMethods<TransformOpInterface>,
733+
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
734+
let summary = "Get a handle to the operand(s) of the targeted op";
735+
let description = [{
736+
The handle defined by this Transform op corresponds to the operands of the
737+
given `target` operation specified by the given set of positions. There are
738+
three possible modes:
739+
740+
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
741+
operands at the specified positions.
742+
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
743+
all operands except those at the given positions.
744+
- All, i.e. `%target[all]`. This will return all operands of the operation.
745+
746+
This transform produces a silenceable failure if any of the operand indices
747+
exceeds the number of operands in the target. It reads the target handle and
748+
produces the result handle.
749+
}];
750+
751+
let arguments = (ins TransformHandleTypeInterface:$target,
752+
DenseI64ArrayAttr:$raw_position_list,
753+
UnitAttr:$is_inverted,
754+
UnitAttr:$is_all);
755+
let results = (outs TransformValueHandleTypeInterface:$result);
756+
let assemblyFormat =
757+
"$target `[`"
758+
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
759+
"`]` attr-dict `:` functional-type(operands, results)";
760+
let hasVerifier = 1;
761+
}
762+
731763
def GetResultOp : TransformDialectOp<"get_result",
732764
[DeclareOpInterfaceMethods<TransformOpInterface>,
733765
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
734-
let summary = "Get handle to the a result of the targeted op";
766+
let summary = "Get a handle to the result(s) of the targeted op";
735767
let description = [{
736-
The handle defined by this Transform op corresponds to the OpResult with
737-
`result_number` that is defined by the given `target` operation.
738-
739-
This transform produces a silenceable failure if the targeted operation
740-
does not have enough results. It reads the target handle and produces the
741-
result handle.
768+
The handle defined by this Transform op correspond to the OpResults of the
769+
given `target` operation. Optionally `result_number` can be specified to
770+
select a specific result.
771+
772+
This transform fails silently if the targeted operation does not have enough
773+
results. It reads the target handle and produces the result handle.
774+
775+
The handle defined by this Transform op corresponds to the results of the
776+
given `target` operation specified by the given set of positions. There are
777+
three possible modes:
778+
779+
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
780+
results at the specified positions.
781+
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
782+
all results except those at the given positions.
783+
- All, i.e. `%target[all]`. This will return all results of the operation.
784+
785+
This transform produces a silenceable failure if any of the result indices
786+
exceeds the number of results returned by the target. It reads the target
787+
handle and produces the result handle.
742788
}];
743789

744790
let arguments = (ins TransformHandleTypeInterface:$target,
745-
I64Attr:$result_number);
791+
DenseI64ArrayAttr:$raw_position_list,
792+
UnitAttr:$is_inverted,
793+
UnitAttr:$is_all);
746794
let results = (outs TransformValueHandleTypeInterface:$result);
747-
let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
748-
"functional-type(operands, results)";
795+
let assemblyFormat =
796+
"$target `[`"
797+
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
798+
"`]` attr-dict `:` functional-type(operands, results)";
799+
let hasVerifier = 1;
749800
}
750801

751802
def GetTypeOp : TransformDialectOp<"get_type",

mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

Lines changed: 6 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
330330
return DiagnosedSilenceableFailure::success();
331331
}
332332

333-
/// Populates `result` with the positional identifiers relative to `maxNumber`.
334-
/// If `isAll` is set, the result will contain all numbers from `0` to
335-
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
336-
/// values from `rawList` are are interpreted as counting backwards from
337-
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
338-
/// numbers remain as is. If `isInverted` is set, populates `result` with those
339-
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
340-
/// `rawList`. If `rawList` contains values that are greater than or equal to
341-
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
342-
/// given location. `maxNumber` must be positive. If `rawList` contains
343-
/// duplicate numbers or numbers that become duplicate after negative value
344-
/// remapping, emits a silenceable error.
345-
static DiagnosedSilenceableFailure
346-
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
347-
ArrayRef<int64_t> rawList, int64_t maxNumber,
348-
SmallVectorImpl<int64_t> &result) {
349-
assert(maxNumber > 0 && "expected size to be positive");
350-
assert(!(isAll && isInverted) && "cannot invert all");
351-
if (isAll) {
352-
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
353-
return DiagnosedSilenceableFailure::success();
354-
}
355-
356-
SmallVector<int64_t> expanded;
357-
llvm::SmallDenseSet<int64_t> visited;
358-
expanded.reserve(rawList.size());
359-
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
360-
for (int64_t raw : rawList) {
361-
int64_t updated = raw < 0 ? maxNumber + raw : raw;
362-
if (updated >= maxNumber) {
363-
return emitSilenceableFailure(loc)
364-
<< "position overflow " << updated << " (updated from " << raw
365-
<< ") for maximum " << maxNumber;
366-
}
367-
if (updated < 0) {
368-
return emitSilenceableFailure(loc) << "position underflow " << updated
369-
<< " (updated from " << raw << ")";
370-
}
371-
if (!visited.insert(updated).second) {
372-
return emitSilenceableFailure(loc) << "repeated position " << updated
373-
<< " (updated from " << raw << ")";
374-
}
375-
target.push_back(updated);
376-
}
377-
378-
if (!isInverted)
379-
return DiagnosedSilenceableFailure::success();
380-
381-
result.reserve(result.size() + (maxNumber - expanded.size()));
382-
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
383-
if (llvm::is_contained(expanded, candidate))
384-
continue;
385-
result.push_back(candidate);
386-
}
387-
388-
return DiagnosedSilenceableFailure::success();
389-
}
390-
391-
/// Checks if the positional specification defined is valid and reports errors
392-
/// otherwise.
393-
LogicalResult verifyStructuredTransformDimsOp(Operation *op,
394-
ArrayRef<int64_t> raw,
395-
bool inverted, bool all) {
396-
if (all) {
397-
if (inverted) {
398-
return op->emitOpError()
399-
<< "cannot request both 'all' and 'inverted' values in the list";
400-
}
401-
if (!raw.empty()) {
402-
return op->emitOpError()
403-
<< "cannot both request 'all' and specific values in the list";
404-
}
405-
}
406-
if (!all && raw.empty()) {
407-
return op->emitOpError() << "must request specific values in the list if "
408-
"'all' is not specified";
409-
}
410-
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
411-
auto *it = std::unique(rawVector.begin(), rawVector.end());
412-
if (it != rawVector.end())
413-
return op->emitOpError() << "expected the listed values to be unique";
414-
415-
return success();
416-
}
417-
418333
//===----------------------------------------------------------------------===//
419334
// MatchStructuredDimOp
420335
//===----------------------------------------------------------------------===//
@@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
475390
return emitOpError() << "cannot request the same dimension to be both "
476391
"parallel and reduction";
477392
}
478-
return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
479-
getIsInverted(), getIsAll());
393+
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
394+
getIsInverted(), getIsAll());
480395
}
481396

482397
//===----------------------------------------------------------------------===//
@@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
592507
LogicalResult transform::MatchStructuredInputOp::verify() {
593508
if (failed(verifyStructuredOperandOp(*this)))
594509
return failure();
595-
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
596-
getIsInverted(), getIsAll());
510+
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
511+
getIsInverted(), getIsAll());
597512
}
598513

599514
//===----------------------------------------------------------------------===//
@@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
665580
LogicalResult transform::MatchStructuredInitOp::verify() {
666581
if (failed(verifyStructuredOperandOp(*this)))
667582
return failure();
668-
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
669-
getIsInverted(), getIsAll());
583+
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
584+
getIsInverted(), getIsAll());
670585
}
671586

672587
//===----------------------------------------------------------------------===//
@@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
793708
build(builder, state, ValueRange());
794709
}
795710

796-
//===----------------------------------------------------------------------===//
797-
// Printing and parsing for structured match ops.
798-
//===----------------------------------------------------------------------===//
799-
800-
/// Keyword syntax for positional specification inversion.
801-
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
802-
803-
/// Keyword syntax for full inclusion in positional specification.
804-
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
805-
806-
/// Parses a positional specification for structured transform operations. The
807-
/// following forms are accepted:
808-
///
809-
/// - `all`: sets `isAll` and returns;
810-
/// - comma-separated-integer-list: populates `rawDimList` with the values;
811-
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
812-
/// with the values and sets `isInverted`.
813-
static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
814-
DenseI64ArrayAttr &rawDimList,
815-
UnitAttr &isInverted,
816-
UnitAttr &isAll) {
817-
Builder &builder = parser.getBuilder();
818-
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
819-
rawDimList = builder.getDenseI64ArrayAttr({});
820-
isInverted = nullptr;
821-
isAll = builder.getUnitAttr();
822-
return success();
823-
}
824-
825-
isAll = nullptr;
826-
isInverted = nullptr;
827-
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
828-
isInverted = builder.getUnitAttr();
829-
}
830-
831-
if (isInverted) {
832-
if (parser.parseLParen().failed())
833-
return failure();
834-
}
835-
836-
SmallVector<int64_t> values;
837-
ParseResult listResult = parser.parseCommaSeparatedList(
838-
[&]() { return parser.parseInteger(values.emplace_back()); });
839-
if (listResult.failed())
840-
return failure();
841-
842-
rawDimList = builder.getDenseI64ArrayAttr(values);
843-
844-
if (isInverted) {
845-
if (parser.parseRParen().failed())
846-
return failure();
847-
}
848-
return success();
849-
}
850-
851-
/// Prints a positional specification for structured transform operations.
852-
static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
853-
DenseI64ArrayAttr rawDimList,
854-
UnitAttr isInverted, UnitAttr isAll) {
855-
if (isAll) {
856-
printer << kDimAllKeyword;
857-
return;
858-
}
859-
if (isInverted) {
860-
printer << kDimExceptKeyword << "(";
861-
}
862-
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
863-
[&](int64_t value) { printer << value; });
864-
if (isInverted) {
865-
printer << ")";
866-
}
867-
}
868-
869711
#define GET_OP_CLASSES
870712
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"

0 commit comments

Comments
 (0)