Skip to content

Commit d1ca1d0

Browse files
[mlir] Makes zip_shortest an optional keyword in transform.foreach (#98492)
This PR addresses a [comment] made by @ftynse about the syntax for `ForeachOp`. The syntax was modified by @muneebkhan85 in #82792, where the attribute dictionary was moved to the middle. This patch moves it back to its original place at the end. And introduces an optional keyword for `zip_shortest`. [comment]: #82792 (review)
1 parent 2d69c36 commit d1ca1d0

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,11 +649,12 @@ def ForeachOp : TransformDialectOp<"foreach",
649649
}];
650650

651651
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
652-
UnitAttr:$zip_shortest);
652+
UnitAttr:$with_zip_shortest);
653653
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
654654
let regions = (region SizedRegion<1>:$body);
655655
let assemblyFormat =
656-
"$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
656+
"$targets oilist(`with_zip_shortest` $with_zip_shortest) `:` "
657+
"type($targets) (`->` type($results)^)? $body attr-dict";
657658
let hasVerifier = 1;
658659

659660
let extraClassDeclaration = [{

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,11 +1396,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
13961396
SmallVector<SmallVector<MappedValue>> payloads;
13971397
detail::prepareValueMappings(payloads, getTargets(), state);
13981398
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399-
bool isZipShortest = getZipShortest();
1399+
bool withZipShortest = getWithZipShortest();
14001400

14011401
// In case of `zip_shortest`, set the number of iterations to the
14021402
// smallest payload in the targets.
1403-
if (isZipShortest) {
1403+
if (withZipShortest) {
14041404
numIterations =
14051405
llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
14061406
const SmallVector<MappedValue> &B) {
@@ -1414,7 +1414,7 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
14141414
// As we will be "zipping" over them, check all payloads have the same size.
14151415
// `zip_shortest` adjusts all payloads to the same size, so skip this check
14161416
// when true.
1417-
for (size_t argIdx = 1; !isZipShortest && argIdx < payloads.size();
1417+
for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
14181418
argIdx++) {
14191419
if (payloads[argIdx].size() != numIterations) {
14201420
return emitSilenceableError()

mlir/test/Dialect/Linalg/continuous-tiling-full.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ module attributes {transform.with_named_sequence} {
127127
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
128128
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
129129
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
130-
transform.foreach %linalg_splits, %tile_sizes {zip_shortest} : !transform.any_op, !transform.any_op {
130+
transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
131131
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
132132
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
133133
transform.yield

0 commit comments

Comments
 (0)