Skip to content

Commit 1e6c8b3

Browse files
committed
[mlir][linalg] Enable Vectorization of 0-D tensor.extract
This patch removes an assert in `vectorizeTensorExtract` that was blocking the vectorization of 0-D tensor.extract operations, e.g.: ```mlir %1 = tensor.extract %src[] : tensor<f32> ``` As demonstrated by the included tests, this case is already effectively supported. **Context** The removed assert was introduced in llvm#109580 as a guard, pending proper support and testing for 0-D tensors. This PR addresses that previously undocumented TODO. Apologies for the oversight! **Updates and Tests** * Revised the existing test `@negative_no_loop` to ensure the `vectorize_nd_extract` attribute is included, allowing the vectorizer to process it. The test was renamed and variables updated for clarity. * Added a new test `@extract_scalar_from_0d_into_1d` to cover "mixed" 0-D/1-D tensor extraction, e.g.: ```mlir %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel"] } outs(%init : tensor<1xf32>) { ^bb0(%in: f32): %1 = tensor.extract %src[] : tensor<f32> linalg.yield %1 : f32 } -> tensor<1xf32> return %res : tensor<1xf32> ``` **Additional updates** I also took the liberty and improved test coverage for 0-D tensor in the vectorizer tests: * Added a specific test for "0D linalg.generic" in "vectorization-with-patterns.mlir". * Renamed several tests in "vectorization-with-patterns.mlir" to clarify that the 0-D case is now covered.
1 parent 1885886 commit 1e6c8b3

File tree

3 files changed

+93
-18
lines changed

3 files changed

+93
-18
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,11 +1115,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11151115
// b. contiguous loads.
11161116
// Both cases use vector.transfer_read.
11171117

1118-
assert(llvm::count_if(resultType.getShape(),
1119-
[](uint64_t dim) { return dim != 1; }) &&
1120-
"Contiguous loads and scalar loads + broadcast only support 1-D "
1121-
"vectors ATM!");
1122-
11231118
// Collect indices for `vector.transfer_read`. At this point, the indices will
11241119
// either be scalars or would have been broadcast to vectors matching the
11251120
// result type. For indices that are vectors, there are two options:

mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,48 @@ module attributes {transform.with_named_sequence} {
122122

123123
// -----
124124

125+
#map = affine_map<() -> ()>
126+
127+
// CHECK-LABEL: func.func @generic_0d(
128+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>, %[[ARG_2:.*]]: tensor<f32>)
129+
func.func @generic_0d(%arg0: tensor<f32>, %arg1: tensor<f32>,
130+
%arg2: tensor<f32>) -> tensor<f32> {
131+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
132+
// CHECK: %[[READ_0:.*]] = vector.transfer_read %[[ARG_0]][], %[[PAD]] : tensor<f32>, vector<f32>
133+
// CHECK: %[[ARG_0_AS_SCALAR:.*]] = vector.extract %[[READ_0]][] : f32 from vector<f32>
134+
// CHECK: %[[READ_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[PAD]] : tensor<f32>, vector<f32>
135+
// CHECK: %[[ARG_1_AS_SCALAR:.*]] = vector.extract %[[READ_1]][] : f32 from vector<f32>
136+
// CHECK: %[[READ_2:.*]] = vector.transfer_read %[[ARG_2]][], %[[PAD]] : tensor<f32>, vector<f32>
137+
// CHECK: %[[ARG_2_AS_SCALAR:.*]] = vector.extract %[[READ_2]][] : f32 from vector<f32>
138+
// CHECK: %[[MULF:.*]] = arith.mulf %[[ARG_0_AS_SCALAR]], %[[ARG_1_AS_SCALAR]] : f32
139+
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_2_AS_SCALAR]], %[[MULF]] : f32
140+
// CHECK: %[[ADDF_BCAST:.*]] = vector.broadcast %[[ADDF]] : f32 to vector<f32>
141+
// CHECK: vector.transfer_write %[[ADDF_BCAST]], %[[ARG_2]][] : vector<f32>, tensor<f32>
142+
%res = linalg.generic {
143+
indexing_maps = [#map, #map, #map],
144+
iterator_types = []
145+
} ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
146+
outs(%arg2 : tensor<f32>) {
147+
^bb(%a: f32, %b: f32, %c: f32) :
148+
%d = arith.mulf %a, %b: f32
149+
%e = arith.addf %c, %d: f32
150+
linalg.yield %e : f32
151+
} -> tensor<f32>
152+
153+
return %res : tensor<f32>
154+
}
155+
156+
module attributes {transform.with_named_sequence} {
157+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
158+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
159+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
160+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
161+
transform.yield
162+
}
163+
}
164+
165+
// -----
166+
125167
#matmul_transpose_out_trait = {
126168
indexing_maps = [
127169
affine_map<(m, n, k) -> (m, k)>,
@@ -372,7 +414,7 @@ module attributes {transform.with_named_sequence} {
372414
// -----
373415

374416
// CHECK-LABEL: func @test_vectorize_fill
375-
func.func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
417+
func.func @test_vectorize_fill_0d(%A : memref<f32>, %arg0 : f32) {
376418
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
377419
// CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
378420
// CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
@@ -410,8 +452,8 @@ module attributes {transform.with_named_sequence} {
410452

411453
// -----
412454

413-
// CHECK-LABEL: func @test_vectorize_copy_scalar
414-
func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
455+
// CHECK-LABEL: func @test_vectorize_copy_0d
456+
func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
415457
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
416458
// CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
417459
// CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,29 +39,67 @@ module attributes {transform.with_named_sequence} {
3939
// -----
4040

4141
#map = affine_map<() -> ()>
42-
func.func @negative_no_loops(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
43-
%1 = linalg.generic {
42+
func.func @extract_scalar_from_0d_into_0d(%src: tensor<f32>, %init: tensor<f32>) -> tensor<f32> {
43+
%res = linalg.generic {
4444
indexing_maps = [#map],
4545
iterator_types = []
46-
} outs(%arg1 : tensor<f32>) {
47-
^bb0(%arg4: f32):
48-
%2 = tensor.extract %arg0[] : tensor<f32>
49-
linalg.yield %2 : f32
46+
} outs(%init : tensor<f32>) {
47+
^bb0(%in: f32):
48+
%1 = tensor.extract %src[] : tensor<f32>
49+
linalg.yield %1 : f32
5050
} -> tensor<f32>
51-
return %1 : tensor<f32>
51+
52+
return %res : tensor<f32>
5253
}
53-
// CHECK-LABEL: func.func @negative_no_loops
54-
// CHECK: tensor.extract
54+
55+
// CHECK-LABEL: func.func @extract_scalar_from_0d_into_0d(
56+
// CHECK-SAME: %[[SRC:.*]]: tensor<f32>,
57+
// CHECK-SAME: %[[INIT:.*]]: tensor<f32>) -> tensor<f32> {
58+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
59+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][], %[[PAD]] : tensor<f32>, vector<f32>
60+
// CHECK: vector.transfer_write %[[READ]], %[[INIT]][] : vector<f32>, tensor<f32>
5561

5662
module attributes {transform.with_named_sequence} {
5763
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5864
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
5965
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
60-
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
66+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
6167
transform.yield
6268
}
6369
}
6470

71+
// -----
72+
73+
#map = affine_map<(n) -> (n)>
74+
func.func @extract_scalar_from_0d_into_1d(%src: tensor<f32>, %init: tensor<1xf32>) -> tensor<1xf32> {
75+
%res = linalg.generic {
76+
indexing_maps = [#map],
77+
iterator_types = ["parallel"]
78+
} outs(%init : tensor<1xf32>) {
79+
^bb0(%in: f32):
80+
%1 = tensor.extract %src[] : tensor<f32>
81+
linalg.yield %1 : f32
82+
} -> tensor<1xf32>
83+
84+
return %res : tensor<1xf32>
85+
}
86+
// CHECK-LABEL: func.func @extract_scalar_from_0d_into_1d(
87+
// CHECK-SAME: %[[SRC:.*]]: tensor<f32>,
88+
// CHECK-SAME: %[[INIT:.*]]: tensor<1xf32>) -> tensor<1xf32> {
89+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
90+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
91+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][], %[[PAD]] : tensor<f32>, vector<f32>
92+
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<f32> to vector<1xf32>
93+
// CHECK: vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]]] {in_bounds = [true]} : vector<1xf32>, tensor<1xf32>
94+
95+
module attributes {transform.with_named_sequence} {
96+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
97+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
98+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
99+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
100+
transform.yield
101+
}
102+
}
65103

66104
// -----
67105

0 commit comments

Comments
 (0)