-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Support ParamType
in vector_sizes
option of VectorizeOp
transform
#87557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Support ParamType
in vector_sizes
option of VectorizeOp
transform
#87557
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (srcarroll) ChangesFull diff: https://github.com/llvm/llvm-project/pull/87557.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c260fe3f7a46a5..7220e6e077e59c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2138,12 +2138,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<TransformHandleTypeInterface>:$vector_sizes,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
- $scalable_sizes,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
- $static_vector_sizes);
+ $scalable_sizes);
let results = (outs);
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 88819cd964354b..9c284ca309a455 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3136,6 +3136,12 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
auto attr = sz.get<Attribute>();
vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
continue;
+ } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
+ ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
+ assert(params.size() == 1 && "expected a single param");
+ vectorSizes.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
}
auto szPayloads = state.getPayloadOps(sz.get<Value>());
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2d01d57304013c..64e5935a90a4c4 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -36,6 +36,43 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @vectorize_dynamic_identity_with_param(%arg0: tensor<?xf32>,
+ %arg1: tensor<?xf32>,
+ %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"] }
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @vectorize_dynamic_identity_with_param
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
+// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<4xf32>
+// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %vector_size = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.structured.vectorize %0 vector_sizes [%vector_size : !transform.param<i64>] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @vectorize_dynamic_1d_broadcast(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>,
%arg2: tensor<?xf32>) -> tensor<?xf32> {
@@ -231,6 +268,49 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @vectorize_dynamic_transpose_reduction_with_params(%arg0: tensor<?x?x?xf32>,
+ %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>],
+ iterator_types = ["reduction", "parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?x?xf32>)
+ outs(%arg1 : tensor<?x?xf32>) {
+ ^bb(%in: f32, %out: f32) :
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %vector_size_0 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ %vector_size_2 = transform.param.constant 16 : i64 -> !transform.param<i64>
+ transform.structured.vectorize %0 vector_sizes
+ [%vector_size_0 : !transform.param<i64>, 8, %vector_size_2: !transform.param<i64>] : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: @vectorize_dynamic_transpose_reduction_with_params(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1>
+// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1>
+// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction <add>, %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+
+// -----
+
func.func @vectorize_partial_dynamic_identity(%arg0: tensor<8x?xf32>,
%arg1: tensor<8x?xf32>,
%arg2: tensor<8x?xf32>) -> tensor<8x?xf32> {
|
ParamType
in vector_sizes
option of VectorizeOp
transformParamType
in vector_sizes
option of VectorizeOp
transform
ParamType
in vector_sizes
option of VectorizeOp
transformParamType
in vector_sizes
option of VectorizeOp
transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, feel free to incorporate parser changes here or send another PR.
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
Ok i gave it a shot, but it was a bit more complex than i expected. I had a hard time figuring it out as i'm not experienced with more complex printing/parsing. so the implementation might be a big sub par and i very much welcome and appreciate suggestions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Please address the nits and add a tests for printing/parsing edge cases, e.g. empty list of sizes, list of sizes that fails to parse, incorrect number of trailing types.
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc)) | ||
return ParseResult::failure(); | ||
|
||
if (succeeded(parser.parseOptionalKeyword("vector_sizes"))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's have these as static const StringLiteral kVectorSizesKeyword
somewhere rather than inlined magic constants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder if it's a dumb idea to define these as static class methods ofVectorizeOp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have them as constants local to this file. There is no need for them to be available to all users of VectorizeOp
, and it would pollute the API and increase linking time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i made a constant for vector_sizes
but used getVectorizeNdExtractAttrName()
for vectorize_nd_extract
I can't think of anything that would fail other than invalid IR to begin with. do you have a test case in mind? |
We can have IR that fails to parse in tests. I'm just looking to exercise the code path were |
oh yah i know. i just mean i can't think of anything that would trigger failure with |
i guess i can just use the old format with the types inside the list edit: maybe not. i tried this
but i get |
i think i addressed everything with the exception of a test for dynamic list parsing |
✅ With the latest revision this PR passed the Python code formatter. |
Something like
|
that has the same issue i mentioned above |
oh actually the one i showed should work, i was just checking incorrectly. I guess i needed to check for two errors . the one you suggested gives |
%0 = transform.param.constant 2 : i64 -> !transform.param<i64> | ||
// expected-error@below {{expected ']' in dynamic index list}} | ||
// expected-error@below {{custom op 'transform.structured.vectorize' expected SSA value or integer}} | ||
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ftynse does this suffice for dynamic list parsing test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. It's strange that we get two diagnostics for the same error, but it's not in your code. You are welcome to investigate that in a separate patch if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! ya i may look into that in my downtime
No description provided.