Skip to content

Commit 7fe623b

Browse files
committed
fixup! [mlir][linalg] Add masked vectorisation for depthwise convolutions
Better LIT var names in tests
1 parent f996270 commit 7fe623b

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -124,59 +124,59 @@ module attributes {transform.with_named_sequence} {
124124

125125
// -----
126126

127-
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(%input: memref<3x5x?xf32>,
128-
%filter: memref<2x?xf32>,
129-
%output: memref<3x2x?xf32>) {
127+
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
128+
%input: memref<3x5x?xf32>,
129+
%filter: memref<2x?xf32>,
130+
%output: memref<3x2x?xf32>) {
130131
linalg.depthwise_conv_1d_nwc_wc
131132
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
132133
ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
133134
outs(%output : memref<3x2x?xf32>)
134135
return
135136
}
136137

137-
// TODO - nice variable names
138138
// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
139-
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5x?xf32>,
140-
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x?xf32>,
141-
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x?xf32>) {
139+
// CHECK-SAME: %[[INPUT:.*]]: memref<3x5x?xf32>,
140+
// CHECK-SAME: %[[FILTER:.*]]: memref<2x?xf32>,
141+
// CHECK-SAME: %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
142142

143-
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
144-
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
145-
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
146-
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
143+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
144+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
145+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
146+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
147147

148148
/// Create a mask for the input tensor
149-
// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : memref<3x5x?xf32>
150-
// CHECK: %[[VAL_8:.*]] = arith.constant 3 : index
151-
// CHECK: %[[VAL_9:.*]] = arith.constant 5 : index
152-
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] : vector<3x4x[4]xi1>
149+
// CHECK: %[[CH_DIM_IN:.*]] = memref.dim %[[INPUT]], %[[C2]] : memref<3x5x?xf32>
150+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
151+
// CHECK: %[[C5:.*]] = arith.constant 5 : index
152+
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C3]], %[[C5]], %[[CH_DIM_IN]] : vector<3x4x[4]xi1>
153153
/// Read the input tensor
154-
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
154+
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
155155

156156
/// Create a mask for the filter tensor
157-
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_3]] : memref<2x?xf32>
158-
// CHECK: %[[VAL_13:.*]] = vector.create_mask %[[VAL_6]], %[[VAL_12]] : vector<2x[4]xi1>
157+
// CHECK: %[[CH_DIM_FLT:.*]] = memref.dim %[[FILTER]], %[[C1]] : memref<2x?xf32>
158+
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C2]], %[[CH_DIM_FLT]] : vector<2x[4]xi1>
159159
/// Read the filter tensor
160-
// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
160+
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
161161

162162
/// Create a mask for the output tensor
163-
// CHECK: %[[VAL_15:.*]] = memref.dim %[[VAL_2]], %[[VAL_6]] : memref<3x2x?xf32>
164-
// CHECK: %[[VAL_16:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]], %[[VAL_15]] : vector<3x2x[4]xi1>
163+
// CHECK: %[[CH_DIM_OUT:.*]] = memref.dim %[[OUTPUT]], %[[C2]] : memref<3x2x?xf32>
164+
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C3]], %[[C2]], %[[CH_DIM_OUT]] : vector<3x2x[4]xi1>
165165
/// Read the output tensor
166-
// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
166+
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
167167

168168
/// Convolution
169-
// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_11]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
170-
// CHECK: %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_11]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
171-
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_14]][0] : vector<[4]xf32> from vector<2x[4]xf32>
172-
// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_14]][1] : vector<[4]xf32> from vector<2x[4]xf32>
173-
// CHECK: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_17]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
174-
// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_20]] : vector<[4]xf32> to vector<3x2x[4]xf32>
175-
// CHECK: %[[VAL_24:.*]] = vector.fma %[[VAL_18]], %[[VAL_23]], %[[VAL_22]] : vector<3x2x[4]xf32>
176-
// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_21]] : vector<[4]xf32> to vector<3x2x[4]xf32>
177-
// CHECK: %[[VAL_26:.*]] = vector.fma %[[VAL_19]], %[[VAL_25]], %[[VAL_24]] : vector<3x2x[4]xf32>
178-
// CHECK: %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_17]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
179-
// CHECK: vector.mask %[[VAL_16]] { vector.transfer_write %[[VAL_27]], %[[VAL_2]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
169+
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
170+
// CHECK: %[[IN_2:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
171+
// CHECK: %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
172+
// CHECK: %[[FLT_2:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
173+
// CHECK: %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
174+
// CHECK: %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
175+
// CHECK: %[[FMA_1:.*]] = vector.fma %[[IN_1]], %[[FLT_1_B]], %[[OUT_1]] : vector<3x2x[4]xf32>
176+
// CHECK: %[[FLT_2_B:.*]] = vector.broadcast %[[FLT_2]] : vector<[4]xf32> to vector<3x2x[4]xf32>
177+
// CHECK: %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
178+
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
179+
// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
180180

181181
module attributes {transform.with_named_sequence} {
182182
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)