Skip to content

Commit 91ff453

Browse files
committed
[mlir][mesh] add support in spmdization for incomplete sharding annotations
Don't require that `mesh.shard` operations come in pairs. If there is only a single `mesh.shard` operation we assume that the producer result and consumer operand have the same sharding.
1 parent b1849a2 commit 91ff453

File tree

2 files changed

+54
-16
lines changed

2 files changed

+54
-16
lines changed

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
593593
Operation *definingOp = operand.getDefiningOp();
594594
assert(definingOp);
595595
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
596-
assert(shardOp.getAnnotateForUsers());
597596
return shardOp.getShard();
598597
});
599598
return res;
@@ -615,34 +614,59 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
615614
assert(result.hasOneUse());
616615
Operation *userOp = *result.getUsers().begin();
617616
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
618-
assert(!shardOp.getAnnotateForUsers());
619617
return shardOp.getShard();
620618
});
621619
return res;
622620
}
623621

622+
ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
623+
Operation *srcOp = targetShardOp.getOperand().getDefiningOp();
624+
if (!srcOp) {
625+
return ShardOp();
626+
}
627+
ShardOp srcShardOp =
628+
llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
629+
if (!srcShardOp) {
630+
return ShardOp();
631+
}
632+
633+
return srcShardOp;
634+
}
635+
624636
static LogicalResult
625-
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
637+
spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
626638
SymbolTableCollection &symbolTableCollection,
627639
OpBuilder &builder) {
628-
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
629-
if (shardOp) {
630-
if (!shardOp.getAnnotateForUsers()) {
631-
return success();
632-
}
640+
Value targetSpmdValue;
633641

642+
// Check if 2 shard ops are chained. If not there is no need for resharding
643+
// as the source and target shared the same sharding.
644+
ShardOp srcShardOp = getSourceShardOpOrNull(shardOp);
645+
if (!srcShardOp) {
646+
targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
647+
} else {
634648
// Insert resharding.
635-
ShardOp srcShardOp =
636-
llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
637-
assert(!srcShardOp.getAnnotateForUsers());
649+
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
638650
TypedValue<ShapedType> srcSpmdValue =
639651
spmdizationMap.lookup(srcShardOp.getOperand())
640652
.cast<TypedValue<ShapedType>>();
641-
Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
642-
symbolTableCollection);
643-
assert(!spmdizationMap.contains(shardOp.getResult()));
644-
spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
645-
return success();
653+
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
654+
symbolTableCollection);
655+
}
656+
657+
assert(!spmdizationMap.contains(shardOp.getResult()));
658+
spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
659+
return success();
660+
}
661+
662+
static LogicalResult
663+
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
664+
SymbolTableCollection &symbolTableCollection,
665+
OpBuilder &builder) {
666+
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
667+
if (shardOp) {
668+
return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
669+
builder);
646670
}
647671

648672
SmallVector<Value> spmdizedOperands;

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
127127
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
128128
return %7 : tensor<2xi8>
129129
}
130+
131+
// // CHECK-LABEL: func @incomplete_sharding
132+
func.func @incomplete_sharding(
133+
// CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
134+
%arg0: tensor<8x16xf32>
135+
// CHECK-SAME: -> tensor<4x16xf32> {
136+
) -> tensor<8x16xf32> {
137+
%0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
138+
// CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
139+
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
140+
%2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
141+
// CHECK: return %[[RES]] : tensor<4x16xf32>
142+
return %2 : tensor<8x16xf32>
143+
}

0 commit comments

Comments
 (0)