Skip to content

Commit 39d7b84

Browse files
committed
[mlir][nfc] Update vectorize-tensor-extract.mlir (1/N)
Tests in "vectorize-tensor-extract.mlir" are inconsistent and would benefit from refactoring to: * Clearly categorize tests into "contiguous load," "gather load," and "scalar load + broadcast" cases, reflecting the structure of tensor.extract vectorization. * Unify variable naming (both MLIR and FileCheck). * Ensure all tests exercise unmasked vectorization (masked vectorization is covered in "vectorize-tensor-extract-masked.mlir"). * Improve and standardize formatting. These changes will make it easier to identify the test cases being exercised and simplify future maintenance or refactoring. This is patch 1/N in the series. Below is a summary of the changes in this patch. ---------------------------------------------------------------------- This PR updates the `@vectorize_scalar_broadcast_column_tensor` test in "vectorize-tensor-extract.mlir", which exercises: * Vectorization of tensor.extract. * A scalar read followed by a broadcast. * Reading from a constant column tensor. Currently, the test uses "masked" vectorization, but the file exclusively tests unmasked vectorization paths. To address this inconsistency, this PR removes masking, aligning the test with the rest of the file. Masked vectorization scenarios remain covered in "vectorize-tensor-extract-masked.mlir". This update switches from: * `transform.structured.vectorize`, to * `transform.structured.vectorize_children_and_apply_patterns`. The latter approach applies canonicalization patterns, significantly simplifying the generated output. Additional improvements for readability: * Renamed the test function for clarity. * Updated variable names and removed unused variables. * Added empty lines for better formatting.
1 parent 1e6c8b3 commit 39d7b84

File tree

1 file changed

+26
-41
lines changed

1 file changed

+26
-41
lines changed

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -845,56 +845,41 @@ module attributes {transform.with_named_sequence} {
845845

846846
// -----
847847

848-
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
848+
func.func @vectorize_scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
849849
%c4 = arith.constant 4 : index
850850
%c0 = arith.constant 0 : index
851-
%cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
852-
853-
%out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
854-
^bb0(%out: i32):
855-
%8 = linalg.index 0 : index
856-
%idx_0 = linalg.index 0 : index
857-
%extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
858-
linalg.yield %extracted : i32
851+
%src = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
852+
853+
%res = linalg.generic {
854+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
855+
iterator_types = ["parallel", "parallel", "parallel"]}
856+
outs(%init : tensor<1x1x4xi32>) {
857+
858+
^bb0(%out: i32):
859+
%idx = linalg.index 0 : index
860+
%extracted = tensor.extract %src[%idx, %c0] : tensor<15x1xi32>
861+
linalg.yield %extracted : i32
859862
} -> tensor<1x1x4xi32>
860863

861-
return %out:tensor<1x1x4xi32>
864+
return %res : tensor<1x1x4xi32>
862865
}
863866

864-
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
865-
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
866-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
867-
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
868-
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
869-
// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
870-
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
871-
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
872-
// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
873-
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
874-
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
875-
// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
876-
// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
877-
// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
878-
// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
879-
// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
880-
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
881-
// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
882-
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
883-
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
884-
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
885-
// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
886-
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
887-
// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
888-
// CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
889-
// CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
890-
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
891-
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
892-
// CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
867+
// CHECK-LABEL: func.func @vectorize_scalar_read_with_broadcast_from_column_tensor(
868+
// CHECK-SAME: %[[INIT:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
869+
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
870+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
871+
// CHECK: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
872+
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
873+
// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
874+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
875+
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
876+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
893877

894878
module attributes {transform.with_named_sequence} {
895879
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
896-
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
897-
transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
880+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
881+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
882+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
898883
transform.yield
899884
}
900885
}

0 commit comments

Comments
 (0)