Skip to content

Commit f9cb859

Browse files
committed
[mlir][tensor][linalg] Move Pack/Unpack Ops to Linalg (2/4)
This is merely moving code around, no new functionality is added. PATCH 2: To verify the newly added Ops (and to make the subsequent change smaller), this PR: 1. Moves tests from: * "mlir/test/Dialect/Tensor/ops.mlir" to: * "mlir/test/Dialect/Linalg/named-ops.mlir" 2. Moves tests from: * "mlir/test/Dialect/Tensor/invalid.mlir" to: * "mlir/test/Dialect/Linalg/invalid.mlir: In addition, I grouped "invalid" tests for `linalg.pack` and `linalg.unpack` into two seperate sets (as opposed to mixing them together). CONTEXT: This change was discussed in the following RFC: * https://discourse.llvm.org/t/rfc-move-tensor-pack-and-tensor-unpack-into-linalg
1 parent c60231c commit f9cb859

File tree

4 files changed

+288
-278
lines changed

4 files changed

+288
-278
lines changed

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,3 +1142,186 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
11421142
%0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
11431143
return %0 : tensor<2x12x11x2xf32>
11441144
}
1145+
1146+
// -----
1147+
1148+
//===----------------------------------------------------------------------===//
1149+
// linalg.pack
1150+
//===----------------------------------------------------------------------===//
1151+
1152+
func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
1153+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
1154+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32>
1155+
return %0 : tensor<8x8x16x33xf32>
1156+
}
1157+
1158+
// -----
1159+
1160+
func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
1161+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
1162+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<10x8x?x?xf32>
1163+
return %0 : tensor<10x8x?x?xf32>
1164+
}
1165+
1166+
// -----
1167+
1168+
func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> {
1169+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
1170+
%0 = linalg.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<8x10x?x?xf32>
1171+
return %0 : tensor<8x10x?x?xf32>
1172+
}
1173+
1174+
// -----
1175+
1176+
func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
1177+
// expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}}
1178+
%0 = linalg.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
1179+
return %0 : tensor<2x8x8x2xf32>
1180+
}
1181+
1182+
// -----
1183+
1184+
func.func @pack_invalid_inner_dims_pos_vector(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1185+
// expected-error@+1 {{invalid inner_dims_pos vector}}
1186+
%0 = linalg.pack %input inner_dims_pos = [2, 0] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1187+
return %0 : tensor<8x8x32x16xf32>
1188+
}
1189+
1190+
// -----
1191+
1192+
func.func @pack_invalid_duplicate_element_in_inner_dims(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1193+
// expected-error@+1 {{invalid inner_dims_pos vector}}
1194+
%0 = linalg.pack %input inner_dims_pos = [1, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1195+
return %0 : tensor<8x8x32x16xf32>
1196+
}
1197+
1198+
// -----
1199+
1200+
func.func @pack_invalid_duplicate_element_in_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1201+
// expected-error@+1 {{invalid outer_dims_perm vector}}
1202+
%0 = linalg.pack %input outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1203+
return %0 : tensor<8x8x32x16xf32>
1204+
}
1205+
1206+
// -----
1207+
1208+
func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<64x32x16xf32> {
1209+
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
1210+
%0 = linalg.pack %input inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<64x32x16xf32>
1211+
return %0 : tensor<64x32x16xf32>
1212+
}
1213+
1214+
// -----
1215+
1216+
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1217+
// expected-error@+1 {{invalid zero tile factor}}
1218+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [0, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1219+
return %0 : tensor<8x8x32x16xf32>
1220+
}
1221+
1222+
// -----
1223+
func.func @pack_mismatch_inner_tile_size_and_output_shape(
1224+
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
1225+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
1226+
%0 = linalg.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
1227+
return %0 : tensor<?x?x8x8xf32>
1228+
}
1229+
1230+
// -----
1231+
1232+
func.func @pack_dynamic_inner_tile_size_and_static_output_shape(
1233+
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
1234+
%c8 = arith.constant 8 : index
1235+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
1236+
%0 = linalg.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
1237+
return %0 : tensor<?x?x8x8xf32>
1238+
}
1239+
1240+
// -----
1241+
1242+
func.func @pack_static_inner_tile_size_and_dynamic_output_shape(
1243+
%input : tensor<?x?xf32>, %output : tensor<?x?x8x?xf32>) -> tensor<?x?x8x?xf32> {
1244+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
1245+
%0 = linalg.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %output : tensor<?x?xf32> -> tensor<?x?x8x?xf32>
1246+
return %0 : tensor<?x?x8x?xf32>
1247+
}
1248+
1249+
// -----
1250+
1251+
func.func @pack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<16x4x32x16xf32> {
1252+
// expected-error@+1 {{outer_dims_perm must be a permutation or empty}}
1253+
%0 = linalg.pack %source outer_dims_perm = [0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
1254+
return %0 : tensor<16x4x32x16xf32>
1255+
}
1256+
1257+
// -----
1258+
1259+
//===----------------------------------------------------------------------===//
1260+
// linalg.unpack
1261+
//===----------------------------------------------------------------------===//
1262+
1263+
func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
1264+
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
1265+
%0 = linalg.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
1266+
return %0 : tensor<256x128xf32>
1267+
}
1268+
1269+
// -----
1270+
1271+
func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1272+
// expected-error@+1 {{invalid outer_dims_perm vector}}
1273+
%0 = linalg.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1274+
return %0 : tensor<256x128xf32>
1275+
}
1276+
1277+
// -----
1278+
1279+
func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> {
1280+
// expected-error@+1 {{outer_dims_perm must be a permutation or empty}}
1281+
%0 = linalg.unpack %dest outer_dims_perm = [1] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
1282+
return %0 : tensor<128x256xf32>
1283+
}
1284+
1285+
// -----
1286+
1287+
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1288+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1289+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1290+
return %0 : tensor<8x8x32x16xf32>
1291+
}
1292+
1293+
// -----
1294+
1295+
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
1296+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1297+
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1298+
return %0 : tensor<256x128xf32>
1299+
}
1300+
1301+
// -----
1302+
1303+
func.func @unpack_mismatch_inner_tile_size_and_output_shape(
1304+
%input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
1305+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
1306+
%0 = linalg.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
1307+
return %0 : tensor<?x?xf32>
1308+
}
1309+
1310+
// -----
1311+
1312+
func.func @unpack_dynamic_inner_tile_size_and_static_output_shape(
1313+
%input : tensor<?x?x8x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
1314+
%c8 = arith.constant 8 : index
1315+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
1316+
%0 = linalg.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8, 4] into %output : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
1317+
return %0 : tensor<?x?xf32>
1318+
}
1319+
1320+
// -----
1321+
1322+
func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
1323+
%input : tensor<?x?x?x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
1324+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
1325+
%0 = linalg.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
1326+
return %0 : tensor<?x?xf32>
1327+
}

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,3 +2248,108 @@ func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %a
22482248
%1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
22492249
return %1 : tensor<4x8x16xf32>
22502250
}
2251+
2252+
//===----------------------------------------------------------------------===//
2253+
// linalg.pack + linalg.unpack
2254+
//===----------------------------------------------------------------------===//
2255+
2256+
func.func @pack_nc_to_ncnc(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) -> tensor<128x256xf32> {
2257+
%0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
2258+
%1 = tensor.empty() : tensor<128x256xf32>
2259+
%2 = linalg.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
2260+
return %2 : tensor<128x256xf32>
2261+
}
2262+
2263+
// CHECK-LABEL: func.func @pack_nc_to_ncnc(
2264+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<128x256xf32>,
2265+
// CHECK-SAME: %[[DEST:.*]]: tensor<4x16x32x16xf32>)
2266+
// CHECK: %[[PACKED:.*]] = linalg.pack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
2267+
// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32>
2268+
// CHECK: %{{.*}} = linalg.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
2269+
2270+
// -----
2271+
2272+
func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> {
2273+
%0 = linalg.pack %source padding_value(%padding : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
2274+
%1 = tensor.empty() : tensor<13x15xf32>
2275+
%2 = linalg.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
2276+
return %2 : tensor<13x15xf32>
2277+
}
2278+
2279+
// CHECK-LABEL: func.func @pack_nc_to_ncnc_with_padding(
2280+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<13x15xf32>,
2281+
// CHECK-SAME: %[[DEST:.*]]: tensor<2x8x8x2xf32>,
2282+
// CHECK-SAME: %[[PADDING:.*]]: f32)
2283+
// CHECK: %[[PACKED:.*]] = linalg.pack %[[SOURCE]] padding_value(%[[PADDING]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
2284+
// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<13x15xf32>
2285+
// CHECK: %{{.*}} = linalg.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[BUFF]] : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
2286+
2287+
// -----
2288+
2289+
func.func @pack_ck_to_kcck(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> {
2290+
%0 = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
2291+
%1 = tensor.empty() : tensor<128x256xf32>
2292+
%2 = linalg.unpack %0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
2293+
return %2 : tensor<128x256xf32>
2294+
}
2295+
2296+
// CHECK-LABEL: func.func @pack_ck_to_kcck(
2297+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<128x256xf32>,
2298+
// CHECK-SAME: %[[DEST:.*]]: tensor<16x4x32x16xf32>)
2299+
// CHECK: %[[PACKED:.*]] = linalg.pack %[[SOURCE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
2300+
// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32>
2301+
// CHECK: %{{.*}} = linalg.unpack %[[PACKED]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
2302+
2303+
// -----
2304+
2305+
func.func @pad_and_pack_fully_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x?x?xf32>, %pad: f32, %tile_n : index, %tile_m : index) -> tensor<?x?x?x?xf32> {
2306+
%0 = linalg.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2307+
return %0 : tensor<?x?x?x?xf32>
2308+
}
2309+
2310+
// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic(
2311+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?xf32>,
2312+
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
2313+
// CHECK-SAME: %[[PAD:.*]]: f32,
2314+
// CHECK-SAME: %[[TILE_N:.*]]: index,
2315+
// CHECK-SAME: %[[TILE_M:.*]]: index)
2316+
// CHECK: %{{.*}} = linalg.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2317+
2318+
// -----
2319+
2320+
func.func @pad_and_pack_partially_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>, %pad: f32) -> tensor<?x?x8x2xf32> {
2321+
%0 = linalg.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
2322+
return %0 : tensor<?x?x8x2xf32>
2323+
}
2324+
2325+
// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic(
2326+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?xf32>,
2327+
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x8x2xf32>,
2328+
// CHECK-SAME: %[[PAD:.*]]: f32)
2329+
// CHECK: %{{.*}} = linalg.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
2330+
2331+
// -----
2332+
2333+
func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
2334+
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2335+
return %0 : tensor<?x?xf32>
2336+
}
2337+
2338+
// CHECK-LABEL: func.func @unpack_fully_dynamic(
2339+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?x?x?xf32>,
2340+
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
2341+
// CHECK-SAME: %[[TILE_N:.*]]: index,
2342+
// CHECK-SAME: %[[TILE_M:.*]]: index)
2343+
// CHECK: %{{.*}} = linalg.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2344+
2345+
// -----
2346+
2347+
func.func @unpack_partially_dynamic(%source: tensor<?x?x8x2xf32>, %dest: tensor<?x?xf32>) -> tensor<?x?xf32> {
2348+
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<?x?x8x2xf32> -> tensor<?x?xf32>
2349+
return %0: tensor<?x?xf32>
2350+
}
2351+
2352+
// CHECK-LABEL: func.func @unpack_partially_dynamic(
2353+
// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?x8x2xf32>,
2354+
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>)
2355+
// CHECK: %{{.*}} = linalg.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<?x?x8x2xf32> -> tensor<?x?xf32>

0 commit comments

Comments
 (0)