Skip to content

Commit 59fae09

Browse files
committed
Address comments and switch to mirroring the linalg.match positional spec
1 parent 6ce9369 commit 59fae09

File tree

7 files changed

+295
-157
lines changed

7 files changed

+295
-157
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
288288
let results = (outs Optional<TransformParamTypeInterface>:$result);
289289
let assemblyFormat =
290290
"$operand_handle `[`"
291-
"custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
291+
"custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
292292
"`]` attr-dict `:` "
293293
"custom<SemiFunctionType>(type($operand_handle), type($result))";
294294

@@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
347347
(outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
348348
let assemblyFormat =
349349
"$operand_handle `[`"
350-
"custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
350+
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
351351
"`]` attr-dict "
352352
"`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
353353

mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
1010
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
1111

12+
#include <optional>
13+
#include <type_traits>
14+
1215
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1316
#include "mlir/IR/OpDefinition.h"
1417
#include "llvm/ADT/STLExtras.h"
15-
#include <optional>
16-
#include <type_traits>
1718

1819
namespace mlir {
1920
namespace transform {
@@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
168169
}
169170
};
170171

172+
//===----------------------------------------------------------------------===//
173+
// Printing/parsing for positional specification matchers
174+
//===----------------------------------------------------------------------===//
175+
176+
/// Parses a positional index specification for transform match operations.
177+
/// The following forms are accepted:
178+
///
179+
/// - `all`: sets `isAll` and returns;
180+
/// - comma-separated-integer-list: populates `rawDimList` with the values;
181+
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
182+
/// with the values and sets `isInverted`.
183+
ParseResult parseTransformMatchDims(OpAsmParser &parser,
184+
DenseI64ArrayAttr &rawDimList,
185+
UnitAttr &isInverted, UnitAttr &isAll);
186+
187+
/// Prints a positional index specification for transform match operations.
188+
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
189+
DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
190+
UnitAttr isAll);
191+
192+
//===----------------------------------------------------------------------===//
193+
// Utilities for positional specification matchers
194+
//===----------------------------------------------------------------------===//
195+
196+
/// Checks if the positional specification defined is valid and reports errors
197+
/// otherwise.
198+
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
199+
bool inverted, bool all);
200+
201+
/// Populates `result` with the positional identifiers relative to `maxNumber`.
202+
/// If `isAll` is set, the result will contain all numbers from `0` to
203+
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
204+
/// values from `rawList` are are interpreted as counting backwards from
205+
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
206+
/// numbers remain as is. If `isInverted` is set, populates `result` with those
207+
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
208+
/// `rawList`. If `rawList` contains values that are greater than or equal to
209+
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
210+
/// given location. `maxNumber` must be positive. If `rawList` contains
211+
/// duplicate numbers or numbers that become duplicate after negative value
212+
/// remapping, emits a silenceable error.
213+
DiagnosedSilenceableFailure
214+
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
215+
ArrayRef<int64_t> rawList, int64_t maxNumber,
216+
SmallVectorImpl<int64_t> &result);
217+
171218
} // namespace transform
172219
} // namespace mlir
173220

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -733,19 +733,31 @@ def GetOperandOp : TransformDialectOp<"get_operand",
733733
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
734734
let summary = "Get a handle to the operand(s) of the targeted op";
735735
let description = [{
736-
The handle defined by this Transform op corresponds to the Operands of the
737-
given `target` operation. Optionally `operand_number` can be specified to
738-
select a specific operand.
736+
The handle defined by this Transform op corresponds to the operands of the
737+
given `target` operation specified by the given set of positions. There are
738+
three possible modes:
739+
740+
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
741+
operands at the specified positions.
742+
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
743+
all operands except those at the given positions.
744+
- All, i.e. `%target[all]`. This will return all operands of the operation.
739745

740-
This transform fails silently if the targeted operation does not have enough
741-
operands. It reads the target handle and produces the result handle.
746+
This transform produces a silenceable failure if any of the operand indices
747+
exceeds the number of operands in the target. It reads the target handle and
748+
produces the result handle.
742749
}];
743750

744751
let arguments = (ins TransformHandleTypeInterface:$target,
745-
OptionalAttr<I64Attr>:$operand_number);
752+
DenseI64ArrayAttr:$raw_position_list,
753+
UnitAttr:$is_inverted,
754+
UnitAttr:$is_all);
746755
let results = (outs TransformValueHandleTypeInterface:$result);
747-
let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
748-
"functional-type(operands, results)";
756+
let assemblyFormat =
757+
"$target `[`"
758+
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
759+
"`]` attr-dict `:` functional-type(operands, results)";
760+
let hasVerifier = 1;
749761
}
750762

751763
def GetResultOp : TransformDialectOp<"get_result",
@@ -759,13 +771,32 @@ def GetResultOp : TransformDialectOp<"get_result",
759771

760772
This transform fails silently if the targeted operation does not have enough
761773
results. It reads the target handle and produces the result handle.
774+
775+
The handle defined by this Transform op corresponds to the results of the
776+
given `target` operation specified by the given set of positions. There are
777+
three possible modes:
778+
779+
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
780+
results at the specified positions.
781+
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
782+
all results except those at the given positions.
783+
- All, i.e. `%target[all]`. This will return all results of the operation.
784+
785+
This transform produces a silenceable failure if any of the result indices
786+
exceeds the number of results returned by the target. It reads the target
787+
handle and produces the result handle.
762788
}];
763789

764790
let arguments = (ins TransformHandleTypeInterface:$target,
765-
OptionalAttr<I64Attr>:$result_number);
791+
DenseI64ArrayAttr:$raw_position_list,
792+
UnitAttr:$is_inverted,
793+
UnitAttr:$is_all);
766794
let results = (outs TransformValueHandleTypeInterface:$result);
767-
let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
768-
"functional-type(operands, results)";
795+
let assemblyFormat =
796+
"$target `[`"
797+
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
798+
"`]` attr-dict `:` functional-type(operands, results)";
799+
let hasVerifier = 1;
769800
}
770801

771802
def GetTypeOp : TransformDialectOp<"get_type",

mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

Lines changed: 6 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -388,33 +388,6 @@ expandTargetSpecification(Location loc, bool isAll, bool isInverted,
388388
return DiagnosedSilenceableFailure::success();
389389
}
390390

391-
/// Checks if the positional specification defined is valid and reports errors
392-
/// otherwise.
393-
LogicalResult verifyStructuredTransformDimsOp(Operation *op,
394-
ArrayRef<int64_t> raw,
395-
bool inverted, bool all) {
396-
if (all) {
397-
if (inverted) {
398-
return op->emitOpError()
399-
<< "cannot request both 'all' and 'inverted' values in the list";
400-
}
401-
if (!raw.empty()) {
402-
return op->emitOpError()
403-
<< "cannot both request 'all' and specific values in the list";
404-
}
405-
}
406-
if (!all && raw.empty()) {
407-
return op->emitOpError() << "must request specific values in the list if "
408-
"'all' is not specified";
409-
}
410-
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
411-
auto *it = std::unique(rawVector.begin(), rawVector.end());
412-
if (it != rawVector.end())
413-
return op->emitOpError() << "expected the listed values to be unique";
414-
415-
return success();
416-
}
417-
418391
//===----------------------------------------------------------------------===//
419392
// MatchStructuredDimOp
420393
//===----------------------------------------------------------------------===//
@@ -475,8 +448,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
475448
return emitOpError() << "cannot request the same dimension to be both "
476449
"parallel and reduction";
477450
}
478-
return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
479-
getIsInverted(), getIsAll());
451+
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
452+
getIsInverted(), getIsAll());
480453
}
481454

482455
//===----------------------------------------------------------------------===//
@@ -592,8 +565,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
592565
LogicalResult transform::MatchStructuredInputOp::verify() {
593566
if (failed(verifyStructuredOperandOp(*this)))
594567
return failure();
595-
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
596-
getIsInverted(), getIsAll());
568+
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
569+
getIsInverted(), getIsAll());
597570
}
598571

599572
//===----------------------------------------------------------------------===//
@@ -665,8 +638,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
665638
LogicalResult transform::MatchStructuredInitOp::verify() {
666639
if (failed(verifyStructuredOperandOp(*this)))
667640
return failure();
668-
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
669-
getIsInverted(), getIsAll());
641+
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
642+
getIsInverted(), getIsAll());
670643
}
671644

672645
//===----------------------------------------------------------------------===//
@@ -793,78 +766,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
793766
build(builder, state, ValueRange());
794767
}
795768

796-
//===----------------------------------------------------------------------===//
797-
// Printing and parsing for structured match ops.
798-
//===----------------------------------------------------------------------===//
799-
800-
/// Keyword syntax for positional specification inversion.
801-
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
802-
803-
/// Keyword syntax for full inclusion in positional specification.
804-
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
805-
806-
/// Parses a positional specification for structured transform operations. The
807-
/// following forms are accepted:
808-
///
809-
/// - `all`: sets `isAll` and returns;
810-
/// - comma-separated-integer-list: populates `rawDimList` with the values;
811-
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
812-
/// with the values and sets `isInverted`.
813-
static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
814-
DenseI64ArrayAttr &rawDimList,
815-
UnitAttr &isInverted,
816-
UnitAttr &isAll) {
817-
Builder &builder = parser.getBuilder();
818-
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
819-
rawDimList = builder.getDenseI64ArrayAttr({});
820-
isInverted = nullptr;
821-
isAll = builder.getUnitAttr();
822-
return success();
823-
}
824-
825-
isAll = nullptr;
826-
isInverted = nullptr;
827-
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
828-
isInverted = builder.getUnitAttr();
829-
}
830-
831-
if (isInverted) {
832-
if (parser.parseLParen().failed())
833-
return failure();
834-
}
835-
836-
SmallVector<int64_t> values;
837-
ParseResult listResult = parser.parseCommaSeparatedList(
838-
[&]() { return parser.parseInteger(values.emplace_back()); });
839-
if (listResult.failed())
840-
return failure();
841-
842-
rawDimList = builder.getDenseI64ArrayAttr(values);
843-
844-
if (isInverted) {
845-
if (parser.parseRParen().failed())
846-
return failure();
847-
}
848-
return success();
849-
}
850-
851-
/// Prints a positional specification for structured transform operations.
852-
static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
853-
DenseI64ArrayAttr rawDimList,
854-
UnitAttr isInverted, UnitAttr isAll) {
855-
if (isAll) {
856-
printer << kDimAllKeyword;
857-
return;
858-
}
859-
if (isInverted) {
860-
printer << kDimExceptKeyword << "(";
861-
}
862-
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
863-
[&](int64_t value) { printer << value; });
864-
if (isInverted) {
865-
printer << ")";
866-
}
867-
}
868-
869769
#define GET_OP_CLASSES
870770
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"

0 commit comments

Comments
 (0)