Skip to content

Commit 30ff1a1

Browse files
committed
add a couple tests to convert to generic
1 parent f0678a4 commit 30ff1a1

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,63 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
864864

865865
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
866866
}
867+
868+
// -----
869+
870+
// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3 + d5)>
871+
// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
872+
// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
873+
// CHECK: func @gen_grouped_1D_ngcs_gfcs_ngfs_memref
874+
func.func @gen_grouped_1D_ngcs_gfcs_ngfs_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) {
875+
// CHECK: linalg.generic
876+
// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
877+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
878+
// CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
879+
// CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
880+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
881+
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
882+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
883+
// CHECK-NEXT: }
884+
linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>)
885+
return
886+
}
887+
888+
// -----
889+
890+
// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2)>
891+
// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
892+
// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
893+
// CHECK: func @gen_grouped_2D_ngcs_gfcs_ngfs_memref
894+
func.func @gen_grouped_2D_ngcs_gfcs_ngfs_memref(%arg0: memref<64x2x16x26x26xf32>, %arg1: memref<2x20x16x3x3xf32>, %arg2: memref<64x2x20x8x8xf32>) {
895+
// CHECK: linalg.generic
896+
// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
897+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
898+
// CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
899+
// CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
900+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
901+
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
902+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
903+
// CHECK-NEXT: }
904+
linalg.grouped_conv_nd {strides = dense<3> : memref<2xi64>, dilations = dense<2> : memref<2xi64>, layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1: memref<64x2x16x26x26xf32>, memref<2x20x16x3x3xf32>) outs(%arg2: memref<64x2x20x8x8xf32>)
905+
return
906+
}
907+
908+
// -----
909+
910+
// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2, d5)>
911+
// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d6, d7, d2, d5)>
912+
// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d4, d2)>
913+
// CHECK: func @gen_grouped_2D_ngsc_gsfc_ngsf_memref
914+
func.func @gen_grouped_2D_ngsc_gsfc_ngsf_memref(%arg0: memref<64x2x26x26x16xf32>, %arg1: memref<2x3x3x20x16xf32>, %arg2: memref<64x2x8x8x20xf32>) {
915+
// CHECK: linalg.generic
916+
// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
917+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
918+
// CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
919+
// CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
920+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
921+
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
922+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
923+
// CHECK-NEXT: }
924+
linalg.grouped_conv_nd {strides = dense<3> : memref<2xi64>, dilations = dense<2> : memref<2xi64>, layouts = ["ngsc", "gsfc", "ngsf"]} ins(%arg0, %arg1: memref<64x2x26x26x16xf32>, memref<2x3x3x20x16xf32>) outs(%arg2: memref<64x2x8x8x20xf32>)
925+
return
926+
}

0 commit comments

Comments
 (0)