Skip to content

Commit 2c1c676

Browse files
srcarrollftynse
andauthored
[mlir][transform] Consistent linalg transform op syntax for dynamic index lists (llvm#90897)
This patch is a first pass at making consistent syntax across the `LinalgTransformOp`s that use dynamic index lists for size parameters. Previously, there were two different forms: inline types in the list, or place them in the functional style tuple. This patch goes for the latter. In order to do this, the `printPackedOrDynamicIndexList`, `printDynamicIndexList` and their `parse` counterparts were modified so that the types can be optionally provided to the corresponding custom directives. All affected ops now use tablegen `assemblyFormat`, so custom `parse`/`print` functions have been removed. There are a couple ops that will likely add dynamic size support, and once that happens it should be made sure that the assembly remains consistent with the changes in this patch. The affected ops are as follows: `pack`, `pack_greedily`, `tile_using_forall`. The `tile_using_for` and `vectorize` ops already used this syntax, but their custom assembly was removed. --------- Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent c6efcc9 commit 2c1c676

File tree

53 files changed

+210
-323
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+210
-323
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -783,10 +783,9 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
783783
let assemblyFormat = [{
784784
$target
785785
`packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
786-
$static_packed_sizes,
787-
type($packed_sizes))
786+
$static_packed_sizes)
788787
attr-dict
789-
`:` functional-type($target, results)
788+
`:` functional-type(operands, results)
790789
}];
791790

792791
let builders = [
@@ -890,14 +889,13 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
890889
$target
891890
oilist(
892891
`matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
893-
$static_matmul_packed_sizes,
894-
type($matmul_packed_sizes))
892+
$static_matmul_packed_sizes)
895893
(`matmul_padded_sizes_next_multiple_of` `=`
896894
$matmul_padded_sizes_next_multiple_of^)?
897895
`matmul_inner_dims_order` `=` $matmul_inner_dims_order
898896
)
899897
attr-dict
900-
`:` functional-type($target, results)
898+
`:` functional-type(operands, results)
901899
}];
902900
let hasVerifier = 1;
903901

@@ -1899,7 +1897,17 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
18991897
$scalableSizes)>,
19001898
];
19011899

1902-
let hasCustomAssemblyFormat = 1;
1900+
let assemblyFormat = [{
1901+
$target
1902+
`tile_sizes` custom<DynamicIndexList>(
1903+
$dynamic_sizes,
1904+
$static_sizes,
1905+
$scalable_sizes)
1906+
(`interchange` `=` $interchange^)?
1907+
attr-dict
1908+
`:` functional-type(operands, results)
1909+
}];
1910+
19031911
let hasVerifier = 1;
19041912

19051913
let extraClassDeclaration = [{
@@ -2017,17 +2025,13 @@ def TileUsingForallOp :
20172025
let assemblyFormat = [{
20182026
$target oilist(
20192027
`num_threads` custom<PackedOrDynamicIndexList>($packed_num_threads,
2020-
type($packed_num_threads),
20212028
$num_threads,
2022-
type($num_threads),
20232029
$static_num_threads) |
20242030
`tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
2025-
type($packed_tile_sizes),
20262031
$tile_sizes,
2027-
type($tile_sizes),
20282032
$static_tile_sizes))
20292033
(`(` `mapping` `=` $mapping^ `)`)? attr-dict
2030-
`:` functional-type($target, results)
2034+
`:` functional-type(operands, results)
20312035
}];
20322036
let hasVerifier = 1;
20332037

@@ -2162,7 +2166,18 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
21622166

21632167
let results = (outs);
21642168

2165-
let hasCustomAssemblyFormat = 1;
2169+
// We use oilist here to elide the optional `vector_sizes` when empty list
2170+
// is passed.
2171+
let assemblyFormat = [{
2172+
$target oilist(
2173+
`vector_sizes` custom<DynamicIndexList>(
2174+
$vector_sizes,
2175+
$static_vector_sizes,
2176+
$scalable_sizes))
2177+
attr-dict
2178+
`:` type($target)(`,`type($vector_sizes)^)?
2179+
}];
2180+
21662181
let hasVerifier = 1;
21672182

21682183
let extraClassDeclaration = [{

mlir/include/mlir/Dialect/Transform/Utils/Utils.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
3737
Value packed, Type packedType,
3838
OperandRange values, TypeRange valueTypes,
3939
DenseI64ArrayAttr integers);
40+
inline void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
41+
Value packed, OperandRange values,
42+
DenseI64ArrayAttr integers) {
43+
printPackedOrDynamicIndexList(printer, op, packed, Type(), values,
44+
TypeRange{}, integers);
45+
}
4046

4147
/// Parser hook for custom directive in assemblyFormat.
4248
///
@@ -47,7 +53,15 @@ void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
4753
ParseResult parsePackedOrDynamicIndexList(
4854
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
4955
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
50-
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers);
56+
SmallVectorImpl<Type> *valueTypes, DenseI64ArrayAttr &integers);
57+
inline ParseResult parsePackedOrDynamicIndexList(
58+
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
59+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
60+
DenseI64ArrayAttr &integers) {
61+
Type packedType;
62+
return parsePackedOrDynamicIndexList(parser, packed, packedType, values,
63+
nullptr, integers);
64+
}
5165
} // namespace transform
5266
} // namespace mlir
5367

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,16 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
106106
/// empty then assume that all indices are non-scalable.
107107
void printDynamicIndexList(
108108
OpAsmPrinter &printer, Operation *op, OperandRange values,
109-
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
110-
ArrayRef<bool> scalables = {},
109+
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
110+
TypeRange valueTypes = TypeRange(),
111111
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
112+
inline void printDynamicIndexList(
113+
OpAsmPrinter &printer, Operation *op, OperandRange values,
114+
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
115+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
116+
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
117+
delimiter);
118+
}
112119

113120
/// Parser hook for custom directive in assemblyFormat.
114121
///

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 0 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,86 +2823,6 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
28232823
return results;
28242824
}
28252825

2826-
// We want to parse `DenseI64ArrayAttr` using the short form without the
2827-
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
2828-
ParseResult parseOptionalInterchange(OpAsmParser &parser,
2829-
OperationState &result) {
2830-
if (failed(parser.parseOptionalKeyword("interchange")))
2831-
return success();
2832-
if (failed(parser.parseEqual()))
2833-
return failure();
2834-
result.addAttribute(
2835-
transform::TileUsingForOp::getInterchangeAttrName(result.name),
2836-
DenseI64ArrayAttr::parse(parser, Type{}));
2837-
return success();
2838-
}
2839-
2840-
void printOptionalInterchange(OpAsmPrinter &p,
2841-
ArrayRef<int64_t> interchangeVals) {
2842-
if (!interchangeVals.empty()) {
2843-
p << " interchange = [";
2844-
llvm::interleaveComma(interchangeVals, p,
2845-
[&](int64_t integer) { p << integer; });
2846-
p << "]";
2847-
}
2848-
}
2849-
2850-
ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
2851-
OperationState &result) {
2852-
OpAsmParser::UnresolvedOperand target;
2853-
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
2854-
DenseI64ArrayAttr staticSizes;
2855-
FunctionType functionalType;
2856-
llvm::SMLoc operandLoc;
2857-
DenseBoolArrayAttr scalableVals;
2858-
2859-
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
2860-
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
2861-
parseOptionalInterchange(parser, result) ||
2862-
parser.parseOptionalAttrDict(result.attributes) ||
2863-
parser.parseColonType(functionalType))
2864-
return ParseResult::failure();
2865-
2866-
size_t numExpectedLoops =
2867-
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
2868-
if (functionalType.getNumResults() != numExpectedLoops + 1) {
2869-
return parser.emitError(parser.getNameLoc())
2870-
<< "expected " << (numExpectedLoops + 1) << " result type(s)";
2871-
}
2872-
if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2873-
return parser.emitError(operandLoc)
2874-
<< "expected " << dynamicSizes.size() + 1 << " operand type(s)";
2875-
}
2876-
if (parser.resolveOperand(target, functionalType.getInputs().front(),
2877-
result.operands) ||
2878-
parser.resolveOperands(dynamicSizes,
2879-
functionalType.getInputs().drop_front(),
2880-
operandLoc, result.operands)) {
2881-
return failure();
2882-
}
2883-
2884-
result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
2885-
2886-
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
2887-
result.addTypes(functionalType.getResults());
2888-
return success();
2889-
}
2890-
2891-
void TileUsingForOp::print(OpAsmPrinter &p) {
2892-
p << ' ' << getTarget();
2893-
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
2894-
/*valueTypes=*/{}, getScalableSizesAttr(),
2895-
OpAsmParser::Delimiter::Square);
2896-
printOptionalInterchange(p, getInterchange());
2897-
p.printOptionalAttrDict(
2898-
(*this)->getAttrs(),
2899-
/*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
2900-
getScalableSizesAttrName(getOperation()->getName()),
2901-
getStaticSizesAttrName(getOperation()->getName())});
2902-
p << " : ";
2903-
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
2904-
}
2905-
29062826
void transform::TileUsingForOp::getEffects(
29072827
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
29082828
consumesHandle(getTarget(), effects);
@@ -3219,80 +3139,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
32193139
// VectorizeOp
32203140
//===----------------------------------------------------------------------===//
32213141

3222-
static const StringLiteral kVectorSizesKeyword = "vector_sizes";
3223-
3224-
ParseResult transform::VectorizeOp::parse(OpAsmParser &parser,
3225-
OperationState &result) {
3226-
OpAsmParser::UnresolvedOperand target;
3227-
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
3228-
DenseI64ArrayAttr staticSizes;
3229-
SmallVector<Type> operandTypes;
3230-
llvm::SMLoc operandLoc;
3231-
DenseBoolArrayAttr scalableVals;
3232-
3233-
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
3234-
return ParseResult::failure();
3235-
3236-
if (succeeded(parser.parseOptionalKeyword(kVectorSizesKeyword))) {
3237-
if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes,
3238-
scalableVals)))
3239-
return ParseResult::failure();
3240-
}
3241-
3242-
if (succeeded(parser.parseOptionalKeyword(
3243-
getVectorizeNdExtractAttrName(result.name))))
3244-
result.addAttribute(getVectorizeNdExtractAttrName(result.name),
3245-
parser.getBuilder().getUnitAttr());
3246-
3247-
if (parser.parseOptionalAttrDict(result.attributes) ||
3248-
parser.parseColonTypeList(operandTypes))
3249-
return ParseResult::failure();
3250-
3251-
if (operandTypes.size() != dynamicSizes.size() + 1) {
3252-
return parser.emitError(operandLoc)
3253-
<< "expected " << dynamicSizes.size() + 1 << " operand type(s)";
3254-
}
3255-
if (parser.resolveOperand(target, operandTypes.front(), result.operands) ||
3256-
parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(),
3257-
operandLoc, result.operands)) {
3258-
return failure();
3259-
}
3260-
3261-
if (scalableVals)
3262-
result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
3263-
if (staticSizes)
3264-
result.addAttribute(getStaticVectorSizesAttrName(result.name), staticSizes);
3265-
3266-
return success();
3267-
}
3268-
3269-
void transform::VectorizeOp::print(OpAsmPrinter &p) {
3270-
p << ' ' << getTarget() << ' ';
3271-
if (!getMixedVectorSizes().empty()) {
3272-
p << kVectorSizesKeyword << ' ';
3273-
printDynamicIndexList(p, getOperation(), getVectorSizes(),
3274-
getStaticVectorSizesAttr(),
3275-
/*valueTypes=*/{}, getScalableSizesAttr(),
3276-
OpAsmParser::Delimiter::Square);
3277-
}
3278-
3279-
if (getVectorizeNdExtract())
3280-
p << getVectorizeNdExtractAttrName() << ' ';
3281-
3282-
p.printOptionalAttrDict(
3283-
(*this)->getAttrs(),
3284-
/*elidedAttrs=*/{
3285-
getScalableSizesAttrName(getOperation()->getName()),
3286-
getStaticVectorSizesAttrName(getOperation()->getName())});
3287-
p << " : ";
3288-
p << getTarget().getType();
3289-
if (!getVectorSizes().empty()) {
3290-
p << ", ";
3291-
llvm::interleaveComma(getVectorSizes(), p,
3292-
[&](Value operand) { p << operand.getType(); });
3293-
}
3294-
}
3295-
32963142
DiagnosedSilenceableFailure transform::VectorizeOp::apply(
32973143
transform::TransformRewriter &rewriter,
32983144
mlir::transform::TransformResults &transformResults,

mlir/lib/Dialect/Transform/Utils/Utils.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ void mlir::transform::printPackedOrDynamicIndexList(
2020
if (packed) {
2121
assert(values.empty() && (!integers || integers.empty()) &&
2222
"expected no values/integers");
23-
printer << "*(" << packed << " : " << packedType << ")";
23+
printer << "*(" << packed;
24+
if (packedType) {
25+
printer << " : " << packedType;
26+
}
27+
printer << ")";
2428
return;
2529
}
2630
printDynamicIndexList(printer, op, values, integers, valueTypes);
@@ -29,19 +33,20 @@ void mlir::transform::printPackedOrDynamicIndexList(
2933
ParseResult mlir::transform::parsePackedOrDynamicIndexList(
3034
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
3135
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
32-
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers) {
36+
SmallVectorImpl<Type> *valueTypes, DenseI64ArrayAttr &integers) {
3337
OpAsmParser::UnresolvedOperand packedOperand;
3438
if (parser.parseOptionalStar().succeeded()) {
3539
if (parser.parseLParen().failed() ||
36-
parser.parseOperand(packedOperand).failed() ||
37-
parser.parseColonType(packedType).failed() ||
38-
parser.parseRParen().failed()) {
40+
parser.parseOperand(packedOperand).failed())
41+
return failure();
42+
if (packedType && (parser.parseColonType(packedType).failed()))
43+
return failure();
44+
if (parser.parseRParen().failed())
3945
return failure();
40-
}
4146
packed.emplace(packedOperand);
4247
integers = parser.getBuilder().getDenseI64ArrayAttr({});
4348
return success();
4449
}
4550

46-
return parseDynamicIndexList(parser, values, integers, &valueTypes);
51+
return parseDynamicIndexList(parser, values, integers, valueTypes);
4752
}

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) {
113113
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
114114
OperandRange values,
115115
ArrayRef<int64_t> integers,
116-
TypeRange valueTypes, ArrayRef<bool> scalables,
116+
ArrayRef<bool> scalables, TypeRange valueTypes,
117117
AsmParser::Delimiter delimiter) {
118118
char leftDelimiter = getLeftDelimiter(delimiter);
119119
char rightDelimiter = getRightDelimiter(delimiter);

mlir/test/Dialect/LLVM/transform-e2e.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @matmul_tensors(
1515
module attributes {transform.with_named_sequence} {
1616
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.consumed}) {
1717
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
18-
%1, %loops:3 = transform.structured.tile_using_for %0 [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
18+
%1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
1919
%2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
2020
transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
2121
%b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}

mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
2727
module attributes {transform.with_named_sequence} {
2828
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
2929
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
30-
%1, %loops:4 = transform.structured.tile_using_for %0 [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
30+
%1, %loops:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
3131
transform.yield
3232
}
3333
}
@@ -54,7 +54,7 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
5454
module attributes {transform.with_named_sequence} {
5555
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5656
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
57-
%1, %loops:2 = transform.structured.tile_using_for %0 [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
57+
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
5858
transform.yield
5959
}
6060
}
@@ -85,7 +85,7 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
8585
module attributes {transform.with_named_sequence} {
8686
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
8787
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
88-
%1, %loops:2 = transform.structured.tile_using_for %0 [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
88+
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
8989
transform.yield
9090
}
9191
}

0 commit comments

Comments
 (0)