|
| 1 | +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s |
| 2 | + |
| 3 | +#map = affine_map<(d0, d1) -> (d0, d1)> |
| 4 | +module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { |
| 5 | + mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} |
| 6 | + func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} { |
| 7 | + %c1_i32 = arith.constant 1 : i32 |
| 8 | + // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> |
| 9 | + %0 = tensor.empty() : tensor<6x6xi32> |
| 10 | + // CHECK: [[v1:%.*]] = linalg.fill ins |
| 11 | + // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding |
| 12 | + // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> |
| 13 | + %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> |
| 14 | + %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding |
| 15 | + %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32> |
| 16 | + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> |
| 17 | + // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding |
| 18 | + // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> |
| 19 | + %3 = tensor.empty() : tensor<6x6xi32> |
| 20 | + // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding |
| 21 | + // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> |
| 22 | + // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} |
| 23 | + // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) { |
| 24 | + // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding |
| 25 | + // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> |
| 26 | + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated |
| 27 | + : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { |
| 28 | + ^bb0(%in: i32, %in_2: i32, %out: i32): |
| 29 | + %9 = arith.addi %in, %in_2 : i32 |
| 30 | + linalg.yield %9 : i32 |
| 31 | + } -> tensor<6x6xi32> |
| 32 | + %c0_i32 = arith.constant 0 : i32 |
| 33 | + %6 = tensor.empty() : tensor<i32> |
| 34 | + %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32> |
| 35 | + // CHECK: [[vreduced:%.*]] = linalg.reduce ins |
| 36 | + // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding |
| 37 | + // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32> |
| 38 | + %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] |
| 39 | + (%in: i32, %init: i32) { |
| 40 | + %9 = arith.addi %in, %init : i32 |
| 41 | + linalg.yield %9 : i32 |
| 42 | + } |
| 43 | + // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding |
| 44 | + %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding |
| 45 | + // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32> |
| 46 | + %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32> |
| 47 | + return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32> |
| 48 | + } |
| 49 | +} |
0 commit comments