Skip to content

[mlir][linalg] Support ParamType in vector_sizes option of VectorizeOp transform #87557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 9, 2024

Conversation

srcarroll
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (srcarroll)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/87557.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-4)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+6)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+80)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c260fe3f7a46a5..7220e6e077e59c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2138,12 +2138,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       Variadic<TransformHandleTypeInterface>:$vector_sizes,
+                       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
                        OptionalAttr<UnitAttr>:$vectorize_nd_extract,
                        DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
-                          $scalable_sizes,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes);
+                          $scalable_sizes);
 
   let results = (outs);
   let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 88819cd964354b..9c284ca309a455 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3136,6 +3136,12 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
       auto attr = sz.get<Attribute>();
       vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
       continue;
+    } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
+      ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
+      assert(params.size() == 1 && "expected a single param");
+      vectorSizes.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
     }
 
     auto szPayloads = state.getPayloadOps(sz.get<Value>());
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2d01d57304013c..64e5935a90a4c4 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -36,6 +36,43 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @vectorize_dynamic_identity_with_param(%arg0: tensor<?xf32>,
+                                                 %arg1: tensor<?xf32>,
+                                                 %arg2: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> (d0)>],
+                   iterator_types = ["parallel"] }
+    ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+    outs(%arg2 : tensor<?xf32>) {
+    ^bb(%in0: f32, %in1: f32, %out: f32) :
+      %0 = arith.addf %in0, %in1 : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_identity_with_param
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
+// CHECK:           %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<4xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %vector_size = transform.param.constant 4 : i64 -> !transform.param<i64>
+    transform.structured.vectorize %0 vector_sizes [%vector_size : !transform.param<i64>] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @vectorize_dynamic_1d_broadcast(%arg0: tensor<?xf32>,
                                           %arg1: tensor<?xf32>,
                                           %arg2: tensor<?xf32>) -> tensor<?xf32> {
@@ -231,6 +268,49 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @vectorize_dynamic_transpose_reduction_with_params(%arg0: tensor<?x?x?xf32>,
+                                                             %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                         affine_map<(d0, d1, d2) -> (d2, d1)>],
+                        iterator_types = ["reduction", "parallel", "parallel"] }
+    ins(%arg0 : tensor<?x?x?xf32>)
+    outs(%arg1 : tensor<?x?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %vector_size_0 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    %vector_size_2 = transform.param.constant 16 : i64 -> !transform.param<i64>
+    transform.structured.vectorize %0 vector_sizes
+      [%vector_size_0 : !transform.param<i64>, 8, %vector_size_2: !transform.param<i64>] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_transpose_reduction_with_params(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1>
+// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction <add>, %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+
+// -----
+
 func.func @vectorize_partial_dynamic_identity(%arg0: tensor<8x?xf32>,
                                               %arg1: tensor<8x?xf32>,
                                               %arg2: tensor<8x?xf32>) -> tensor<8x?xf32> {

@srcarroll srcarroll changed the title Support ParamType in vector_sizes option of VectorizeOp transform [mlir][linalg]Support ParamType in vector_sizes option of VectorizeOp transform Apr 3, 2024
@srcarroll srcarroll changed the title [mlir][linalg]Support ParamType in vector_sizes option of VectorizeOp transform [mlir][linalg] Support ParamType in vector_sizes option of VectorizeOp transform Apr 3, 2024
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, feel free to incorporate parser changes here or send another PR.

@srcarroll
Copy link
Contributor Author

It should be possible to reuse parts of the parser between this and tile_using_for, and is actually preferable since it ensure the syntax remains consistent over time. The other op is relying on mlir::parseDynamicIndexList AFAICS, so there isn't much to implement.

Ok i gave it a shot, but it was a bit more complex than i expected. I had a hard time figuring it out as i'm not experienced with more complex printing/parsing. so the implementation might be a big sub par and i very much welcome and appreciate suggestions

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Please address the nits and add a tests for printing/parsing edge cases, e.g. empty list of sizes, list of sizes that fails to parse, incorrect number of trailing types.

if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
return ParseResult::failure();

if (succeeded(parser.parseOptionalKeyword("vector_sizes"))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's have these as static const StringLiteral kVectorSizesKeyword somewhere rather than inlined magic constants.

Copy link
Contributor Author

@srcarroll srcarroll Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder if it's a dumb idea to define these as static class methods ofVectorizeOp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have them as constants local to this file. There is no need for them to be available to all users of VectorizeOp, and it would pollute the API and increase linking time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i made a constant for vector_sizes but used getVectorizeNdExtractAttrName() for vectorize_nd_extract

@srcarroll
Copy link
Contributor Author

list of sizes that fails to parse

I can't think of anything that would fail other than invalid IR to begin with. do you have a test case in mind?

@ftynse
Copy link
Member

ftynse commented Apr 5, 2024

I can't think of anything that would fail other than invalid IR to begin with. do you have a test case in mind?

We can have IR that fails to parse in tests. I'm just looking to exercise the code path were parseDynamicIndexList returns failure to see if there is a diagnostic emitted.

@srcarroll
Copy link
Contributor Author

oh yah i know. i just mean i can't think of anything that would trigger failure with parseDynamicIndexList specifically

@srcarroll
Copy link
Contributor Author

srcarroll commented Apr 5, 2024

oh yah i know. i just mean i can't think of anything that would trigger failure with parseDynamicIndexList specifically

i guess i can just use the old format with the types inside the list

edit: maybe not. i tried this

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
  %0 = transform.param.constant 2 : i64 -> !transform.param<i64>
  // expected-error@below {{expected ']' in dynamic index list}}
  transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>

}

but i get custom op 'transform.structured.vectorize' expected SSA value or integer before the dynamic list error happens

@srcarroll
Copy link
Contributor Author

i think i addressed everything with the exception of a test for dynamic list parsing

Copy link

github-actions bot commented Apr 5, 2024

✅ With the latest revision this PR passed the Python code formatter.

@ftynse
Copy link
Member

ftynse commented Apr 8, 2024

Something like

  transform.structured.vectorize %arg0 vector_sizes [!transform.any_type] : !transform.any_op, !transform.param<i64>

@srcarroll
Copy link
Contributor Author

[!transform.any_type] : !transform.any_op, !transform.param

that has the same issue i mentioned above

@srcarroll
Copy link
Contributor Author

srcarroll commented Apr 8, 2024

oh actually the one i showed should work, i was just checking incorrectly. I guess i needed to check for two errors . the one you suggested gives custom op 'transform.structured.vectorize' expected integer value which i dont think is directly related to the dynamic list parsing

%0 = transform.param.constant 2 : i64 -> !transform.param<i64>
// expected-error@below {{expected ']' in dynamic index list}}
// expected-error@below {{custom op 'transform.structured.vectorize' expected SSA value or integer}}
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ftynse does this suffice for dynamic list parsing test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. It's strange that we get two diagnostics for the same error, but it's not in your code. You are welcome to investigate that in a separate patch if you want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! ya i may look into that in my downtime

@srcarroll srcarroll merged commit b79db39 into llvm:main Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants