@@ -124,59 +124,59 @@ module attributes {transform.with_named_sequence} {
124
124
125
125
// -----
126
126
127
- func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2 (%input: memref <3 x5 x?xf32 >,
128
- %filter: memref <2 x?xf32 >,
129
- %output: memref <3 x2 x?xf32 >) {
127
+ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2 (
128
+ %input: memref <3 x5 x?xf32 >,
129
+ %filter: memref <2 x?xf32 >,
130
+ %output: memref <3 x2 x?xf32 >) {
130
131
linalg.depthwise_conv_1d_nwc_wc
131
132
{dilations = dense <2 > : tensor <1 xi64 >, strides = dense <1 > : tensor <1 xi64 >}
132
133
ins (%input , %filter : memref <3 x5 x?xf32 >, memref <2 x?xf32 >)
133
134
outs (%output : memref <3 x2 x?xf32 >)
134
135
return
135
136
}
136
137
137
- // TODO - nice variable names
138
138
// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
139
- // CHECK-SAME: %[[VAL_0 :.*]]: memref<3x5x?xf32>,
140
- // CHECK-SAME: %[[VAL_1 :.*]]: memref<2x?xf32>,
141
- // CHECK-SAME: %[[VAL_2 :.*]]: memref<3x2x?xf32>) {
139
+ // CHECK-SAME: %[[INPUT :.*]]: memref<3x5x?xf32>,
140
+ // CHECK-SAME: %[[FILTER :.*]]: memref<2x?xf32>,
141
+ // CHECK-SAME: %[[OUTPUT :.*]]: memref<3x2x?xf32>) {
142
142
143
- // CHECK: %[[VAL_3 :.*]] = arith.constant 1 : index
144
- // CHECK: %[[VAL_4 :.*]] = arith.constant 0 : index
145
- // CHECK: %[[VAL_5 :.*]] = arith.constant 0.000000e+00 : f32
146
- // CHECK: %[[VAL_6 :.*]] = arith.constant 2 : index
143
+ // CHECK: %[[C1 :.*]] = arith.constant 1 : index
144
+ // CHECK: %[[C0 :.*]] = arith.constant 0 : index
145
+ // CHECK: %[[PAD :.*]] = arith.constant 0.000000e+00 : f32
146
+ // CHECK: %[[C2 :.*]] = arith.constant 2 : index
147
147
148
148
/// Create a mask for the input tensor
149
- // CHECK: %[[VAL_7 :.*]] = memref.dim %[[VAL_0 ]], %[[VAL_6 ]] : memref<3x5x?xf32>
150
- // CHECK: %[[VAL_8 :.*]] = arith.constant 3 : index
151
- // CHECK: %[[VAL_9 :.*]] = arith.constant 5 : index
152
- // CHECK: %[[VAL_10 :.*]] = vector.create_mask %[[VAL_8 ]], %[[VAL_9 ]], %[[VAL_7 ]] : vector<3x4x[4]xi1>
149
+ // CHECK: %[[CH_DIM_IN :.*]] = memref.dim %[[INPUT ]], %[[C2 ]] : memref<3x5x?xf32>
150
+ // CHECK: %[[C3 :.*]] = arith.constant 3 : index
151
+ // CHECK: %[[C5 :.*]] = arith.constant 5 : index
152
+ // CHECK: %[[MASK_IN :.*]] = vector.create_mask %[[C3 ]], %[[C5 ]], %[[CH_DIM_IN ]] : vector<3x4x[4]xi1>
153
153
/// Read the input tensor
154
- // CHECK: %[[VAL_11 :.*]] = vector.mask %[[VAL_10 ]] { vector.transfer_read %[[VAL_0 ]]{{\[}}%[[VAL_4 ]], %[[VAL_4 ]], %[[VAL_4 ]]], %[[VAL_5 ]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
154
+ // CHECK: %[[VEC_IN :.*]] = vector.mask %[[MASK_IN ]] { vector.transfer_read %[[INPUT ]]{{\[}}%[[C0 ]], %[[C0 ]], %[[C0 ]]], %[[PAD ]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
155
155
156
156
/// Create a mask for the filter tensor
157
- // CHECK: %[[VAL_12 :.*]] = memref.dim %[[VAL_1 ]], %[[VAL_3 ]] : memref<2x?xf32>
158
- // CHECK: %[[VAL_13 :.*]] = vector.create_mask %[[VAL_6 ]], %[[VAL_12 ]] : vector<2x[4]xi1>
157
+ // CHECK: %[[CH_DIM_FLT :.*]] = memref.dim %[[FILTER ]], %[[C1 ]] : memref<2x?xf32>
158
+ // CHECK: %[[MASK_FLT :.*]] = vector.create_mask %[[C2 ]], %[[CH_DIM_FLT ]] : vector<2x[4]xi1>
159
159
/// Read the filter tensor
160
- // CHECK: %[[VAL_14 :.*]] = vector.mask %[[VAL_13 ]] { vector.transfer_read %[[VAL_1 ]]{{\[}}%[[VAL_4 ]], %[[VAL_4 ]]], %[[VAL_5 ]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
160
+ // CHECK: %[[VEC_FLT :.*]] = vector.mask %[[MASK_FLT ]] { vector.transfer_read %[[FILTER ]]{{\[}}%[[C0 ]], %[[C0 ]]], %[[PAD ]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
161
161
162
162
/// Create a mask for the output tensor
163
- // CHECK: %[[VAL_15 :.*]] = memref.dim %[[VAL_2 ]], %[[VAL_6 ]] : memref<3x2x?xf32>
164
- // CHECK: %[[VAL_16 :.*]] = vector.create_mask %[[VAL_8 ]], %[[VAL_6 ]], %[[VAL_15 ]] : vector<3x2x[4]xi1>
163
+ // CHECK: %[[CH_DIM_OUT :.*]] = memref.dim %[[OUTPUT ]], %[[C2 ]] : memref<3x2x?xf32>
164
+ // CHECK: %[[MASK_OUT :.*]] = vector.create_mask %[[C3 ]], %[[C2 ]], %[[CH_DIM_OUT ]] : vector<3x2x[4]xi1>
165
165
/// Read the output tensor
166
- // CHECK: %[[VAL_17 :.*]] = vector.mask %[[VAL_16 ]] { vector.transfer_read %[[VAL_2 ]]{{\[}}%[[VAL_4 ]], %[[VAL_4 ]], %[[VAL_4 ]]], %[[VAL_5 ]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
166
+ // CHECK: %[[VEC_OUT :.*]] = vector.mask %[[MASK_OUT ]] { vector.transfer_read %[[OUTPUT ]]{{\[}}%[[C0 ]], %[[C0 ]], %[[C0 ]]], %[[PAD ]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
167
167
168
168
/// Convolution
169
- // CHECK: %[[VAL_18 :.*]] = vector.extract_strided_slice %[[VAL_11 ]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
170
- // CHECK: %[[VAL_19 :.*]] = vector.extract_strided_slice %[[VAL_11 ]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
171
- // CHECK: %[[VAL_20 :.*]] = vector.extract %[[VAL_14 ]][0] : vector<[4]xf32> from vector<2x[4]xf32>
172
- // CHECK: %[[VAL_21 :.*]] = vector.extract %[[VAL_14 ]][1] : vector<[4]xf32> from vector<2x[4]xf32>
173
- // CHECK: %[[VAL_22 :.*]] = vector.extract_strided_slice %[[VAL_17 ]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
174
- // CHECK: %[[VAL_23 :.*]] = vector.broadcast %[[VAL_20 ]] : vector<[4]xf32> to vector<3x2x[4]xf32>
175
- // CHECK: %[[VAL_24 :.*]] = vector.fma %[[VAL_18 ]], %[[VAL_23 ]], %[[VAL_22 ]] : vector<3x2x[4]xf32>
176
- // CHECK: %[[VAL_25 :.*]] = vector.broadcast %[[VAL_21 ]] : vector<[4]xf32> to vector<3x2x[4]xf32>
177
- // CHECK: %[[VAL_26 :.*]] = vector.fma %[[VAL_19 ]], %[[VAL_25 ]], %[[VAL_24 ]] : vector<3x2x[4]xf32>
178
- // CHECK: %[[VAL_27 :.*]] = vector.insert_strided_slice %[[VAL_26 ]], %[[VAL_17 ]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
179
- // CHECK: vector.mask %[[VAL_16 ]] { vector.transfer_write %[[VAL_27 ]], %[[VAL_2 ]]{{\[}}%[[VAL_4 ]], %[[VAL_4 ]], %[[VAL_4 ]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
169
+ // CHECK: %[[IN_1 :.*]] = vector.extract_strided_slice %[[VEC_IN ]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
170
+ // CHECK: %[[IN_2 :.*]] = vector.extract_strided_slice %[[VEC_IN ]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
171
+ // CHECK: %[[FLT_1 :.*]] = vector.extract %[[VEC_FLT ]][0] : vector<[4]xf32> from vector<2x[4]xf32>
172
+ // CHECK: %[[FLT_2 :.*]] = vector.extract %[[VEC_FLT ]][1] : vector<[4]xf32> from vector<2x[4]xf32>
173
+ // CHECK: %[[OUT_1 :.*]] = vector.extract_strided_slice %[[VEC_OUT ]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
174
+ // CHECK: %[[FLT_1_B :.*]] = vector.broadcast %[[FLT_1 ]] : vector<[4]xf32> to vector<3x2x[4]xf32>
175
+ // CHECK: %[[FMA_1 :.*]] = vector.fma %[[IN_1 ]], %[[FLT_1_B ]], %[[OUT_1 ]] : vector<3x2x[4]xf32>
176
+ // CHECK: %[[FLT_2_B :.*]] = vector.broadcast %[[FLT_2 ]] : vector<[4]xf32> to vector<3x2x[4]xf32>
177
+ // CHECK: %[[FMA_2 :.*]] = vector.fma %[[IN_2 ]], %[[FLT_2_B ]], %[[FMA_1 ]] : vector<3x2x[4]xf32>
178
+ // CHECK: %[[OUT_INS :.*]] = vector.insert_strided_slice %[[FMA_2 ]], %[[VEC_OUT ]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
179
+ // CHECK: vector.mask %[[MASK_OUT ]] { vector.transfer_write %[[OUT_INS ]], %[[OUTPUT ]]{{\[}}%[[C0 ]], %[[C0 ]], %[[C0 ]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
180
180
181
181
module attributes {transform.with_named_sequence } {
182
182
transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
0 commit comments