Skip to content

Add support of param type for transform.structured.tile_using_forall #72097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2024

Conversation

jinchen62
Copy link
Contributor

@jinchen62 jinchen62 commented Nov 13, 2023

Make transform.structured.tile_using_forall be able to take param type tile sizes.

Examples:

%tile_sizes = transform.param.constant 16 : i64 -> !transform.param<i64>
transform.structured.tile_using_forall %matmul tile_sizes [%tile_sizes : !transform.param<i64>, 32] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%c10 = transform.param.constant 10 : i64 -> !transform.any_param
%c20 = transform.param.constant 20 : i64 -> !transform.any_param
%tile_sizes = transform.merge_handles %c10, %c20 : !transform.any_param
transform.structured.tile_using_forall %matmul tile_sizes *(%tile_sizes : !transform.any_param) ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: jinchen62 (jinchen62)

Changes

Make transform.structured.tile_using_forall be able to take param type tile sizes.

Examples:

%tile_size1 = transform.param.constant 16 : i64 -&gt; !transform.param&lt;i64&gt;
transform.structured.tile_using_forall %matmul tile_sizes [%tile_size1 : !transform.param&lt;i64&gt;, 32] ( mapping = [#gpu.block&lt;x&gt;, #gpu.block&lt;y&gt;] ) : (!transform.any_op) -&gt; (!transform.any_op, !transform.any_op)
%tile_sizes = transform.param.constant [16 : i64, 32 : i64] -&gt; !transform.any_param
transform.structured.tile_using_forall %matmul tile_sizes [%tile_sizes : !transform.any_param] ( mapping = [#gpu.block&lt;x&gt;, #gpu.block&lt;y&gt;] ) : (!transform.any_op) -&gt; (!transform.any_op, !transform.any_op)

Full diff: https://github.com/llvm/llvm-project/pull/72097.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5-5)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+24-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..a24f6ff8308ba34 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
 // value in the payload IR.
 def TransformParamTypeOrAnyHandle : Type<
     Or<[TransformHandleTypeInterface.predicate,
-        Transform_ParamType.predicate]>,
+        TransformParamTypeInterface.predicate]>,
     "transform 'param' type or any handle type">;
 
 //===----------------------------------------------------------------------===//
@@ -1924,10 +1924,10 @@ def TileUsingForallOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<TransformHandleTypeInterface>:$num_threads,
-                   Variadic<TransformHandleTypeInterface>:$tile_sizes,
-                   Optional<TransformHandleTypeInterface>:$packed_num_threads,
-                   Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+                   Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+                   Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..4bf4db3381fab79 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -98,12 +98,34 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
       result.push_back(ofr);
       continue;
     }
-    auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+    Value transformValue = ofr.get<Value>();
+    if (isa<ParamType>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (!isa<IntegerAttr>(params[0]))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(params[0]);
+      continue;
+    }
+    if (isa<AnyParamType>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (!isa<ArrayAttr>(params[0]))
+        return transformOp.emitDefiniteFailure() << "expected ArrayAttr";
+      ArrayAttr paramsArray = cast<ArrayAttr>(params[0]);
+      for (Attribute param : paramsArray.getValue()) {
+        if (!isa<IntegerAttr>(param))
+          return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+        result.push_back(param);
+      }
+      continue;
+    }
+
+    auto payloadOps = state.getPayloadOps(transformValue);
     if (!llvm::hasSingleElement(payloadOps)) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "handle must be mapped to exactly one payload op";
-      diag.attachNote(ofr.get<Value>().getLoc())
+      diag.attachNote(transformValue.getLoc())
           << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
       return diag;
     }

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for you PR! Please address inline comments.

Also please add tests. The good practice is to have at least one non-trivial test for desired behavior, tests for edge cases, and tests for all user-visible diagnostic messages.

for (Attribute param : paramsArray.getValue()) {
if (!isa<IntegerAttr>(param))
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
result.push_back(param);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me what is the intended model here. Looking at the code, if the type is any_param, it will accept array-of-integer parameters and expand its context as tile sizes. But it will keep doing so for other parameters, so we may end up having more sizes than loops, and a surprising mismatch between parameter positions and loop depth. We could accept one parameter, and unpack it, but I don't think we want to accept multiple here. Also, there's no need to rely on array attributes as a parameter can be associated with a list of attributes (unless we want to support some multi-level scheme where the parameter is a list-of-lists of tile sizes to use for each of the loops associated with the input).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this was addressed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought moving the support of any_param to the second unpackSingleIndexResultPayloadOperations that takes a Value would address it since there won't be lists inside lists. Could you give me more guidence on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I recall, there are currently two modes of functioning here:

  • there are N unpacked handles, where each handle corresponds to one dimension; each handle points to the same number of operations as the "target" handle;
  • there is 1 packed handle, which points to as many operations as dimensions; same dimensions are used for tiling all operations associated with the "target" handle.

Parameters should behave similarly. Since parameters are associated with a list of attributes, we would treat that list identically to how we treat the list of operations for operation handle. Consequently, we should never expect an array attribute in parameters. There is already the implicit list and we must use that instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please review, thanks!

@ftynse
Copy link
Member

ftynse commented Nov 13, 2023

Also note that the utility function that is changed in this PR is used in more cases than TileUsingForall, make sure to test those as well.

Copy link

github-actions bot commented Nov 15, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@jinchen62 jinchen62 force-pushed the param branch 3 times, most recently from a04b3bd to 8d09cc2 Compare November 16, 2023 22:28
@jinchen62 jinchen62 requested a review from ftynse November 16, 2023 22:29
@jinchen62
Copy link
Contributor Author

@ftynse Thanks for the review! I've addressed the comments, please check.

@jinchen62 jinchen62 force-pushed the param branch 2 times, most recently from e50f7c3 to 48c045e Compare November 17, 2023 05:12
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for user-visible error messages (maybe you forgot to git add a file with those?)

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! It's almost ready to go, please address the remaining comments.

@jinchen62 jinchen62 force-pushed the param branch 4 times, most recently from fa63fb8 to 4851368 Compare January 24, 2024 10:17
@jinchen62 jinchen62 requested a review from ftynse January 24, 2024 10:18
@jinchen62 jinchen62 force-pushed the param branch 3 times, most recently from 988b39e to dcb10ea Compare January 25, 2024 18:02
@jinchen62
Copy link
Contributor Author

@ftynse Please review, thanks!

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

@ftynse
Copy link
Member

ftynse commented Jan 31, 2024

Windows problems are due to a missing numpy installation and are irrelevant here.

@ftynse ftynse merged commit d439f36 into llvm:main Jan 31, 2024
@jinchen62 jinchen62 deleted the param branch January 31, 2024 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants