Skip to content

Commit 2fe4d90

Browse files
committed
[mlir] make structured transform ops use types
Types have been introduced a while ago and provide for better readability and transform-time verification. Use them in the ops from the structured transform dialect extension. In most cases, the types are appended as trailing functional types or a derived format of the functional type that allows for an empty right hand size without the annoying `-> ()` syntax (similarly to `func.func` declaration that may omit the arrow). When handles are used inside mixed static/dynamic lists, such as tile sizes, types of those handles follow them immediately as in `sizes [%0 : !transform.any_value, 42]`. This allows for better readability than matching the trailing type. Update code to remove hardcoded PDL dependencies and expunge PDL from structured transform op code. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D144515
1 parent af0121f commit 2fe4d90

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1304
-1125
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 107 additions & 80 deletions
Large diffs are not rendered by default.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===- Syntax.h - Custom syntax for Linalg transform ops --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
10+
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
11+
12+
#include "mlir/Support/LLVM.h"
13+
14+
namespace mlir {
15+
class ParseResult;
16+
class OpAsmParser;
17+
class OpAsmPrinter;
18+
class Type;
19+
class TypeRange;
20+
class Operation;
21+
22+
/// Parses a single non-function type or a function type with at least one
23+
/// argument. This allows for the following syntax:
24+
///
25+
/// - type: just the argument type;
26+
/// - `(` type `)` `->` type: one argument and one result type;
27+
/// - `(` type `)` `->` `(` comma-separated-type-list `)`: one argument and
28+
/// multiple result types.
29+
///
30+
/// Unlike FunctionType, this allows and requires one to omit the parens around
31+
/// the argument type in absence of result types, and does not accept the
32+
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
33+
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
34+
Type &resultType);
35+
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
36+
SmallVectorImpl<Type> &resultTypes);
37+
38+
/// Prints argument and result types in a syntax similar to that of FunctionType
39+
/// but allowing and requiring one to omit the parens around the argument type
40+
/// in absence of result types, and without the trailing `-> ()`.
41+
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
42+
Type argumentType, TypeRange resultType);
43+
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
44+
Type argumentType, Type resultType);
45+
} // namespace mlir
46+
47+
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H

mlir/include/mlir/Dialect/Transform/Utils/Utils.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,32 @@ class TransformState;
2222

2323
/// Printer hook for custom directive in assemblyFormat.
2424
///
25-
/// custom<PackedOrDynamicIndexList>($packed, $values, $integers)
25+
/// custom<PackedOrDynamicIndexList>($packed, type($packed), $values,
26+
/// type($values), $integers)
2627
///
2728
/// where `values` are variadic Index values, `integers` is an `I64ArrayAttr`
2829
/// and `packed` is a single transform dialect handle who's mapped payload ops
2930
/// have a single Index result and represent the index list. Either `packed`
3031
/// or the other two parameters may be specified.
3132
///
3233
/// This allows idiomatic printing of mixed value and integer attributes in a
33-
/// list or with a single handle. E.g., `[%arg0, 7, 42, %arg42]` or just `%h`.
34+
/// list or with a single handle. E.g., `[%arg0 : !transform.any_op, 7, 42,
35+
/// %arg42 : !transform.param<i64>]` or just `%h : !transform.any_op`.
3436
void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
35-
Value packed, OperandRange values,
37+
Value packed, Type packedType,
38+
OperandRange values, TypeRange valueTypes,
3639
ArrayRef<int64_t> integers);
3740

38-
/// Pasrer hook for custom directive in assemblyFormat.
41+
/// Parser hook for custom directive in assemblyFormat.
3942
///
40-
/// custom<PackedOrDynamicIndexList>($packed, $values, $integers)
43+
/// custom<PackedOrDynamicIndexList>($packed, type($packed), $values,
44+
/// type($values), $integers)
4145
///
4246
/// See `printPackedOrDynamicIndexList` for details.
4347
ParseResult parsePackedOrDynamicIndexList(
4448
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
45-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
46-
DenseI64ArrayAttr &integers);
49+
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
50+
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers);
4751
} // namespace transform
4852
} // namespace mlir
4953

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,49 @@ namespace mlir {
4242
/// Printer hook for custom directive in assemblyFormat.
4343
///
4444
/// custom<DynamicIndexList>($values, $integers)
45+
/// custom<DynamicIndexList>($values, $integers, type($values))
4546
///
46-
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
47+
/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
4748
/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
48-
/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
49-
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
50-
/// `[%arg0, 7, 42, %arg42]`.
49+
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
50+
/// is non-empty, it is expected to contain as many elements as `values`
51+
/// indicating their types. This allows idiomatic printing of mixed value and
52+
/// integer attributes in a list. E.g.
53+
/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
5154
void printDynamicIndexList(
5255
OpAsmPrinter &printer, Operation *op, OperandRange values,
53-
ArrayRef<int64_t> integers,
56+
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
5457
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
5558

56-
/// Pasrer hook for custom directive in assemblyFormat.
59+
/// Parser hook for custom directive in assemblyFormat.
5760
///
5861
/// custom<DynamicIndexList>($values, $integers)
62+
/// custom<DynamicIndexList>($values, $integers, type($values))
5963
///
60-
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
64+
/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
6165
/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
6266
/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
63-
/// `dynVal` encodes the position of SSA values. Add the parsed SSA values
64-
/// to `values` in-order.
65-
//
66-
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
67-
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
67+
/// `kDynamic` encodes the position of SSA values. Add the parsed SSA values
68+
/// to `values` in-order. If `valueTypes` is non-null, fill it with types
69+
/// corresponding to values; otherwise the caller must handle the types.
70+
///
71+
/// E.g. after parsing "[%arg0 : index, 7, 42, %arg42 : i32]":
72+
/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
73+
/// `kDynamic`]"
6874
/// 2. `ssa` is filled with "[%arg0, %arg1]".
6975
ParseResult parseDynamicIndexList(
7076
OpAsmParser &parser,
7177
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
72-
DenseI64ArrayAttr &integers,
78+
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
7379
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
80+
inline ParseResult parseDynamicIndexList(
81+
OpAsmParser &parser,
82+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
83+
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
84+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
85+
return parseDynamicIndexList(parser, values, integers, &valueTypes,
86+
delimiter);
87+
}
7488

7589
/// Verify that a the `values` has as many elements as the number of entries in
7690
/// `attr` for which `isDynamic` evaluates to true.

mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
22
DialectExtension.cpp
33
LinalgMatchOps.cpp
44
LinalgTransformOps.cpp
5+
Syntax.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps
@@ -19,7 +20,6 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
1920
MLIRLinalgDialect
2021
MLIRLinalgTransforms
2122
MLIRParser
22-
MLIRPDLDialect
2323
MLIRSCFDialect
2424
MLIRSideEffectInterfaces
2525
MLIRTransformDialect

mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1414
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
1515
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
16-
#include "mlir/Dialect/PDL/IR/PDL.h"
1716
#include "mlir/Dialect/SCF/IR/SCF.h"
1817
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1918
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -31,7 +30,6 @@ class LinalgTransformDialectExtension
3130
using Base::Base;
3231

3332
void init() {
34-
declareDependentDialect<pdl::PDLDialect>();
3533
declareDependentDialect<linalg::LinalgDialect>();
3634

3735
declareGeneratedDialect<affine::AffineDialect>();

mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
1010
#include "mlir/Analysis/SliceAnalysis.h"
1111
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
1213
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
1314
#include "mlir/IR/BuiltinAttributes.h"
1415
#include "mlir/IR/FunctionImplementation.h"
@@ -745,82 +746,6 @@ static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
745746
printer << ")";
746747
}
747748
}
748-
/// Parses a single non-function type or a function type with at least one
749-
/// argument. This allows for the following syntax:
750-
///
751-
/// - type: just the argument type;
752-
/// - `(` type `)` `->` type: one argument and one result type;
753-
/// - `(` type `)` `->` `(` comma-separated-type-list `)`: one argument and
754-
/// multiple result types.
755-
///
756-
/// Unlike FunctionType, this allows and requires one to omit the parens around
757-
/// the argument type in absence of result types, and does not accept the
758-
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
759-
static ParseResult parseSemiFunctionType(OpAsmParser &parser,
760-
Type &argumentType, Type &resultType) {
761-
argumentType = resultType = nullptr;
762-
bool hasLParen = parser.parseOptionalLParen().succeeded();
763-
if (parser.parseType(argumentType).failed())
764-
return failure();
765-
if (!hasLParen)
766-
return success();
767-
768-
return failure(parser.parseRParen().failed() ||
769-
parser.parseArrow().failed() ||
770-
parser.parseType(resultType).failed());
771-
}
772-
static ParseResult parseSemiFunctionType(OpAsmParser &parser,
773-
Type &argumentType,
774-
SmallVectorImpl<Type> &resultTypes) {
775-
argumentType = nullptr;
776-
bool hasLParen = parser.parseOptionalLParen().succeeded();
777-
if (parser.parseType(argumentType).failed())
778-
return failure();
779-
if (!hasLParen)
780-
return success();
781-
782-
if (parser.parseRParen().failed() || parser.parseArrow().failed())
783-
return failure();
784-
785-
if (parser.parseOptionalLParen().failed()) {
786-
Type type;
787-
if (parser.parseType(type).failed())
788-
return failure();
789-
resultTypes.push_back(type);
790-
return success();
791-
}
792-
if (parser.parseTypeList(resultTypes).failed() ||
793-
parser.parseRParen().failed()) {
794-
resultTypes.clear();
795-
return failure();
796-
}
797-
return success();
798-
}
799-
800-
/// Prints argument and result types in a syntax similar to that of FunctionType
801-
/// but allowing and requiring one to omit the parens around the argument type
802-
/// in absence of result types, and without the trailing `-> ()`.
803-
static void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
804-
Type argumentType, TypeRange resultType) {
805-
if (!resultType.empty())
806-
printer << "(";
807-
printer << argumentType;
808-
if (resultType.empty())
809-
return;
810-
printer << ") -> ";
811-
812-
if (resultType.size() > 1)
813-
printer << "(";
814-
llvm::interleaveComma(resultType, printer.getStream());
815-
if (resultType.size() > 1)
816-
printer << ")";
817-
}
818-
static void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
819-
Type argumentType, Type resultType) {
820-
return printSemiFunctionType(printer, op, argumentType,
821-
resultType ? TypeRange(resultType)
822-
: TypeRange());
823-
}
824749

825750
#define GET_OP_CLASSES
826751
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"

0 commit comments

Comments
 (0)