Skip to content

Commit fa63fb8

Browse files
committed
Add support of param type for transform.structured.tile_using_forall
1 parent 4684507 commit fa63fb8

File tree

3 files changed

+176
-22
lines changed

3 files changed

+176
-22
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ include "mlir/IR/RegionKindInterface.td"
2121

2222
// This is roughly similar to OpFoldResult assuming the handle produces a single
2323
// value in the payload IR.
24-
def TransformParamTypeOrAnyHandle : Type<
24+
def TransformAnyParamTypeOrAnyHandle : Type<
2525
Or<[TransformHandleTypeInterface.predicate,
26-
Transform_ParamType.predicate]>,
27-
"transform 'param' type or any handle type">;
26+
TransformParamTypeInterface.predicate]>,
27+
"transform any param type or any handle type">;
2828

2929
//===----------------------------------------------------------------------===//
3030
// Apply...PatternsOp
@@ -691,9 +691,9 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
691691
I64Attr:$dimension,
692692
I64Attr:$target_size,
693693
DefaultValuedAttr<I64Attr, "1">:$divisor);
694-
let results = (outs TransformParamTypeOrAnyHandle:$low_size,
695-
TransformParamTypeOrAnyHandle:$high_size,
696-
TransformParamTypeOrAnyHandle:$split_point);
694+
let results = (outs TransformAnyParamTypeOrAnyHandle:$low_size,
695+
TransformAnyParamTypeOrAnyHandle:$high_size,
696+
TransformAnyParamTypeOrAnyHandle:$split_point);
697697
let hasVerifier = 1;
698698
let assemblyFormat =
699699
"$target attr-dict `:` custom<MultitileSizesTypes>("
@@ -1408,7 +1408,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
14081408

14091409
let arguments = (ins TransformHandleTypeInterface:$target,
14101410
I64Attr:$dimension,
1411-
Optional<TransformParamTypeOrAnyHandle>:$dynamic_split_point,
1411+
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
14121412
I64Attr:$static_split_point);
14131413
let results = (outs TransformHandleTypeInterface:$first,
14141414
TransformHandleTypeInterface:$second);
@@ -1857,7 +1857,7 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
18571857
}];
18581858

18591859
let arguments = (ins TransformHandleTypeInterface:$target,
1860-
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
1860+
Variadic<TransformAnyParamTypeOrAnyHandle>:$dynamic_sizes,
18611861
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
18621862
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
18631863
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
@@ -1968,10 +1968,10 @@ def TileUsingForallOp :
19681968
}];
19691969

19701970
let arguments = (ins TransformHandleTypeInterface:$target,
1971-
Variadic<TransformHandleTypeInterface>:$num_threads,
1972-
Variadic<TransformHandleTypeInterface>:$tile_sizes,
1973-
Optional<TransformHandleTypeInterface>:$packed_num_threads,
1974-
Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
1971+
Variadic<TransformAnyParamTypeOrAnyHandle>:$num_threads,
1972+
Variadic<TransformAnyParamTypeOrAnyHandle>:$tile_sizes,
1973+
Optional<TransformAnyParamTypeOrAnyHandle>:$packed_num_threads,
1974+
Optional<TransformAnyParamTypeOrAnyHandle>:$packed_tile_sizes,
19751975
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
19761976
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
19771977
OptionalAttr<DeviceMappingArrayAttr>:$mapping);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
8686
return cast<LinalgOp>(result->getOperation());
8787
}
8888

89-
/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
90-
/// to exactly one op with one index result, return that value.
89+
/// Assuming that `ofr` is an index attr or a param of index type
90+
/// or a transform dialect handle mapped to exactly one op
91+
/// with one index result, return that value.
9192
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
9293
transform::TransformState &state, TransformOpInterface transformOp,
9394
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
@@ -98,12 +99,23 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
9899
result.push_back(ofr);
99100
continue;
100101
}
101-
auto payloadOps = state.getPayloadOps(ofr.get<Value>());
102+
103+
Value transformValue = ofr.get<Value>();
104+
if (isa<ParamType>(transformValue.getType())) {
105+
ArrayRef<Attribute> params = state.getParams(transformValue);
106+
if (params.size() != 1)
107+
return transformOp.emitDefiniteFailure()
108+
<< "requires exactly one parameter associated";
109+
result.push_back(params[0]);
110+
continue;
111+
}
112+
113+
auto payloadOps = state.getPayloadOps(transformValue);
102114
if (!llvm::hasSingleElement(payloadOps)) {
103115
DiagnosedSilenceableFailure diag =
104116
transformOp.emitSilenceableError()
105117
<< "handle must be mapped to exactly one payload op";
106-
diag.attachNote(ofr.get<Value>().getLoc())
118+
diag.attachNote(transformValue.getLoc())
107119
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
108120
return diag;
109121
}
@@ -123,14 +135,27 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
123135
return DiagnosedSilenceableFailure::success();
124136
}
125137

126-
// Given a list of OpFoldResults that are either index attrs or op
127-
// handles, return a list of OpFoldResults where all op handles are
128-
// replaced with the first (and only) OpResult of that payload op. (There
129-
// must be exactly one mapped payload op and it must have exactly one
130-
// index result.)
138+
// Given a list of params that are index attrs or a list of OpFoldResults
139+
// that are either index attrs or op handles, return a list of OpFoldResults
140+
// of index attrs or a list of OpFoldResults where all op handles are
141+
// replaced with the first (and only) OpResult of that payload op.
142+
// (There must be exactly one parameter associated with the AnyParamType or
143+
// one mapped payload op which must have exactly one index result.)
131144
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
132145
transform::TransformState &state, TransformOpInterface transformOp,
133146
SmallVector<OpFoldResult> &result, Value packedHandle) {
147+
if (isa<AnyParamType>(packedHandle.getType())) {
148+
ArrayRef<Attribute> params = state.getParams(packedHandle);
149+
for (auto param : params) {
150+
if (!isa<IntegerAttr>(param))
151+
return transformOp.emitDefiniteFailure()
152+
<< "expected the parameter to be associated with an integer "
153+
"attribute";
154+
result.push_back(param);
155+
}
156+
return DiagnosedSilenceableFailure::success();
157+
}
158+
134159
for (Operation *op : state.getPayloadOps(packedHandle)) {
135160
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
136161
DiagnosedSilenceableFailure diag =

mlir/test/Dialect/Linalg/tile-to-forall.mlir

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
22

33
// Offset per thread:
44
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@@ -451,3 +451,132 @@ module attributes {transform.with_named_sequence} {
451451
}
452452
}
453453

454+
// -----
455+
456+
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
457+
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
458+
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
459+
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
460+
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
461+
// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
462+
463+
// CHECK-LABEL: matmul_tile_size_dynamic(
464+
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
465+
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
466+
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
467+
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
468+
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[C0]] :
469+
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[C1]] :
470+
// CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
471+
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
472+
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
473+
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
474+
// CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
475+
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
476+
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
477+
// CHECK: tensor.extract_slice %[[A]]
478+
// CHECK: tensor.extract_slice %[[B]]
479+
// CHECK: tensor.extract_slice %[[C_BLK]]
480+
// CHECK: linalg.matmul
481+
// CHECK: scf.forall.in_parallel
482+
// CHECK-NEXT: tensor.parallel_insert_slice
483+
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
484+
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
485+
return %0 : tensor<?x?xf32>
486+
}
487+
488+
module attributes {transform.with_named_sequence} {
489+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
490+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
491+
%sz = transform.param.constant 10 : i64 -> !transform.param<i64>
492+
%1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
493+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
494+
transform.yield
495+
}
496+
}
497+
498+
// -----
499+
500+
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
501+
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
502+
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
503+
return %0 : tensor<?x?xf32>
504+
}
505+
506+
module attributes {transform.with_named_sequence} {
507+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
508+
%0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
509+
%c10 = transform.param.constant 10 : i64 -> !transform.param<i64>
510+
%c20 = transform.param.constant 20 : i64 -> !transform.param<i64>
511+
%sz = transform.merge_handles %c10, %c20 : !transform.param<i64>
512+
// expected-error @below {{requires exactly one parameter associated}}
513+
%1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
514+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
515+
transform.yield
516+
}
517+
}
518+
519+
// -----
520+
521+
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
522+
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
523+
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
524+
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
525+
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
526+
// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
527+
528+
// CHECK-LABEL: matmul_tile_size_dynamic(
529+
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
530+
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
531+
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
532+
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
533+
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[C0]] :
534+
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[C1]] :
535+
// CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
536+
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
537+
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
538+
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
539+
// CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
540+
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
541+
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
542+
// CHECK: tensor.extract_slice %[[A]]
543+
// CHECK: tensor.extract_slice %[[B]]
544+
// CHECK: tensor.extract_slice %[[C_BLK]]
545+
// CHECK: linalg.matmul
546+
// CHECK: scf.forall.in_parallel
547+
// CHECK-NEXT: tensor.parallel_insert_slice
548+
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
549+
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
550+
return %0 : tensor<?x?xf32>
551+
}
552+
553+
module attributes {transform.with_named_sequence} {
554+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
555+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
556+
%c10 = transform.param.constant 10 : i64 -> !transform.any_param
557+
%c20 = transform.param.constant 20 : i64 -> !transform.any_param
558+
%sz = transform.merge_handles %c10, %c20 : !transform.any_param
559+
%1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
560+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
561+
transform.yield
562+
}
563+
}
564+
565+
// -----
566+
567+
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
568+
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
569+
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
570+
return %0 : tensor<?x?xf32>
571+
}
572+
573+
module attributes {transform.with_named_sequence} {
574+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
575+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
576+
%sz = transform.param.constant "[10 : i64, 20 : i64]" -> !transform.any_param
577+
// expected-error @below {{expected IntegerAttr}}
578+
%1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
579+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
580+
transform.yield
581+
}
582+
}

0 commit comments

Comments
 (0)