Skip to content

Commit d42f27f

Browse files
committed
[mlir][mesh] add support in spmdization for incomplete sharding annotations
Don't require that `mesh.shard` operations come in pairs. If there is only a single `mesh.shard` operation we assume that the producer result and consumer operand have the same sharding.
1 parent 87b1e73 commit d42f27f

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
593593
Operation *definingOp = operand.getDefiningOp();
594594
assert(definingOp);
595595
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
596-
assert(shardOp.getAnnotateForUsers());
597596
return shardOp.getShard();
598597
});
599598
return res;
@@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
615614
assert(result.hasOneUse());
616615
Operation *userOp = *result.getUsers().begin();
617616
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
618-
assert(!shardOp.getAnnotateForUsers());
619617
return shardOp.getShard();
620618
});
621619
return res;
622620
}
623621

624622
static LogicalResult
625-
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
623+
spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
626624
SymbolTableCollection &symbolTableCollection,
627625
OpBuilder &builder) {
628-
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
629-
if (shardOp) {
630-
if (!shardOp.getAnnotateForUsers()) {
631-
return success();
632-
}
633-
626+
Value targetSpmdValue;
627+
628+
// Check if 2 shard ops are chained. If not there is no need for resharding
629+
// as the source and target shared the same sharding.
630+
ShardOp srcShardOp =
631+
dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
632+
if (!srcShardOp) {
633+
targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
634+
} else {
634635
// Insert resharding.
635-
ShardOp srcShardOp =
636-
llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
637-
assert(!srcShardOp.getAnnotateForUsers());
636+
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
638637
TypedValue<ShapedType> srcSpmdValue =
639638
spmdizationMap.lookup(srcShardOp.getOperand())
640639
.cast<TypedValue<ShapedType>>();
641-
Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
642-
symbolTableCollection);
643-
assert(!spmdizationMap.contains(shardOp.getResult()));
644-
spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
645-
return success();
640+
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
641+
symbolTableCollection);
642+
}
643+
644+
assert(!spmdizationMap.contains(shardOp.getResult()));
645+
spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
646+
return success();
647+
}
648+
649+
static LogicalResult
650+
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
651+
SymbolTableCollection &symbolTableCollection,
652+
OpBuilder &builder) {
653+
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
654+
if (shardOp) {
655+
return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
656+
builder);
646657
}
647658

648659
SmallVector<Value> spmdizedOperands;

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
127127
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
128128
return %7 : tensor<2xi8>
129129
}
130+
131+
// CHECK-LABEL: func @incomplete_sharding
132+
func.func @incomplete_sharding(
133+
// CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
134+
%arg0: tensor<8x16xf32>
135+
// CHECK-SAME: -> tensor<4x16xf32> {
136+
) -> tensor<8x16xf32> {
137+
%0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
138+
// CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
139+
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
140+
%2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
141+
// CHECK: return %[[RES]] : tensor<4x16xf32>
142+
return %2 : tensor<8x16xf32>
143+
}

0 commit comments

Comments
 (0)