Skip to content

Commit 1d6b31f

Browse files
authored
[mlir]linalg][NFC]-Add lit test for tile and fuse transformation (#126216)
Add coverage for the fuse consumer transform for `linalg.generic` operation with projected permutation indexing maps.
1 parent 569e94f commit 1d6b31f

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,3 +676,65 @@ module attributes {transform.with_named_sequence} {
676676
// CHECK: }
677677
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
678678
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
679+
680+
// -----
681+
682+
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
683+
#map1 = affine_map<(d0, d1, d2) -> (d2)>
684+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
685+
module {
686+
func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> {
687+
%c0 = arith.constant 0 : index
688+
%c64 = arith.constant 64 : index
689+
%c256 = arith.constant 256 : index
690+
%0 = tensor.empty() : tensor<256x256xf32>
691+
%1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) {
692+
%extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
693+
%extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
694+
%extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
695+
%4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32>
696+
%inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
697+
scf.yield %inserted_slice : tensor<256x256xf32>
698+
}
699+
%2 = tensor.empty() : tensor<256x256x24xf32>
700+
%3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) {
701+
^bb0(%in: f32, %in_0: f32, %out: f32):
702+
%4 = arith.addf %in, %in_0 : f32
703+
linalg.yield %4 : f32
704+
} -> tensor<256x256x24xf32>
705+
return %3 : tensor<256x256x24xf32>
706+
}
707+
}
708+
709+
// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
710+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
711+
// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
712+
// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index
713+
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
714+
// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
715+
// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
716+
// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
717+
// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
718+
// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
719+
// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
720+
// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
721+
// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
722+
// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
723+
// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
724+
// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
725+
// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
726+
// CHECK: linalg.yield %[[VAL_23]] : f32
727+
// CHECK: } -> tensor<64x256x24xf32>
728+
// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
729+
// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
730+
// CHECK: }
731+
732+
module attributes {transform.with_named_sequence} {
733+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
734+
%slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
735+
: (!transform.any_op) -> !transform.any_op
736+
%a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
737+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
738+
transform.yield
739+
}
740+
}

0 commit comments

Comments
 (0)