@@ -864,3 +864,63 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
864
864
865
865
return %0 , %1: tensor <f32 >, tensor <vector <2 x4 xf32 >>
866
866
}
867
+
868
+ // -----
869
+
870
+ // CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3 + d5)>
871
+ // CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
872
+ // CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
873
+ // CHECK: func @gen_grouped_1D_ngcs_gfcs_ngfs_memref
874
+ func.func @gen_grouped_1D_ngcs_gfcs_ngfs_memref (%arg0: memref <64 x8 x16 x10 xf32 >, %arg1: memref <8 x32 x16 x3 xf32 >, %arg2: memref <64 x8 x32 x8 xf32 >) {
875
+ // CHECK: linalg.generic
876
+ // CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
877
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
878
+ // CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
879
+ // CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
880
+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
881
+ // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
882
+ // CHECK-NEXT: linalg.yield %[[ADD]] : f32
883
+ // CHECK-NEXT: }
884
+ linalg.grouped_conv_nd {layouts = [" ngcs" , " gfcs" , " ngfs" ]} ins (%arg0 , %arg1: memref <64 x8 x16 x10 xf32 >, memref <8 x32 x16 x3 xf32 >) outs (%arg2: memref <64 x8 x32 x8 xf32 >)
885
+ return
886
+ }
887
+
888
+ // -----
889
+
890
+ // CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2)>
891
+ // CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
892
+ // CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
893
+ // CHECK: func @gen_grouped_2D_ngcs_gfcs_ngfs_memref
894
+ func.func @gen_grouped_2D_ngcs_gfcs_ngfs_memref (%arg0: memref <64 x2 x16 x26 x26 xf32 >, %arg1: memref <2 x20 x16 x3 x3 xf32 >, %arg2: memref <64 x2 x20 x8 x8 xf32 >) {
895
+ // CHECK: linalg.generic
896
+ // CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
897
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
898
+ // CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
899
+ // CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
900
+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
901
+ // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
902
+ // CHECK-NEXT: linalg.yield %[[ADD]] : f32
903
+ // CHECK-NEXT: }
904
+ linalg.grouped_conv_nd {strides = dense <3 > : memref <2 xi64 >, dilations = dense <2 > : memref <2 xi64 >, layouts = [" ngcs" , " gfcs" , " ngfs" ]} ins (%arg0 , %arg1: memref <64 x2 x16 x26 x26 xf32 >, memref <2 x20 x16 x3 x3 xf32 >) outs (%arg2: memref <64 x2 x20 x8 x8 xf32 >)
905
+ return
906
+ }
907
+
908
+ // -----
909
+
910
+ // CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3 * 3 + d6 * 2, d4 * 3 + d7 * 2, d5)>
911
+ // CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d6, d7, d2, d5)>
912
+ // CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d4, d2)>
913
+ // CHECK: func @gen_grouped_2D_ngsc_gsfc_ngsf_memref
914
+ func.func @gen_grouped_2D_ngsc_gsfc_ngsf_memref (%arg0: memref <64 x2 x26 x26 x16 xf32 >, %arg1: memref <2 x3 x3 x20 x16 xf32 >, %arg2: memref <64 x2 x8 x8 x20 xf32 >) {
915
+ // CHECK: linalg.generic
916
+ // CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
917
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
918
+ // CHECK-SAME: ins(%arg0, %arg1 : {{.*}}) outs(%arg2 : {{.*}})
919
+ // CHECK-NEXT: ^bb0(%[[IN_0:.*]]: f32, %[[IN_1:.*]]: f32, %[[OUT:.*]]: f32):
920
+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN_0]], %[[IN_1]] : f32
921
+ // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] : f32
922
+ // CHECK-NEXT: linalg.yield %[[ADD]] : f32
923
+ // CHECK-NEXT: }
924
+ linalg.grouped_conv_nd {strides = dense <3 > : memref <2 xi64 >, dilations = dense <2 > : memref <2 xi64 >, layouts = [" ngsc" , " gsfc" , " ngsf" ]} ins (%arg0 , %arg1: memref <64 x2 x26 x26 x16 xf32 >, memref <2 x3 x3 x20 x16 xf32 >) outs (%arg2: memref <64 x2 x8 x8 x20 xf32 >)
925
+ return
926
+ }
0 commit comments