-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] add support in spmdization for incomplete sharding annotations #82442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] add support in spmdization for incomplete sharding annotations #82442
Conversation
For more context see #82375 |
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesDon't require that Full diff: https://github.com/llvm/llvm-project/pull/82442.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 7cbe0de048769b..287db5dd08c5fd 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- assert(shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
@@ -615,34 +614,58 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- assert(!shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
}
+ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
+ Operation* srcOp = targetShardOp.getOperand().getDefiningOp();
+ if (!srcOp) {
+ return ShardOp();
+ }
+ ShardOp srcShardOp =
+ llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
+ if (!srcShardOp) {
+ return ShardOp();
+ }
+
+ return srcShardOp;
+}
+
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
- if (shardOp) {
- if (!shardOp.getAnnotateForUsers()) {
- return success();
- }
+ Value targetSpmdValue;
+ // Check if 2 shard ops are chained. If not there is no need for resharding
+ // as the source and target shared the same sharding.
+ ShardOp srcShardOp = getSourceShardOpOrNull(shardOp);
+ if (!srcShardOp) {
+ targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+ } else {
// Insert resharding.
- ShardOp srcShardOp =
- llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
- assert(!srcShardOp.getAnnotateForUsers());
+ assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
- Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+ targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
- assert(!spmdizationMap.contains(shardOp.getResult()));
- spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
- return success();
+ }
+
+ assert(!spmdizationMap.contains(shardOp.getResult()));
+ spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ return success();
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+ if (shardOp) {
+ return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, builder);
}
SmallVector<Value> spmdizedOperands;
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2fb8029dfe64ae..258c3786e3518c 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
return %7 : tensor<2xi8>
}
+
+// // CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+ %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<4x16xf32>
+ return %2 : tensor<8x16xf32>
+}
|
@yaochengji, could you review this PR? |
90e03b9
to
91ff453
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops( | |||
// CHECK: return %[[RESHARD3]] : tensor<1xi8> | |||
return %7 : tensor<2xi8> | |||
} | |||
|
|||
// // CHECK-LABEL: func @incomplete_sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant //
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it.
return shardOp.getShard(); | ||
}); | ||
return res; | ||
} | ||
|
||
static ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) { | ||
Operation *srcOp = targetShardOp.getOperand().getDefiningOp(); | ||
if (!srcOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By LLVM convention we don't nee { ... }
for simple loops like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the function.
return ShardOp(); | ||
} | ||
ShardOp srcShardOp = | ||
llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be merged with the above using dyn_cast_or_null
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// Check if 2 shard ops are chained. If not there is no need for resharding | ||
// as the source and target shared the same sharding. | ||
ShardOp srcShardOp = | ||
llvm::dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we should not need the llvm::
prefix I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
…ations Don't require that `mesh.shard` operations come in pairs. If there is only a single `mesh.shard` operation we assume that the producer result and consumer operand have the same sharding.
7ca5d76
to
d42f27f
Compare
Thank you for the review. I squashed and rebased to check that the CI is OK before merging. |
Don't require that
mesh.shard
operations come in pairs. If there is only a singlemesh.shard
operation we assume that the producer result and consumer operand have the same sharding.