Skip to content

Commit 8b209ce

Browse files
committed
clang-format
1 parent 13c590c commit 8b209ce

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -489,18 +489,22 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
489489
tgtCoreOffs, coreShape, strides);
490490

491491
// finally update the halo
492-
auto updateHaloResult = builder.create<UpdateHaloOp>(
493-
sourceShard.getLoc(),
494-
RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
495-
sourceShard, initOprnd, mesh.getSymName(),
496-
MeshAxesArrayAttr::get(builder.getContext(),
497-
sourceSharding.getSplitAxes()),
498-
sourceSharding.getDynamicHaloSizes(),
499-
sourceSharding.getStaticHaloSizes(),
500-
targetSharding.getDynamicHaloSizes(),
501-
targetSharding.getStaticHaloSizes()).getResult();
502-
return std::make_tuple(
503-
cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
492+
auto updateHaloResult =
493+
builder
494+
.create<UpdateHaloOp>(
495+
sourceShard.getLoc(),
496+
RankedTensorType::get(outShape,
497+
sourceShard.getType().getElementType()),
498+
sourceShard, initOprnd, mesh.getSymName(),
499+
MeshAxesArrayAttr::get(builder.getContext(),
500+
sourceSharding.getSplitAxes()),
501+
sourceSharding.getDynamicHaloSizes(),
502+
sourceSharding.getStaticHaloSizes(),
503+
targetSharding.getDynamicHaloSizes(),
504+
targetSharding.getStaticHaloSizes())
505+
.getResult();
506+
return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
507+
targetSharding);
504508
}
505509
return std::nullopt;
506510
}
@@ -725,8 +729,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
725729
targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
726730
} else {
727731
// Insert resharding.
728-
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
729-
spmdizationMap.lookup(srcShardOp));
732+
TypedValue<ShapedType> srcSpmdValue =
733+
cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
730734
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
731735
symbolTableCollection);
732736
}

0 commit comments

Comments
 (0)