Skip to content

Commit 4c689e4

Browse files
committed
removing source from UpdateHaloOp, because not required for destination passing style
1 parent 2cf1c8a commit 4c689e4

File tree

5 files changed

+23
-30
lines changed

5 files changed

+23
-30
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10931093
TypesMatchWith<
10941094
"result has same type as destination",
10951095
"result", "destination", "$_self">,
1096-
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1097-
AttrSizedOperandSegments
1096+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
10981097
]> {
10991098
let summary = "Update halo data.";
11001099
let description = [{
@@ -1103,7 +1102,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
11031102
on the remote devices. Changes might be caused by mutating operations
11041103
and/or if the new halo regions are larger than the existing ones.
11051104

1106-
Source and destination might have different halo sizes.
1105+
Destination is supposed to be initialized with the local data (not halos).
11071106

11081107
Assumes all devices hold tensors with same-sized halo data as specified
11091108
by `source_halo_sizes/static_source_halo_sizes` and
@@ -1115,25 +1114,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
11151114

11161115
}];
11171116
let arguments = (ins
1118-
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
11191117
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
11201118
FlatSymbolRefAttr:$mesh,
11211119
Mesh_MeshAxesArrayAttr:$split_axes,
1122-
Variadic<I64>:$source_halo_sizes,
1123-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1124-
Variadic<I64>:$destination_halo_sizes,
1125-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1120+
Variadic<I64>:$halo_sizes,
1121+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
11261122
);
11271123
let results = (outs
11281124
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
11291125
);
11301126
let assemblyFormat = [{
1131-
$source `into` $destination
1127+
$destination
11321128
`on` $mesh
11331129
`split_axes` `=` $split_axes
1134-
(`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1135-
(`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1136-
attr-dict `:` type($source) `->` type($result)
1130+
(`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
1131+
attr-dict `:` type($result)
11371132
}];
11381133
let extraClassDeclaration = [{
11391134
MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ struct ConvertUpdateHaloOp
7070
cast<IntegerAttr>(v.get<Attribute>()).getInt()));
7171
};
7272

73-
auto array = op.getInput();
74-
auto rank = array.getType().getRank();
73+
auto array = op.getDestination();
74+
auto rank = cast<ShapedType>(array.getType()).getRank();
7575
auto opSplitAxes = op.getSplitAxes().getAxes();
7676
auto mesh = op.getMesh();
7777
auto meshOp = getMesh(op, symbolTableCollection);
7878
auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
79-
op.getDynamicHaloSizes(), rewriter);
79+
op.getHaloSizes(), rewriter);
8080
// subviews need Index values
8181
for (auto &sz : haloSizes) {
8282
if (sz.is<Value>()) {
@@ -94,7 +94,7 @@ struct ConvertUpdateHaloOp
9494
auto currHaloDim = -1; // halo sizes are provided for split dimensions only
9595
// we need the actual shape to compute offsets and sizes
9696
for (auto i = 0; i < rank; ++i) {
97-
auto s = array.getType().getShape()[i];
97+
auto s = cast<ShapedType>(array.getType()).getShape()[i];
9898
if (ShapedType::isDynamic(s)) {
9999
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
100100
} else {
@@ -176,7 +176,7 @@ struct ConvertUpdateHaloOp
176176
auto hasTo = rewriter.create<arith::CmpIOp>(
177177
loc, arith::CmpIPredicate::sge, to, zero);
178178
auto buffer = rewriter.create<memref::AllocOp>(
179-
loc, dimSizes, array.getType().getElementType());
179+
loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
180180
// if has neighbor: copy halo data from array to buffer and send
181181
rewriter.create<scf::IfOp>(
182182
loc, hasTo, [&](OpBuilder &builder, Location loc) {

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
495495
sourceShard.getLoc(),
496496
RankedTensorType::get(outShape,
497497
sourceShard.getType().getElementType()),
498-
sourceShard, initOprnd, mesh.getSymName(),
498+
initOprnd, mesh.getSymName(),
499499
MeshAxesArrayAttr::get(builder.getContext(),
500500
sourceSharding.getSplitAxes()),
501-
sourceSharding.getDynamicHaloSizes(),
502-
sourceSharding.getStaticHaloSizes(),
503501
targetSharding.getDynamicHaloSizes(),
504502
targetSharding.getStaticHaloSizes())
505503
.getResult();

mlir/test/Dialect/Mesh/ops.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -615,16 +615,16 @@ func.func @update_halo(
615615
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
616616
%arg0 : memref<12x12xi8>) {
617617
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
618-
// CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
618+
// CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
619619
// CHECK-SAME: split_axes = {{\[\[}}0]]
620-
// CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
620+
// CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
621621
%c2 = arith.constant 2 : i64
622-
%uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
623-
source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
624-
// CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
622+
%uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
623+
halo_sizes = [2, %c2] : memref<12x12xi8>
624+
// CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
625625
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
626-
// CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
627-
%uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
628-
source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
626+
// CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
627+
%uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
628+
halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
629629
return
630630
}

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
226226
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
227227
// CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
228228
// CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
229-
// CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
229+
// CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
230230
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
231231
%sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
232232
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
@@ -242,7 +242,7 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200
242242
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
243243
// CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
244244
// CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
245-
// CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
245+
// CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
246246
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
247247
%sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
248248
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>

0 commit comments

Comments
 (0)