Skip to content

[mlir] introduce transform.num_associations #76723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,28 @@ def CastOp : TransformDialectOp<"cast",
}];
}

def NumAssociationsOp : TransformDialectOp<"num_associations",
[MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
MatchOpInterface]> {
let summary =
"Returns the number of payload objects associated with the argument";
let description = [{
Given an argument, handle or parameter, returns a new parameter associated
with a single 64-bit number that corresponds to the number of payload
objects (operations or values for a handle, attributes for a parameter)
associated with the argument.

Always succeeds.
}];
let arguments = (ins Transform_AnyHandleOrParamType:$handle);
let results = (outs TransformParamTypeInterface:$num);
let assemblyFormat = [{
$handle attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}

def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
Expand Down
37 changes: 37 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>

Expand Down Expand Up @@ -1974,6 +1975,42 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}

//===----------------------------------------------------------------------===//
// NumAssociationsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
size_t numAssociations =
llvm::TypeSwitch<Type, size_t>(getHandle().getType())
.Case([&](TransformHandleTypeInterface opHandle) {
return llvm::range_size(state.getPayloadOps(getHandle()));
})
.Case([&](TransformValueHandleTypeInterface valueHandle) {
return llvm::range_size(state.getPayloadValues(getHandle()));
})
.Case([&](TransformParamTypeInterface param) {
return llvm::range_size(state.getParams(getHandle()));
})
.Default([](Type) {
llvm_unreachable("unknown kind of transform dialect type");
return 0;
});
results.setParams(getNum().cast<OpResult>(),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::NumAssociationsOp::verify() {
// Verify that the result type accepts an i64 attribute as payload.
auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
return resultType
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
.checkAndReport();
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 10 additions & 5 deletions mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {

// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
%p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
transform.test_print_param %p : !transform.param<i64>

// Ensure that one linalg.copy was generated.
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
%p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
transform.test_print_param %p2 : !transform.param<i64>
transform.yield
}
}
Expand Down Expand Up @@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {

// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
%p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
transform.test_print_param %p : !transform.param<i64>

// Ensure that one linalg.copy was generated.
%linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
%p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
transform.test_print_param %p2 : !transform.param<i64>

// Ensure that one memref.alloca was generated.
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
%p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
transform.test_print_param %p3 : !transform.param<i64>

// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
Expand Down
5 changes: 3 additions & 2 deletions mlir/test/Dialect/Linalg/transform-op-match.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
#linalg.iterator_type<parallel>,
#linalg.iterator_type<reduction>]}
in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-remark @below {{0}}
transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
%p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Dialect/Linalg/transform-op-pad.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
%p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Transform/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,11 @@ transform.sequence failures(propagate) {
transform.named_sequence @foo()
} : !transform.any_op
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
}
Loading