Skip to content

[mlir][mesh] Dedublicate iterator type and partial type information #81920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

sogartar
Copy link
Contributor

The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.

Use utils::IteratorType instead of mesh::IteratorType. It now has only parallel and reduction values.

Rename Partial to ReductionKind.

Add getReductionLoopIteratorKinds method to ShardingInterface.

The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.

Use `utils::IteratorType` instead of `mesh::IteratorType`.
It now has only parallel and reduction values.

Rename `Partial` to `ReductionKind`.

Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
@llvmbot
Copy link
Member

llvmbot commented Feb 15, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.

Use utils::IteratorType instead of mesh::IteratorType. It now has only parallel and reduction values.

Rename Partial to ReductionKind.

Add getReductionLoopIteratorKinds method to ShardingInterface.


Patch is 24.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81920.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-30)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+4-5)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+1)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+25-6)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+10-7)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+1-28)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+20-16)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+8-8)
  • (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+9-4)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 04929f4869273d..fc2acc70381ef7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
 
-def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
+  "Reduction of an iterator/mesh dimension.", [
   I32EnumAttrCase<"Sum", 1, "sum">,
   I32EnumAttrCase<"Max", 2, "max">,
   I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
-def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
-  I32EnumAttrCase<"Parallel", 1, "parallel">,
-  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
-  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
-  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
-  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
-  I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
-    let genSpecializedAttr = 0;
-    let cppNamespace = "::mlir::mesh";
-}
-
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
      "The mesh on which tensors are sharded.">:$mesh,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
-    OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+    OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
   );
 
   let summary = "Attribute that extends tensor type to distributed tensor type.";
 
   let description = [{
-    The MeshSharding attribute could be used in the encoding of a
-    `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+    The MeshSharding attribute is used in a `mesh.shard` operation.
+    It specifies how a tensor is sharded and distributed across the process
+    mesh.
 
     1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
     mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     4. `partial_type`: indicates the reduction type of the possible all-reduce
     op. It has 4 possible values:
-    - `partial_sum`: denotes it's an all-reduce-sum
-    - `partial_max`: denotes it's an all-reduce-max
-    - `partial_min`: denotes it's an all-reduce-min
-    - `partial_generic`: denotes that the all-reduce type is complex and cannot
-    be represented merely by a simple sum, max, or min. The exact reduction
-    computation may be derived from the semantics of the corresponding operation
-    or from the reduction computation IR
+    `generic`: is not an allowed value inside a shard attribute.
 
     Example:
 
@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
-                     "mesh::Partial": $partial_type), [{
+                     "mesh::ReductionKind": $partial_type), [{
       SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
                   split_axes, [&](ArrayRef<MeshAxis> array) {
           return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>,
     AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
+      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fb9425b96e68e2..4569b77441c3f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -38,9 +39,9 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 namespace mlir {
 namespace mesh {
 
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+inline bool isReductionLoop(utils::IteratorType iType) {
+  return iType == utils::IteratorType::reduction;
+}
 
 template <typename T>
 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
@@ -48,8 +49,6 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
     array.pop_back();
 }
 
-Partial getPartialTypeFromReduction(IteratorType iType);
-
 // Is the same tensor replicated on all processes.
 inline bool isFullReplication(MeshShardingAttr attr) {
   return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 96636d5347ff6e..8ba7c111aea6bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
     AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
     DenseI64ArrayAttr:$root,
     Variadic<Index>:$root_dynamic
   ));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
     IndexAttr:$scatter_axis
   ));
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index cc90ddd40a6222..c47a7ddd3f9cc3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 4afb1c36a72f7b..1f75135f42882f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           output tensors.
 
           Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
-          types are parallel, parallel, reduction-sum. This indicates that M and
+          types are parallel, parallel, reduction. This indicates that M and
           N are traversed in parallel, while the K dimension is used for
           reduction.
-
-          Example 2: A softmax op's loop iterator types are parallel and
-          invalid. The second dimension is considered as invalid because it is
-          neither parallel nor any kind of reduction. 
         }],
-        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*retType=*/"SmallVector<mlir::utils::IteratorType>",
         /*methodName=*/"getLoopIteratorTypes",
         /*args=*/(ins),
         /*methodBody=*/"",
         /*defaultImplementation=*/"return {};"
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the kind of all reduction loop iterators.
+          The order is the same as the same as the result from
+          `getLoopIteratorTypes`.
+
+          Example 1:
+          iterator types =  (parallel, reduction, parallel, reduction)
+                                             ||                   ||
+          reduction kinds = (                sum,                 max)
+
+          Example 2:
+          A softmax op's loop iterator types are parallel and
+          reduction.
+          The reduction iterator will be of kind `generic`, since it is non of
+          the available presets.
+        }],
+        /*retType=*/"SmallVector<ReductionKind>",
+        /*methodName=*/"getReductionLoopIteratorKinds",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Return the indexing maps attribute within the current operation.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 8108386c2e0437..ffc9b6fb18be53 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -36,8 +36,9 @@ template <typename Op>
 struct IndependentParallelIteratorDomainShardingInterface
     : public ShardingInterface::ExternalModel<
           IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
-    SmallVector<IteratorType> iterTypes;
+  SmallVector<utils::IteratorType>
+  getLoopIteratorTypes(Operation *operation) const {
+    SmallVector<utils::IteratorType> iterTypes;
     for (Type t : operation->getOperandTypes()) {
       populateIteratorTypes(t, iterTypes);
     }
@@ -65,8 +66,9 @@ struct IndependentParallelIteratorDomainShardingInterface
   }
 
 private:
-  void populateIteratorTypes(Type t,
-                             SmallVector<IteratorType> &iterTypes) const {
+  void
+  populateIteratorTypes(Type t,
+                        SmallVector<utils::IteratorType> &iterTypes) const {
     RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
     if (!rankedTensorType) {
       return;
@@ -74,7 +76,7 @@ struct IndependentParallelIteratorDomainShardingInterface
 
     iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
     for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
-      iterTypes.push_back(IteratorType::Parallel);
+      iterTypes.push_back(utils::IteratorType::parallel);
     }
   }
 };
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
 struct ElementwiseShardingInterface
     : public ShardingInterface::ExternalModel<
           ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     Value val = op->getOperand(0);
     auto type = val.getType().dyn_cast<RankedTensorType>();
     if (!type)
       return {};
-    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    SmallVector<utils::IteratorType> types(type.getRank(),
+                                           utils::IteratorType::parallel);
     return types;
   }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f438465251bb06..c64da29ca64123 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -38,7 +38,7 @@ namespace mesh {
 // the algebraic structure.
 template <typename AlgebraicOp>
 void populateAllReduceEndomorphismSimplificationPatterns(
-    RewritePatternSet &patterns, Partial reduction) {
+    RewritePatternSet &patterns, ReductionKind reduction) {
   auto getEndomorphismOpOperand = [](Operation *op) {
     auto allReduceOp = llvm::cast<AllReduceOp>(op);
     return &allReduceOp.getInputMutable();
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 678a25f1c3cf58..45ac9edb280bc9 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRDialectUtils
   MLIRIR
   MLIRSupport
   MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3291010d27428a..838255cf5a5ba3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
   return success();
 }
 
-bool mesh::isReductionLoop(IteratorType iType) {
-  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
-  return (partial == Partial::Generic &&
-          iType == IteratorType::ReductionGeneric) ||
-         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
-         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
-         (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
-  switch (iType) {
-  case IteratorType::ReductionGeneric:
-    return Partial::Generic;
-  case IteratorType::ReductionSum:
-    return Partial::Sum;
-  case IteratorType::ReductionMax:
-    return Partial::Max;
-  case IteratorType::ReductionMin:
-    return Partial::Min;
-  default:
-    llvm_unreachable("No corresponding partial type can be found");
-  }
-}
-
 template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                          FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
-                         ArrayRef<MeshAxis> partialAxes, Partial) {
+                         ArrayRef<MeshAxis> partialAxes, ReductionKind) {
   // TODO: At present mesh symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index b8b3841d947abd..fe3d7c44413fef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/Debug.h"
@@ -163,7 +164,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
       return failure();
 
   // check loop types
-  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
   if (loopTypes.size() == 0)
     return failure();
 
@@ -198,7 +199,7 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
   getOperation()->print(os);
   os << "\n";
   os << "loop types: [";
-  for (IteratorType type : getLoopIteratorTypes()) {
+  for (utils::IteratorType type : getLoopIteratorTypes()) {
     os << stringifyEnum(type) << " ";
   }
   os << "]\n";
@@ -257,12 +258,12 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
   if (failed(shardingOp.verifyShardingInterfaceImpl()))
     return op->emitOpError() << "invalid sharding interface implementation";
-  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes =
+      shardingOp.getLoopIteratorTypes();
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
   shardingOption.shardingArray.resize(loopTypes.size());
   llvm::SmallVector<MeshAxis> partialMeshAxes;
-  Partial partialType;
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
 
@@ -294,7 +295,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       if (!partialMeshAxes.empty())
         return op->emitOpError() << "at most one result with partial axes is "
                                     "supported at present";
-      partialType = shardAttr.getPartialType();
       partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
       // Add all the reduction loop indices to `visitedLoopIndices` if
       // `partialAxes` is not empty
@@ -370,8 +370,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     if (!anyNonEmptyReductionLoop) {
       bool filled = false;
       for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
-        if (isReductionLoop(loopTypes[idx]) &&
-            areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+        if (isReductionLoop(loopTypes[idx])) {
           std::ignore = fillShardingOption(op, shardingOption, nullptr,
                                            partialMeshAxes, idx);
           filled = true;
@@ -398,7 +397,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 const ShardingOption &shardingOption,
                                 AffineMap map,
-                                ArrayRef<IteratorType> loopTypes) {
+                                ArrayRef<utils::IteratorType> loopTypes,
+                                ArrayRef<ReductionKind> reductionLoopKinds) {
   FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
       getMeshShardingAttr(result);
   if (succeeded(maybeSharding) && !maybeSharding->first)
@@ -421,11 +421,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 
   // process the partial axes
   // partialType will be ignored if partialAxes is empty
-  Partial partialType = Partial::Sum;
+  ReductionKind partialType = ReductionKind::Sum;
+  size_t reductionLoopKindsIdx = 0;
   for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
-    IteratorType iType = std::get<0>(it);
+    utils::IteratorType iType = std::get<0>(it);
     if (isReductionLoop(iType)) {
-      Partial curPartialType = getPartialTypeFromReduction(iType);
+      ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
+      ++reductionLoopKindsIdx;
       if (!partialAxes.empty())
         assert(partialType == curPartialType &&
                "Only one reduction type is supported");
@@ -450,8 +452,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 // in `shardingO...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 15, 2024

@llvm/pr-subscribers-mlir-tosa

Author: Boian Petkantchin (sogartar)

Changes

The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.

Use utils::IteratorType instead of mesh::IteratorType. It now has only parallel and reduction values.

Rename Partial to ReductionKind.

Add getReductionLoopIteratorKinds method to ShardingInterface.


Patch is 24.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81920.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-30)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+4-5)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+1)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+25-6)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+10-7)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+1-28)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+20-16)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+8-8)
  • (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+9-4)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 04929f4869273d..fc2acc70381ef7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
 
-def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
+  "Reduction of an iterator/mesh dimension.", [
   I32EnumAttrCase<"Sum", 1, "sum">,
   I32EnumAttrCase<"Max", 2, "max">,
   I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
-def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
-  I32EnumAttrCase<"Parallel", 1, "parallel">,
-  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
-  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
-  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
-  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
-  I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
-    let genSpecializedAttr = 0;
-    let cppNamespace = "::mlir::mesh";
-}
-
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
      "The mesh on which tensors are sharded.">:$mesh,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
-    OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+    OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
   );
 
   let summary = "Attribute that extends tensor type to distributed tensor type.";
 
   let description = [{
-    The MeshSharding attribute could be used in the encoding of a
-    `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+    The MeshSharding attribute is used in a `mesh.shard` operation.
+    It specifies how a tensor is sharded and distributed across the process
+    mesh.
 
     1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
     mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     4. `partial_type`: indicates the reduction type of the possible all-reduce
     op. It has 4 possible values:
-    - `partial_sum`: denotes it's an all-reduce-sum
-    - `partial_max`: denotes it's an all-reduce-max
-    - `partial_min`: denotes it's an all-reduce-min
-    - `partial_generic`: denotes that the all-reduce type is complex and cannot
-    be represented merely by a simple sum, max, or min. The exact reduction
-    computation may be derived from the semantics of the corresponding operation
-    or from the reduction computation IR
+    `generic`: is not an allowed value inside a shard attribute.
 
     Example:
 
@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
-                     "mesh::Partial": $partial_type), [{
+                     "mesh::ReductionKind": $partial_type), [{
       SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
                   split_axes, [&](ArrayRef<MeshAxis> array) {
           return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>,
     AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
+      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fb9425b96e68e2..4569b77441c3f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -38,9 +39,9 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 namespace mlir {
 namespace mesh {
 
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+inline bool isReductionLoop(utils::IteratorType iType) {
+  return iType == utils::IteratorType::reduction;
+}
 
 template <typename T>
 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
@@ -48,8 +49,6 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
     array.pop_back();
 }
 
-Partial getPartialTypeFromReduction(IteratorType iType);
-
 // Is the same tensor replicated on all processes.
 inline bool isFullReplication(MeshShardingAttr attr) {
   return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 96636d5347ff6e..8ba7c111aea6bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
     AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
     DenseI64ArrayAttr:$root,
     Variadic<Index>:$root_dynamic
   ));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
     IndexAttr:$scatter_axis
   ));
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index cc90ddd40a6222..c47a7ddd3f9cc3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 4afb1c36a72f7b..1f75135f42882f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           output tensors.
 
           Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
-          types are parallel, parallel, reduction-sum. This indicates that M and
+          types are parallel, parallel, reduction. This indicates that M and
           N are traversed in parallel, while the K dimension is used for
           reduction.
-
-          Example 2: A softmax op's loop iterator types are parallel and
-          invalid. The second dimension is considered as invalid because it is
-          neither parallel nor any kind of reduction. 
         }],
-        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*retType=*/"SmallVector<mlir::utils::IteratorType>",
         /*methodName=*/"getLoopIteratorTypes",
         /*args=*/(ins),
         /*methodBody=*/"",
         /*defaultImplementation=*/"return {};"
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the kind of all reduction loop iterators.
+          The order is the same as the same as the result from
+          `getLoopIteratorTypes`.
+
+          Example 1:
+          iterator types =  (parallel, reduction, parallel, reduction)
+                                             ||                   ||
+          reduction kinds = (                sum,                 max)
+
+          Example 2:
+          A softmax op's loop iterator types are parallel and
+          reduction.
+          The reduction iterator will be of kind `generic`, since it is non of
+          the available presets.
+        }],
+        /*retType=*/"SmallVector<ReductionKind>",
+        /*methodName=*/"getReductionLoopIteratorKinds",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Return the indexing maps attribute within the current operation.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 8108386c2e0437..ffc9b6fb18be53 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -36,8 +36,9 @@ template <typename Op>
 struct IndependentParallelIteratorDomainShardingInterface
     : public ShardingInterface::ExternalModel<
           IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
-    SmallVector<IteratorType> iterTypes;
+  SmallVector<utils::IteratorType>
+  getLoopIteratorTypes(Operation *operation) const {
+    SmallVector<utils::IteratorType> iterTypes;
     for (Type t : operation->getOperandTypes()) {
       populateIteratorTypes(t, iterTypes);
     }
@@ -65,8 +66,9 @@ struct IndependentParallelIteratorDomainShardingInterface
   }
 
 private:
-  void populateIteratorTypes(Type t,
-                             SmallVector<IteratorType> &iterTypes) const {
+  void
+  populateIteratorTypes(Type t,
+                        SmallVector<utils::IteratorType> &iterTypes) const {
     RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
     if (!rankedTensorType) {
       return;
@@ -74,7 +76,7 @@ struct IndependentParallelIteratorDomainShardingInterface
 
     iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
     for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
-      iterTypes.push_back(IteratorType::Parallel);
+      iterTypes.push_back(utils::IteratorType::parallel);
     }
   }
 };
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
 struct ElementwiseShardingInterface
     : public ShardingInterface::ExternalModel<
           ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     Value val = op->getOperand(0);
     auto type = val.getType().dyn_cast<RankedTensorType>();
     if (!type)
       return {};
-    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    SmallVector<utils::IteratorType> types(type.getRank(),
+                                           utils::IteratorType::parallel);
     return types;
   }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f438465251bb06..c64da29ca64123 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -38,7 +38,7 @@ namespace mesh {
 // the algebraic structure.
 template <typename AlgebraicOp>
 void populateAllReduceEndomorphismSimplificationPatterns(
-    RewritePatternSet &patterns, Partial reduction) {
+    RewritePatternSet &patterns, ReductionKind reduction) {
   auto getEndomorphismOpOperand = [](Operation *op) {
     auto allReduceOp = llvm::cast<AllReduceOp>(op);
     return &allReduceOp.getInputMutable();
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 678a25f1c3cf58..45ac9edb280bc9 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRDialectUtils
   MLIRIR
   MLIRSupport
   MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3291010d27428a..838255cf5a5ba3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
   return success();
 }
 
-bool mesh::isReductionLoop(IteratorType iType) {
-  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
-  return (partial == Partial::Generic &&
-          iType == IteratorType::ReductionGeneric) ||
-         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
-         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
-         (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
-  switch (iType) {
-  case IteratorType::ReductionGeneric:
-    return Partial::Generic;
-  case IteratorType::ReductionSum:
-    return Partial::Sum;
-  case IteratorType::ReductionMax:
-    return Partial::Max;
-  case IteratorType::ReductionMin:
-    return Partial::Min;
-  default:
-    llvm_unreachable("No corresponding partial type can be found");
-  }
-}
-
 template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                          FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
-                         ArrayRef<MeshAxis> partialAxes, Partial) {
+                         ArrayRef<MeshAxis> partialAxes, ReductionKind) {
   // TODO: At present mesh symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index b8b3841d947abd..fe3d7c44413fef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/Debug.h"
@@ -163,7 +164,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
       return failure();
 
   // check loop types
-  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
   if (loopTypes.size() == 0)
     return failure();
 
@@ -198,7 +199,7 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
   getOperation()->print(os);
   os << "\n";
   os << "loop types: [";
-  for (IteratorType type : getLoopIteratorTypes()) {
+  for (utils::IteratorType type : getLoopIteratorTypes()) {
     os << stringifyEnum(type) << " ";
   }
   os << "]\n";
@@ -257,12 +258,12 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
   if (failed(shardingOp.verifyShardingInterfaceImpl()))
     return op->emitOpError() << "invalid sharding interface implementation";
-  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes =
+      shardingOp.getLoopIteratorTypes();
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
   shardingOption.shardingArray.resize(loopTypes.size());
   llvm::SmallVector<MeshAxis> partialMeshAxes;
-  Partial partialType;
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
 
@@ -294,7 +295,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       if (!partialMeshAxes.empty())
         return op->emitOpError() << "at most one result with partial axes is "
                                     "supported at present";
-      partialType = shardAttr.getPartialType();
       partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
       // Add all the reduction loop indices to `visitedLoopIndices` if
       // `partialAxes` is not empty
@@ -370,8 +370,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     if (!anyNonEmptyReductionLoop) {
       bool filled = false;
       for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
-        if (isReductionLoop(loopTypes[idx]) &&
-            areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+        if (isReductionLoop(loopTypes[idx])) {
           std::ignore = fillShardingOption(op, shardingOption, nullptr,
                                            partialMeshAxes, idx);
           filled = true;
@@ -398,7 +397,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 const ShardingOption &shardingOption,
                                 AffineMap map,
-                                ArrayRef<IteratorType> loopTypes) {
+                                ArrayRef<utils::IteratorType> loopTypes,
+                                ArrayRef<ReductionKind> reductionLoopKinds) {
   FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
       getMeshShardingAttr(result);
   if (succeeded(maybeSharding) && !maybeSharding->first)
@@ -421,11 +421,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 
   // process the partial axes
   // partialType will be ignored if partialAxes is empty
-  Partial partialType = Partial::Sum;
+  ReductionKind partialType = ReductionKind::Sum;
+  size_t reductionLoopKindsIdx = 0;
   for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
-    IteratorType iType = std::get<0>(it);
+    utils::IteratorType iType = std::get<0>(it);
     if (isReductionLoop(iType)) {
-      Partial curPartialType = getPartialTypeFromReduction(iType);
+      ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
+      ++reductionLoopKindsIdx;
       if (!partialAxes.empty())
         assert(partialType == curPartialType &&
                "Only one reduction type is supported");
@@ -450,8 +452,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 // in `shardingO...
[truncated]

@sogartar
Copy link
Contributor Author

@yaochengji, could you review this PR?

@sogartar sogartar requested a review from joker-eph February 15, 2024 21:09
@sogartar sogartar merged commit ff2720d into llvm:main Feb 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants