Skip to content

[mlir][mesh] add support in spmdization for incomplete sharding annotations #82442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

sogartar
Copy link
Contributor

Don't require that mesh.shard operations come in pairs. If there is only a single mesh.shard operation we assume that the producer result and consumer operand have the same sharding.

@sogartar
Copy link
Contributor Author

For more context see #82375

@llvmbot llvmbot added the mlir label Feb 21, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Don't require that mesh.shard operations come in pairs. If there is only a single mesh.shard operation we assume that the producer result and consumer operand have the same sharding.


Full diff: https://github.com/llvm/llvm-project/pull/82442.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+38-15)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+14)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 7cbe0de048769b..287db5dd08c5fd 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
     Operation *definingOp = operand.getDefiningOp();
     assert(definingOp);
     ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
-    assert(shardOp.getAnnotateForUsers());
     return shardOp.getShard();
   });
   return res;
@@ -615,34 +614,58 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
                     assert(result.hasOneUse());
                     Operation *userOp = *result.getUsers().begin();
                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
-                    assert(!shardOp.getAnnotateForUsers());
                     return shardOp.getShard();
                   });
   return res;
 }
 
+ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
+  Operation* srcOp = targetShardOp.getOperand().getDefiningOp();
+  if (!srcOp) {
+    return ShardOp();
+  }
+  ShardOp srcShardOp =
+      llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
+  if (!srcShardOp) {
+    return ShardOp();
+  }
+
+  return srcShardOp;
+}
+
 static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
                  SymbolTableCollection &symbolTableCollection,
                  OpBuilder &builder) {
-  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
-  if (shardOp) {
-    if (!shardOp.getAnnotateForUsers()) {
-      return success();
-    }
+  Value targetSpmdValue;
 
+  // Check if 2 shard ops are chained. If not there is no need for resharding
+  // as the source and target shared the same sharding.
+  ShardOp srcShardOp = getSourceShardOpOrNull(shardOp);
+  if (!srcShardOp) {
+    targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+  } else {
     // Insert resharding.
-    ShardOp srcShardOp =
-        llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
-    assert(!srcShardOp.getAnnotateForUsers());
+    assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
     TypedValue<ShapedType> srcSpmdValue =
         spmdizationMap.lookup(srcShardOp.getOperand())
             .cast<TypedValue<ShapedType>>();
-    Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+    targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                                     symbolTableCollection);
-    assert(!spmdizationMap.contains(shardOp.getResult()));
-    spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
-    return success();
+  }
+
+  assert(!spmdizationMap.contains(shardOp.getResult()));
+  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+  return success();
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+                 SymbolTableCollection &symbolTableCollection,
+                 OpBuilder &builder) {
+  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+  if (shardOp) {
+    return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, builder);
   }
 
   SmallVector<Value> spmdizedOperands;
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2fb8029dfe64ae..258c3786e3518c 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
   // CHECK: return %[[RESHARD3]] : tensor<1xi8>
   return %7 : tensor<2xi8>
 }
+
+// // CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+  %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+  // CHECK: return %[[RES]] : tensor<4x16xf32>
+  return %2 : tensor<8x16xf32>
+}

@sogartar
Copy link
Contributor Author

@yaochengji, could you review this PR?

@sogartar sogartar force-pushed the mesh-spmdization-handle-incomplete-sharding branch from 90e03b9 to 91ff453 Compare February 21, 2024 00:20
Copy link

github-actions bot commented Feb 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
return %7 : tensor<2xi8>
}

// // CHECK-LABEL: func @incomplete_sharding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant //

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it.

return shardOp.getShard();
});
return res;
}

static ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
Operation *srcOp = targetShardOp.getOperand().getDefiningOp();
if (!srcOp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By LLVM convention we don't nee { ... } for simple loops like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the function.

return ShardOp();
}
ShardOp srcShardOp =
llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be merged with the above using dyn_cast_or_null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@sogartar sogartar requested a review from antiagainst February 21, 2024 14:37
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
ShardOp srcShardOp =
llvm::dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we should not need the llvm:: prefix I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

…ations

Don't require that `mesh.shard` operations come in pairs.
If there is only a single `mesh.shard` operation we assume that the producer
result and consumer operand have the same sharding.
@sogartar sogartar force-pushed the mesh-spmdization-handle-incomplete-sharding branch from 7ca5d76 to d42f27f Compare February 22, 2024 17:24
@sogartar
Copy link
Contributor Author

Thank you for the review. I squashed and rebased to check that the CI is OK before merging.

@sogartar sogartar merged commit 4f7ab78 into llvm:main Feb 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants