Skip to content

Commit 451b1ff

Browse files
committed
[mlir] Add lower-to-loops tests for linalg.map/reduce/transpose.
Differential Revision: https://reviews.llvm.org/D136691
1 parent 2f88268 commit 451b1ff

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,84 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
159159
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
160160
// CHECK: %[[T10:.+]] = arith.maxf %[[T9]], %[[T8]]
161161
// CHECK: memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
162+
163+
// -----
164+
165+
func.func @map(%lhs: memref<64xf32>,
166+
%rhs: memref<64xf32>, %out: memref<64xf32>) {
167+
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
168+
outs(%out : memref<64xf32>)
169+
(%in: f32, %in_0: f32) {
170+
%0 = arith.addf %in, %in_0 : f32
171+
linalg.yield %0 : f32
172+
}
173+
return
174+
}
175+
// CHECK-LABEL: func.func @map(
176+
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
177+
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
178+
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<64xf32>) {
179+
180+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
181+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
182+
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
183+
184+
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
185+
// CHECK: %[[LHS_ELEM:.*]] = memref.load %[[LHS]][%[[I]]]
186+
// CHECK: %[[RHS_ELEM:.*]] = memref.load %[[RHS]][%[[I]]]
187+
// CHECK: %[[ADD:.*]] = arith.addf %[[LHS_ELEM]], %[[RHS_ELEM]]
188+
// CHECK: memref.store %[[ADD]], %[[OUT]][%[[I]]]
189+
190+
// -----
191+
192+
func.func @transpose(%arg0: memref<16x32x64xf32>,
193+
%arg1: memref<32x64x16xf32>) {
194+
linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
195+
outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
196+
return
197+
}
198+
// CHECK-LABEL: func.func @transpose(
199+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
200+
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>)
201+
202+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
203+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
204+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
205+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
206+
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
207+
208+
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
209+
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
210+
// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
211+
// CHECK: %[[ELEM:.*]] = memref.load %[[OUT]][%[[J]], %[[K]], %[[I]]]
212+
// CHECK: memref.store %[[ELEM]], %[[OUT]][%[[J]], %[[K]], %[[I]]]
213+
214+
// -----
215+
216+
func.func @reduce(%arg0: memref<16x32x64xf32>,
217+
%arg1: memref<16x64xf32>) {
218+
linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
219+
outs(%arg1 : memref<16x64xf32>) dimensions = [1]
220+
(%in: f32, %init: f32) {
221+
%0 = arith.addf %in, %init : f32
222+
linalg.yield %0 : f32
223+
}
224+
return
225+
}
226+
// CHECK-LABEL: func.func @reduce(
227+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
228+
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32>
229+
230+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
231+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
232+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
233+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
234+
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
235+
236+
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
237+
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
238+
// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
239+
// CHECK: %[[IN_ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[J]], %[[K]]]
240+
// CHECK: %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]]
241+
// CHECK: %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]]
242+
// CHECK: memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]]

0 commit comments

Comments
 (0)