-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Add support of param type for transform.structured.tile_using_forall #72097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: jinchen62 (jinchen62) ChangesMake transform.structured.tile_using_forall be able to take param type tile sizes. Examples:
Full diff: https://github.com/llvm/llvm-project/pull/72097.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..a24f6ff8308ba34 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
// value in the payload IR.
def TransformParamTypeOrAnyHandle : Type<
Or<[TransformHandleTypeInterface.predicate,
- Transform_ParamType.predicate]>,
+ TransformParamTypeInterface.predicate]>,
"transform 'param' type or any handle type">;
//===----------------------------------------------------------------------===//
@@ -1924,10 +1924,10 @@ def TileUsingForallOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<TransformHandleTypeInterface>:$num_threads,
- Variadic<TransformHandleTypeInterface>:$tile_sizes,
- Optional<TransformHandleTypeInterface>:$packed_num_threads,
- Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+ Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+ Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+ Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+ Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..4bf4db3381fab79 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -98,12 +98,34 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
result.push_back(ofr);
continue;
}
- auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+ Value transformValue = ofr.get<Value>();
+ if (isa<ParamType>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (!isa<IntegerAttr>(params[0]))
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ result.push_back(params[0]);
+ continue;
+ }
+ if (isa<AnyParamType>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (!isa<ArrayAttr>(params[0]))
+ return transformOp.emitDefiniteFailure() << "expected ArrayAttr";
+ ArrayAttr paramsArray = cast<ArrayAttr>(params[0]);
+ for (Attribute param : paramsArray.getValue()) {
+ if (!isa<IntegerAttr>(param))
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ result.push_back(param);
+ }
+ continue;
+ }
+
+ auto payloadOps = state.getPayloadOps(transformValue);
if (!llvm::hasSingleElement(payloadOps)) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "handle must be mapped to exactly one payload op";
- diag.attachNote(ofr.get<Value>().getLoc())
+ diag.attachNote(transformValue.getLoc())
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
return diag;
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for you PR! Please address inline comments.
Also please add tests. The good practice is to have at least one non-trivial test for desired behavior, tests for edge cases, and tests for all user-visible diagnostic messages.
for (Attribute param : paramsArray.getValue()) { | ||
if (!isa<IntegerAttr>(param)) | ||
return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; | ||
result.push_back(param); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me what is the intended model here. Looking at the code, if the type is any_param
, it will accept array-of-integer parameters and expand its context as tile sizes. But it will keep doing so for other parameters, so we may end up having more sizes than loops, and a surprising mismatch between parameter positions and loop depth. We could accept one parameter, and unpack it, but I don't think we want to accept multiple here. Also, there's no need to rely on array attributes as a parameter can be associated with a list of attributes (unless we want to support some multi-level scheme where the parameter is a list-of-lists of tile sizes to use for each of the loops associated with the input).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe this was addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought moving the support of any_param to the second unpackSingleIndexResultPayloadOperations
that takes a Value
would address it since there won't be lists inside lists. Could you give me more guidence on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I recall, there are currently two modes of functioning here:
- there are N unpacked handles, where each handle corresponds to one dimension; each handle points to the same number of operations as the "target" handle;
- there is 1 packed handle, which points to as many operations as dimensions; same dimensions are used for tiling all operations associated with the "target" handle.
Parameters should behave similarly. Since parameters are associated with a list of attributes, we would treat that list identically to how we treat the list of operations for operation handle. Consequently, we should never expect an array attribute in parameters. There is already the implicit list and we must use that instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please review, thanks!
Also note that the utility function that is changed in this PR is used in more cases than |
✅ With the latest revision this PR passed the C/C++ code formatter. |
a04b3bd
to
8d09cc2
Compare
@ftynse Thanks for the review! I've addressed the comments, please check. |
e50f7c3
to
48c045e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add tests for user-visible error messages (maybe you forgot to git add
a file with those?)
c4be997
to
3a524ab
Compare
27a4e49
to
b0d4192
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! It's almost ready to go, please address the remaining comments.
fa63fb8
to
4851368
Compare
988b39e
to
dcb10ea
Compare
@ftynse Please review, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this!
Windows problems are due to a missing numpy installation and are irrelevant here. |
Make transform.structured.tile_using_forall be able to take param type tile sizes.
Examples: