Skip to content

[mlir][mesh] Refactoring code organization, tests and docs #79606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2024

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Jan 26, 2024

  • Split out MeshDialect.h form MeshOps.h that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops.

  • Expose functions getMesh and collectiveProcessGroupSize. There functions are useful for outside users of the dialect.

  • Remove unused code.

  • Remove examples and tests of mesh.shard attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment:
    https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81

@sogartar sogartar requested a review from antiagainst January 26, 2024 15:29
@llvmbot llvmbot added the mlir label Jan 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 26, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes
  • Split out MeshDialect.h form MeshOps.h that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops.

  • Expose functions getMesh and collectiveProcessGroupSize. There functions are useful for outside users of the dialect.

  • Remove unused code.

  • Remove examples and tests of mesh.sharding attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment:
    https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81


Patch is 23.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79606.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt (+6-6)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+4-4)
  • (added) mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h (+16)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+32-5)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (-4)
  • (modified) mlir/include/mlir/InitAllDialects.h (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/IR/CMakeLists.txt (+2-2)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+14-48)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+1)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+4-16)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+1)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+20-20)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+6-30)
  • (modified) mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp (-2)
  • (modified) mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 79a99e58b4e8f1..a73ec701dbadcc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -2,11 +2,11 @@ add_mlir_dialect(MeshOps mesh)
 add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMeshAttrIncGen)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
+mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRMeshEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index e8353613cd0e12..04929f4869273d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+    #mesh.shard<@mesh0, [[]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+    #mesh.shard<@mesh0, [[0]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
+    #mesh.shard<@mesh0, [[0]], partial = max[1]>
 
     // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
new file mode 100644
index 00000000000000..a1adbaa44406d6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
@@ -0,0 +1,16 @@
+//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83452dcc2e8abe..8e5e0f541ba5ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,7 +15,6 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include <algorithm>
 
 namespace mlir {
 namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 } // namespace mesh
 } // namespace mlir
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
-
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+      op, meshSymbol);
+}
+
+// Get the corresponding mesh op using the standard attribute nomenclature.
+template <typename Op>
+mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+}
+
+// Get the number of processes that participate in each group
+// induced by `meshAxes`.
+template <typename MeshAxesRange, typename MeshShapeRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
+                                   MeshShapeRange &&meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    auto axisSize = *(std::begin(meshShape) + axis);
+    if (ShapedType::isDynamic(axisSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= axisSize;
+  }
+
+  return res;
+}
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687ae..c9efdcc3a68bb8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
     // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
     mesh.mesh @mesh3(shape = ?x?)
-
-    // Used in the mesh sharding attribute to extend the standard tensor to
-    // distributed
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
     ```
   }];
   let arguments = (ins
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..d1ee32d7bac613 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -54,7 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 634a94f8cec879..140fb553f84817 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -5,8 +5,8 @@ add_mlir_dialect_library(MLIRMeshDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
 
   DEPENDS
-  MLIRMeshOpsAttrIncGen
-  MLIRMeshOpsEnumsIncGen
+  MLIRMeshAttrIncGen
+  MLIRMeshEnumsIncGen
   MLIRMeshOpsIncGen
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 994a017a1f46c2..b1cc36d1879e19 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -43,24 +45,6 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
-template <typename It>
-static It canonicalizeSetAsArray(It begin, It end) {
-  llvm::sort(begin, end);
-  return std::unique(begin, end);
-}
-
-template <typename R>
-static auto canonicalizeSetAsArray(R &&range) {
-  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
-}
-
-template <typename T>
-static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
-  auto newEnd = canonicalizeSetAsArray(vec);
-  vec.resize(newEnd - vec.begin());
-  return vec;
-}
-
 namespace {
 
 struct DimensionSize {
@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                 SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
+                                          FlatSymbolRefAttr meshSymbol,
+                                          SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
   if (rank <= 0)
     return emitOpError("rank of mesh is expected to be a positive integer");
 
-  if (getShape().size() > size_t(rank))
-    return emitOpError(
-        "rank of shape is not expected to be larger than rank of mesh");
-
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
 
 LogicalResult
 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
 LogicalResult
 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 template <typename Op>
 static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  auto mesh =
+      ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -450,21 +431,6 @@ static auto product(R &&range) {
   return product(adl_begin(range), adl_end(range));
 }
 
-static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
-                                         ArrayRef<int64_t> meshShape) {
-  int64_t res = 1;
-
-  for (MeshAxis axis : meshAxes) {
-    if (ShapedType::isDynamic(meshShape[axis])) {
-      return ShapedType::kDynamic;
-    }
-    assert(size_t(axis) < meshShape.size());
-    res *= meshShape[axis];
-  }
-
-  return res;
-}
-
 static LogicalResult verifyDimensionCompatibility(Location loc,
                                                   int64_t expectedDimSize,
                                                   int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
   DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandScatterDimSize =
       DimensionSize(operandType.getDimSize(scatterAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 3aed912fb43c63..f3cd12f38879d8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 593158d5f6d293..5554edac4d2f63 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
   return dim * shardCount;
 }
 
-template <typename MeshShape, typename SplitAxes>
-int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
-  int64_t res = 1;
-  for (auto splitAxis : splitAxes) {
-    int64_t meshDimSize = meshShape[splitAxis];
-    if (ShapedType::isDynamic(meshDimSize)) {
-      return ShapedType::kDynamic;
-    }
-    res *= meshDimSize;
-  }
-  return res;
-}
-
 // Compute the shape for the tensor on each device in the mesh.
 // Example:
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
   for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] =
-        shardDimension(inShape[tensorAxis],
-                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+    outShape[tensorAxis] = shardDimension(
+        inShape[tensorAxis],
+        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 5c2344651bf60d..03b1d9b3498028 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 259e4ebf76757c..3fa3ebd67b15e7 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -13,10 +13,10 @@ mesh.mesh @mesh0(shape = -1)
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
-    // expected-error@+1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error@+1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -24,10 +24,10 @@ func.func @mesh_axis_duplicated_different_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
-    // expected-error@+1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error@+1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -35,10 +35,10 @@ func.func @mesh_axis_duplicated_same_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
-    // expected-error@+1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error@+1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -46,10 +46,10 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
-    // expected-error@+1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error@+1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -57,10 +57,10 @@ func.func @mesh_axis_negtive_in_split_part(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
-    // expected-error@+1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error@+1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index dbaaff9c172fd9..40a8469b264643 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -17,36 +17,12 @@ mesh.mesh @mesh4(shape = 3)
 // CHECK: mesh.mesh @mesh5(shape = ?)
 mesh.mesh @mesh5(shape = ?)
 
-// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
-func.func @mesh_shard_encoding_fully_replicated(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
-    %arg0 : tensor<4x8xf32, #m...
[truncated]

@sogartar sogartar requested a review from joker-eph January 26, 2024 15:30
@sogartar
Copy link
Contributor Author

@yaochengji, I have another PR if you can check it out.

@sogartar sogartar marked this pull request as draft January 26, 2024 16:38
@sogartar sogartar marked this pull request as ready for review January 26, 2024 17:01
@yaochengji
Copy link
Member

yaochengji commented Jan 26, 2024

I'm thinking about do we really need an explicit sharding attr? Because it is currently only used in annotation op.

Then it will be easier to get the mesh cluster through mesh symbol.

@sogartar
Copy link
Contributor Author

I guess you mean that an attribute is not bound to an operational context, so if you have just the sharding attribute, you can't really use it without the operation.
There is merit to your proposal. I would like to think a bit more about it and maybe do this change in another PR?

@yaochengji
Copy link
Member

Yeah, previously the type of MeshShardingAttr is created because it is intended to be used for both the mesh.shard operation and tensor. But now it is only used by the op, that's why I think the explicit MeshShardingAttr is not needed.

Of course, it could be changed in another PR.

add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRMeshEnumsIncGen)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, but we should be consistent, can you also rename the first line: add_mlir_dialect(MeshDialect mesh)?

The generated files shouldn't be named "MeshOpsDialect.h.inc" for example but "MeshDialect.h.inc"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change the naming and creation of the CMake targets.

@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this should be in the next group below, "mlir/Dialect/Mesh/IR/MeshOps.h" should be alone at the top.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Moved it.

@sogartar sogartar requested a review from joker-eph January 30, 2024 16:29
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class.
Reduces include clutter if you care only about the dialect and not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`.
There functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.shard attribute in tensor encoding.
Per the decision that Spmdization would be performed on sharding annotations
and there will be no tensors with sharding specified in the type.
For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
@sogartar
Copy link
Contributor Author

Thank you, I rebased and squashed to resolved the merge conflicts before merging.

@sogartar sogartar merged commit 31fc0a1 into llvm:main Jan 31, 2024
@tblah
Copy link
Contributor

tblah commented Jan 31, 2024

Since this commit I am getting a build error saying the include mlir/Dialect/Mesh/IR/MeshDialect.h.inc is missing (MeshDialect.h:14). Is some CMake missing?

@tblah
Copy link
Contributor

tblah commented Jan 31, 2024

Thank you for the fix in 16c4843. This resolves the issue for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants