Skip to content

[mlir][transform] Add transform.get_operand op #78397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
let results = (outs Optional<TransformParamTypeInterface>:$result);
let assemblyFormat =
"$operand_handle `[`"
"custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
"custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
"`]` attr-dict `:` "
"custom<SemiFunctionType>(type($operand_handle), type($result))";

Expand Down Expand Up @@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
(outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
let assemblyFormat =
"$operand_handle `[`"
"custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict "
"`:` custom<SemiFunctionType>(type($operand_handle), type($result))";

Expand Down
51 changes: 49 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H

#include <optional>
#include <type_traits>

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
#include <type_traits>

namespace mlir {
namespace transform {
Expand Down Expand Up @@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
}
};

//===----------------------------------------------------------------------===//
// Printing/parsing for positional specification matchers
//===----------------------------------------------------------------------===//

/// Parses a positional index specification for transform match operations.
/// The following forms are accepted:
///
/// - `all`: sets `isAll` and returns;
/// - comma-separated-integer-list: populates `rawDimList` with the values;
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
/// with the values and sets `isInverted`.
ParseResult parseTransformMatchDims(OpAsmParser &parser,
DenseI64ArrayAttr &rawDimList,
UnitAttr &isInverted, UnitAttr &isAll);

/// Prints a positional index specification for transform match operations.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
UnitAttr isAll);

//===----------------------------------------------------------------------===//
// Utilities for positional specification matchers
//===----------------------------------------------------------------------===//

/// Checks if the positional specification defined is valid and reports errors
/// otherwise.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
bool inverted, bool all);

/// Populates `result` with the positional identifiers relative to `maxNumber`.
/// If `isAll` is set, the result will contain all numbers from `0` to
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
/// values from `rawList` are are interpreted as counting backwards from
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
/// numbers remain as is. If `isInverted` is set, populates `result` with those
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
/// `rawList`. If `rawList` contains values that are greater than or equal to
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
/// given location. `maxNumber` must be positive. If `rawList` contains
/// duplicate numbers or numbers that become duplicate after negative value
/// remapping, emits a silenceable error.
DiagnosedSilenceableFailure
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
ArrayRef<int64_t> rawList, int64_t maxNumber,
SmallVectorImpl<int64_t> &result);

} // namespace transform
} // namespace mlir

Expand Down
71 changes: 61 additions & 10 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -728,24 +728,75 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
"functional-type(operands, results)";
}

def GetOperandOp : TransformDialectOp<"get_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get a handle to the operand(s) of the targeted op";
let description = [{
The handle defined by this Transform op corresponds to the operands of the
given `target` operation specified by the given set of positions. There are
three possible modes:

- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
operands at the specified positions.
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
all operands except those at the given positions.
- All, i.e. `%target[all]`. This will return all operands of the operation.

This transform produces a silenceable failure if any of the operand indices
exceeds the number of operands in the target. It reads the target handle and
produces the result handle.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
DenseI64ArrayAttr:$raw_position_list,
UnitAttr:$is_inverted,
UnitAttr:$is_all);
let results = (outs TransformValueHandleTypeInterface:$result);
let assemblyFormat =
"$target `[`"
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict `:` functional-type(operands, results)";
let hasVerifier = 1;
}

def GetResultOp : TransformDialectOp<"get_result",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
let summary = "Get handle to the a result of the targeted op";
let summary = "Get a handle to the result(s) of the targeted op";
let description = [{
The handle defined by this Transform op corresponds to the OpResult with
`result_number` that is defined by the given `target` operation.

This transform produces a silenceable failure if the targeted operation
does not have enough results. It reads the target handle and produces the
result handle.
The handle defined by this Transform op correspond to the OpResults of the
given `target` operation. Optionally `result_number` can be specified to
select a specific result.

This transform fails silently if the targeted operation does not have enough
results. It reads the target handle and produces the result handle.

The handle defined by this Transform op corresponds to the results of the
given `target` operation specified by the given set of positions. There are
three possible modes:

- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
results at the specified positions.
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
all results except those at the given positions.
- All, i.e. `%target[all]`. This will return all results of the operation.

This transform produces a silenceable failure if any of the result indices
exceeds the number of results returned by the target. It reads the target
handle and produces the result handle.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$result_number);
DenseI64ArrayAttr:$raw_position_list,
UnitAttr:$is_inverted,
UnitAttr:$is_all);
let results = (outs TransformValueHandleTypeInterface:$result);
let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
"functional-type(operands, results)";
let assemblyFormat =
"$target `[`"
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict `:` functional-type(operands, results)";
let hasVerifier = 1;
}

def GetTypeOp : TransformDialectOp<"get_type",
Expand Down
170 changes: 6 additions & 164 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
return DiagnosedSilenceableFailure::success();
}

/// Populates `result` with the positional identifiers relative to `maxNumber`.
/// If `isAll` is set, the result will contain all numbers from `0` to
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
/// values from `rawList` are are interpreted as counting backwards from
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
/// numbers remain as is. If `isInverted` is set, populates `result` with those
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
/// `rawList`. If `rawList` contains values that are greater than or equal to
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
/// given location. `maxNumber` must be positive. If `rawList` contains
/// duplicate numbers or numbers that become duplicate after negative value
/// remapping, emits a silenceable error.
static DiagnosedSilenceableFailure
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
ArrayRef<int64_t> rawList, int64_t maxNumber,
SmallVectorImpl<int64_t> &result) {
assert(maxNumber > 0 && "expected size to be positive");
assert(!(isAll && isInverted) && "cannot invert all");
if (isAll) {
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
return DiagnosedSilenceableFailure::success();
}

SmallVector<int64_t> expanded;
llvm::SmallDenseSet<int64_t> visited;
expanded.reserve(rawList.size());
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
for (int64_t raw : rawList) {
int64_t updated = raw < 0 ? maxNumber + raw : raw;
if (updated >= maxNumber) {
return emitSilenceableFailure(loc)
<< "position overflow " << updated << " (updated from " << raw
<< ") for maximum " << maxNumber;
}
if (updated < 0) {
return emitSilenceableFailure(loc) << "position underflow " << updated
<< " (updated from " << raw << ")";
}
if (!visited.insert(updated).second) {
return emitSilenceableFailure(loc) << "repeated position " << updated
<< " (updated from " << raw << ")";
}
target.push_back(updated);
}

if (!isInverted)
return DiagnosedSilenceableFailure::success();

result.reserve(result.size() + (maxNumber - expanded.size()));
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
if (llvm::is_contained(expanded, candidate))
continue;
result.push_back(candidate);
}

return DiagnosedSilenceableFailure::success();
}

/// Checks if the positional specification defined is valid and reports errors
/// otherwise.
LogicalResult verifyStructuredTransformDimsOp(Operation *op,
ArrayRef<int64_t> raw,
bool inverted, bool all) {
if (all) {
if (inverted) {
return op->emitOpError()
<< "cannot request both 'all' and 'inverted' values in the list";
}
if (!raw.empty()) {
return op->emitOpError()
<< "cannot both request 'all' and specific values in the list";
}
}
if (!all && raw.empty()) {
return op->emitOpError() << "must request specific values in the list if "
"'all' is not specified";
}
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
auto *it = std::unique(rawVector.begin(), rawVector.end());
if (it != rawVector.end())
return op->emitOpError() << "expected the listed values to be unique";

return success();
}

//===----------------------------------------------------------------------===//
// MatchStructuredDimOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
return emitOpError() << "cannot request the same dimension to be both "
"parallel and reduction";
}
return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
getIsInverted(), getIsAll());
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
getIsInverted(), getIsAll());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
LogicalResult transform::MatchStructuredInputOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
LogicalResult transform::MatchStructuredInitOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
build(builder, state, ValueRange());
}

//===----------------------------------------------------------------------===//
// Printing and parsing for structured match ops.
//===----------------------------------------------------------------------===//

/// Keyword syntax for positional specification inversion.
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";

/// Keyword syntax for full inclusion in positional specification.
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";

/// Parses a positional specification for structured transform operations. The
/// following forms are accepted:
///
/// - `all`: sets `isAll` and returns;
/// - comma-separated-integer-list: populates `rawDimList` with the values;
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
/// with the values and sets `isInverted`.
static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
DenseI64ArrayAttr &rawDimList,
UnitAttr &isInverted,
UnitAttr &isAll) {
Builder &builder = parser.getBuilder();
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
rawDimList = builder.getDenseI64ArrayAttr({});
isInverted = nullptr;
isAll = builder.getUnitAttr();
return success();
}

isAll = nullptr;
isInverted = nullptr;
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
isInverted = builder.getUnitAttr();
}

if (isInverted) {
if (parser.parseLParen().failed())
return failure();
}

SmallVector<int64_t> values;
ParseResult listResult = parser.parseCommaSeparatedList(
[&]() { return parser.parseInteger(values.emplace_back()); });
if (listResult.failed())
return failure();

rawDimList = builder.getDenseI64ArrayAttr(values);

if (isInverted) {
if (parser.parseRParen().failed())
return failure();
}
return success();
}

/// Prints a positional specification for structured transform operations.
static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
DenseI64ArrayAttr rawDimList,
UnitAttr isInverted, UnitAttr isAll) {
if (isAll) {
printer << kDimAllKeyword;
return;
}
if (isInverted) {
printer << kDimExceptKeyword << "(";
}
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
[&](int64_t value) { printer << value; });
if (isInverted) {
printer << ")";
}
}

#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
Loading