Skip to content

Commit f90b609

Browse files
authored
[mlir] introduce transform.num_associations (#76723)
Add a new transform operation that creates a new parameter containing the number of payload objects (operations, values or attributes) associated with the argument. This is useful in matching and for debugging purposes. This replaces three ad-hoc operations previously provided by the test extension.
1 parent df1b5ae commit f90b609

File tree

10 files changed

+152
-115
lines changed

10 files changed

+152
-115
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,28 @@ def CastOp : TransformDialectOp<"cast",
438438
}];
439439
}
440440

441+
def NumAssociationsOp : TransformDialectOp<"num_associations",
442+
[MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
443+
DeclareOpInterfaceMethods<TransformOpInterface>,
444+
MatchOpInterface]> {
445+
let summary =
446+
"Returns the number of payload objects associated with the argument";
447+
let description = [{
448+
Given an argument, handle or parameter, returns a new parameter associated
449+
with a single 64-bit number that corresponds to the number of payload
450+
objects (operations or values for a handle, attributes for a parameter)
451+
associated with the argument.
452+
453+
Always succeeds.
454+
}];
455+
let arguments = (ins Transform_AnyHandleOrParamType:$handle);
456+
let results = (outs TransformParamTypeInterface:$num);
457+
let assemblyFormat = [{
458+
$handle attr-dict `:` functional-type(operands, results)
459+
}];
460+
let hasVerifier = 1;
461+
}
462+
441463
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
442464
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
443465
DeclareOpInterfaceMethods<SymbolUserOpInterface>,

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/ADT/STLExtras.h"
3333
#include "llvm/ADT/ScopeExit.h"
3434
#include "llvm/ADT/SmallPtrSet.h"
35+
#include "llvm/ADT/TypeSwitch.h"
3536
#include "llvm/Support/Debug.h"
3637
#include <optional>
3738

@@ -1974,6 +1975,42 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
19741975
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
19751976
}
19761977

1978+
//===----------------------------------------------------------------------===//
1979+
// NumAssociationsOp
1980+
//===----------------------------------------------------------------------===//
1981+
1982+
DiagnosedSilenceableFailure
1983+
transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
1984+
transform::TransformResults &results,
1985+
transform::TransformState &state) {
1986+
size_t numAssociations =
1987+
llvm::TypeSwitch<Type, size_t>(getHandle().getType())
1988+
.Case([&](TransformHandleTypeInterface opHandle) {
1989+
return llvm::range_size(state.getPayloadOps(getHandle()));
1990+
})
1991+
.Case([&](TransformValueHandleTypeInterface valueHandle) {
1992+
return llvm::range_size(state.getPayloadValues(getHandle()));
1993+
})
1994+
.Case([&](TransformParamTypeInterface param) {
1995+
return llvm::range_size(state.getParams(getHandle()));
1996+
})
1997+
.Default([](Type) {
1998+
llvm_unreachable("unknown kind of transform dialect type");
1999+
return 0;
2000+
});
2001+
results.setParams(getNum().cast<OpResult>(),
2002+
rewriter.getI64IntegerAttr(numAssociations));
2003+
return DiagnosedSilenceableFailure::success();
2004+
}
2005+
2006+
LogicalResult transform::NumAssociationsOp::verify() {
2007+
// Verify that the result type accepts an i64 attribute as payload.
2008+
auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
2009+
return resultType
2010+
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2011+
.checkAndReport();
2012+
}
2013+
19772014
//===----------------------------------------------------------------------===//
19782015
// SelectOp
19792016
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
3636

3737
// Ensure that one linalg.fill was generated.
3838
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
39+
%p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
3940
// expected-remark @below{{1}}
40-
transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
41+
transform.test_print_param %p : !transform.param<i64>
4142

4243
// Ensure that one linalg.copy was generated.
4344
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
45+
%p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
4446
// expected-remark @below{{1}}
45-
transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
47+
transform.test_print_param %p2 : !transform.param<i64>
4648
transform.yield
4749
}
4850
}
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
7375

7476
// Ensure that one linalg.fill was generated.
7577
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
78+
%p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
7679
// expected-remark @below{{1}}
77-
transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
80+
transform.test_print_param %p : !transform.param<i64>
7881

7982
// Ensure that one linalg.copy was generated.
8083
%linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
84+
%p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
8185
// expected-remark @below{{1}}
82-
transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
86+
transform.test_print_param %p2 : !transform.param<i64>
8387

8488
// Ensure that one memref.alloca was generated.
8589
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
90+
%p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
8691
// expected-remark @below{{1}}
87-
transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
92+
transform.test_print_param %p3 : !transform.param<i64>
8893

8994
// Make sure that One-Shot Bufferize can bufferize the rest.
9095
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op

mlir/test/Dialect/Linalg/transform-op-match.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
134134
#linalg.iterator_type<parallel>,
135135
#linalg.iterator_type<reduction>]}
136136
in %arg1 : (!transform.any_op) -> !transform.any_op
137-
// expected-remark @below {{0}}
138-
transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
137+
%p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
138+
// expected-remark @below {{0}}
139+
transform.test_print_param %p : !transform.param<i64>
139140
transform.yield
140141
}
141142
}

mlir/test/Dialect/Linalg/transform-op-pad.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
4141
padding_dimensions=[0, 1, 2],
4242
pack_paddings=[1, 1, 0]
4343
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
44+
%p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
4445
// expected-remark @below {{1}}
45-
transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
46+
transform.test_print_param %p : !transform.param<i64>
4647
transform.yield
4748
}
4849
}

mlir/test/Dialect/Transform/ops-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,11 @@ transform.sequence failures(propagate) {
696696
transform.named_sequence @foo()
697697
} : !transform.any_op
698698
}
699+
700+
// -----
701+
702+
transform.sequence failures(propagate) {
703+
^bb0(%arg0: !transform.any_op):
704+
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
705+
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
706+
}

0 commit comments

Comments
 (0)