@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
70
70
// CHECK-LABEL: func @uncollapsable(
71
71
// CHECK: linalg.generic
72
72
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
73
+
74
+ // -----
75
+
76
+ // CHECK-LABEL: func.func private @collapsable_memref(
77
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
78
+ // CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
79
+ // CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
80
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
81
+ // CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
82
+ // CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
83
+ // CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
84
+ // CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
85
+ // CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
86
+ // CHECK: linalg.yield %[[VAL_9]] : f32
87
+ // CHECK: }
88
+ // CHECK: return %[[VAL_2]] : memref<1x24x32x8xf32>
89
+ // CHECK: }
90
+
91
+ func.func private @collapsable_memref (%arg0: memref <1 x24 x32 x8 xf32 >, %arg1: memref <1 x24 x32 x8 xf32 >) -> (memref <1 x24 x32 x8 xf32 >) {
92
+ %alloc = memref.alloc () {alignment = 64 : i64 } : memref <1 x24 x32 x8 xf32 >
93
+ linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : memref <1 x24 x32 x8 xf32 >, memref <1 x24 x32 x8 xf32 >) outs (%alloc : memref <1 x24 x32 x8 xf32 >) {
94
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
95
+ %0 = arith.addf %in , %in_0 : f32
96
+ linalg.yield %0 : f32
97
+ }
98
+ return %alloc : memref <1 x24 x32 x8 xf32 >
99
+ }
100
+
101
+ // -----
102
+
103
+ // CHECK-LABEL: func @uncollapsable_strided_memref(
104
+ // CHECK: linalg.generic
105
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
106
+
107
+ func.func @uncollapsable_strided_memref (%arg0: memref <2 x6 x24 x48 xi32 >, %arg1: memref <2 x6 x24 x48 xi32 >) -> (memref <2 x6 x24 x48 xi32 >) {
108
+ %alloc = memref.alloc () {alignment = 64 : i64 } : memref <2 x6 x24 x48 xi32 >
109
+ %subview = memref.subview %arg0 [0 , 0 , 0 , 0 ] [1 , 3 , 12 , 24 ] [1 , 1 , 1 , 1 ] : memref <2 x6 x24 x48 xi32 > to memref <1 x3 x12 x24 xi32 , strided <[6912 , 1152 , 48 , 1 ], offset : 0 >>
110
+ %subview0 = memref.subview %arg1 [0 , 0 , 0 , 0 ] [1 , 3 , 12 , 24 ] [1 , 1 , 1 , 1 ] : memref <2 x6 x24 x48 xi32 > to memref <1 x3 x12 x24 xi32 , strided <[6912 , 1152 , 48 , 1 ], offset : 0 >>
111
+ %subview1 = memref.subview %alloc [0 , 0 , 0 , 0 ] [1 , 3 , 12 , 24 ] [1 , 1 , 1 , 1 ] : memref <2 x6 x24 x48 xi32 > to memref <1 x3 x12 x24 xi32 , strided <[6912 , 1152 , 48 , 1 ], offset : 0 >>
112
+ linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%subview , %subview0 : memref <1 x3 x12 x24 xi32 , strided <[6912 , 1152 , 48 , 1 ], offset : 0 >>, memref <1 x3 x12 x24 xi32 , strided <[6912 , 1152 , 48 , 1 ], offset : 0 >>) outs (%subview1 : memref <1 x3 x12 x24 xi32 , strided <[6912 , 1152 , 48 , 1 ], offset : 0 >>) {
113
+ ^bb0 (%in: i32 , %in_0: i32 , %out: i32 ):
114
+ %0 = arith.addi %in , %in_0 : i32
115
+ linalg.yield %0 : i32
116
+ }
117
+ return %alloc : memref <2 x6 x24 x48 xi32 >
118
+ }
0 commit comments