-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Dedublicate iterator type and partial type information #81920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] Dedublicate iterator type and partial type information #81920
Conversation
The two types duplicated mostly the same values. Here they are decomposed to carry orthogonal and complimentary information. Use `utils::IteratorType` instead of `mesh::IteratorType`. It now has only parallel and reduction values. Rename `Partial` to `ReductionKind`. Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesThe two types duplicated mostly the same values. Use Rename Add Patch is 24.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81920.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 04929f4869273d..fc2acc70381ef7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
// Mesh Enums.
//===----------------------------------------------------------------------===//
-def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
+ "Reduction of an iterator/mesh dimension.", [
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
-def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
let assemblyFormat = "`<` $value `>`";
}
-// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
- I32EnumAttrCase<"Parallel", 1, "parallel">,
- I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
- I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
- I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
- I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
- I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::mesh";
-}
-
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
"The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
- OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+ OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
);
let summary = "Attribute that extends tensor type to distributed tensor type.";
let description = [{
- The MeshSharding attribute could be used in the encoding of a
- `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+ The MeshSharding attribute is used in a `mesh.shard` operation.
+ It specifies how a tensor is sharded and distributed across the process
+ mesh.
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
4. `partial_type`: indicates the reduction type of the possible all-reduce
op. It has 4 possible values:
- - `partial_sum`: denotes it's an all-reduce-sum
- - `partial_max`: denotes it's an all-reduce-max
- - `partial_min`: denotes it's an all-reduce-min
- - `partial_generic`: denotes that the all-reduce type is complex and cannot
- be represented merely by a simple sum, max, or min. The exact reduction
- computation may be derived from the semantics of the corresponding operation
- or from the reduction computation IR
+ `generic`: is not an allowed value inside a shard attribute.
Example:
@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
- "mesh::Partial": $partial_type), [{
+ "mesh::ReductionKind": $partial_type), [{
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>,
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
+ return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fb9425b96e68e2..4569b77441c3f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -38,9 +39,9 @@ using MeshAxesAttr = DenseI16ArrayAttr;
namespace mlir {
namespace mesh {
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+inline bool isReductionLoop(utils::IteratorType iType) {
+ return iType == utils::IteratorType::reduction;
+}
template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
@@ -48,8 +49,6 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
array.pop_back();
}
-Partial getPartialTypeFromReduction(IteratorType iType);
-
// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 96636d5347ff6e..8ba7c111aea6bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index cc90ddd40a6222..c47a7ddd3f9cc3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 4afb1c36a72f7b..1f75135f42882f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
output tensors.
Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
- types are parallel, parallel, reduction-sum. This indicates that M and
+ types are parallel, parallel, reduction. This indicates that M and
N are traversed in parallel, while the K dimension is used for
reduction.
-
- Example 2: A softmax op's loop iterator types are parallel and
- invalid. The second dimension is considered as invalid because it is
- neither parallel nor any kind of reduction.
}],
- /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*retType=*/"SmallVector<mlir::utils::IteratorType>",
/*methodName=*/"getLoopIteratorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the kind of all reduction loop iterators.
+ The order is the same as the same as the result from
+ `getLoopIteratorTypes`.
+
+ Example 1:
+ iterator types = (parallel, reduction, parallel, reduction)
+ || ||
+ reduction kinds = ( sum, max)
+
+ Example 2:
+ A softmax op's loop iterator types are parallel and
+ reduction.
+ The reduction iterator will be of kind `generic`, since it is non of
+ the available presets.
+ }],
+ /*retType=*/"SmallVector<ReductionKind>",
+ /*methodName=*/"getReductionLoopIteratorKinds",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 8108386c2e0437..ffc9b6fb18be53 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -36,8 +36,9 @@ template <typename Op>
struct IndependentParallelIteratorDomainShardingInterface
: public ShardingInterface::ExternalModel<
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
- SmallVector<IteratorType> iterTypes;
+ SmallVector<utils::IteratorType>
+ getLoopIteratorTypes(Operation *operation) const {
+ SmallVector<utils::IteratorType> iterTypes;
for (Type t : operation->getOperandTypes()) {
populateIteratorTypes(t, iterTypes);
}
@@ -65,8 +66,9 @@ struct IndependentParallelIteratorDomainShardingInterface
}
private:
- void populateIteratorTypes(Type t,
- SmallVector<IteratorType> &iterTypes) const {
+ void
+ populateIteratorTypes(Type t,
+ SmallVector<utils::IteratorType> &iterTypes) const {
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
if (!rankedTensorType) {
return;
@@ -74,7 +76,7 @@ struct IndependentParallelIteratorDomainShardingInterface
iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
- iterTypes.push_back(IteratorType::Parallel);
+ iterTypes.push_back(utils::IteratorType::parallel);
}
}
};
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
struct ElementwiseShardingInterface
: public ShardingInterface::ExternalModel<
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
if (!type)
return {};
- SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+ SmallVector<utils::IteratorType> types(type.getRank(),
+ utils::IteratorType::parallel);
return types;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f438465251bb06..c64da29ca64123 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -38,7 +38,7 @@ namespace mesh {
// the algebraic structure.
template <typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns(
- RewritePatternSet &patterns, Partial reduction) {
+ RewritePatternSet &patterns, ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 678a25f1c3cf58..45ac9edb280bc9 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRDialectUtils
MLIRIR
MLIRSupport
MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3291010d27428a..838255cf5a5ba3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
return success();
}
-bool mesh::isReductionLoop(IteratorType iType) {
- return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
- return (partial == Partial::Generic &&
- iType == IteratorType::ReductionGeneric) ||
- (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
- (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
- (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
- switch (iType) {
- case IteratorType::ReductionGeneric:
- return Partial::Generic;
- case IteratorType::ReductionSum:
- return Partial::Sum;
- case IteratorType::ReductionMax:
- return Partial::Max;
- case IteratorType::ReductionMin:
- return Partial::Min;
- default:
- llvm_unreachable("No corresponding partial type can be found");
- }
-}
-
template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
- ArrayRef<MeshAxis> partialAxes, Partial) {
+ ArrayRef<MeshAxis> partialAxes, ReductionKind) {
// TODO: At present mesh symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index b8b3841d947abd..fe3d7c44413fef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
@@ -163,7 +164,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
return failure();
// check loop types
- SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+ SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
if (loopTypes.size() == 0)
return failure();
@@ -198,7 +199,7 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
getOperation()->print(os);
os << "\n";
os << "loop types: [";
- for (IteratorType type : getLoopIteratorTypes()) {
+ for (utils::IteratorType type : getLoopIteratorTypes()) {
os << stringifyEnum(type) << " ";
}
os << "]\n";
@@ -257,12 +258,12 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (failed(shardingOp.verifyShardingInterfaceImpl()))
return op->emitOpError() << "invalid sharding interface implementation";
- SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+ SmallVector<utils::IteratorType> loopTypes =
+ shardingOp.getLoopIteratorTypes();
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
shardingOption.shardingArray.resize(loopTypes.size());
llvm::SmallVector<MeshAxis> partialMeshAxes;
- Partial partialType;
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
bool anyShardingInResultsOrOperands = false;
@@ -294,7 +295,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!partialMeshAxes.empty())
return op->emitOpError() << "at most one result with partial axes is "
"supported at present";
- partialType = shardAttr.getPartialType();
partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
// Add all the reduction loop indices to `visitedLoopIndices` if
// `partialAxes` is not empty
@@ -370,8 +370,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!anyNonEmptyReductionLoop) {
bool filled = false;
for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
- if (isReductionLoop(loopTypes[idx]) &&
- areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+ if (isReductionLoop(loopTypes[idx])) {
std::ignore = fillShardingOption(op, shardingOption, nullptr,
partialMeshAxes, idx);
filled = true;
@@ -398,7 +397,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
- ArrayRef<IteratorType> loopTypes) {
+ ArrayRef<utils::IteratorType> loopTypes,
+ ArrayRef<ReductionKind> reductionLoopKinds) {
FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
getMeshShardingAttr(result);
if (succeeded(maybeSharding) && !maybeSharding->first)
@@ -421,11 +421,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
// process the partial axes
// partialType will be ignored if partialAxes is empty
- Partial partialType = Partial::Sum;
+ ReductionKind partialType = ReductionKind::Sum;
+ size_t reductionLoopKindsIdx = 0;
for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
- IteratorType iType = std::get<0>(it);
+ utils::IteratorType iType = std::get<0>(it);
if (isReductionLoop(iType)) {
- Partial curPartialType = getPartialTypeFromReduction(iType);
+ ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
+ ++reductionLoopKindsIdx;
if (!partialAxes.empty())
assert(partialType == curPartialType &&
"Only one reduction type is supported");
@@ -450,8 +452,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
// in `shardingO...
[truncated]
|
@llvm/pr-subscribers-mlir-tosa Author: Boian Petkantchin (sogartar) ChangesThe two types duplicated mostly the same values. Use Rename Add Patch is 24.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81920.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 04929f4869273d..fc2acc70381ef7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
// Mesh Enums.
//===----------------------------------------------------------------------===//
-def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
+ "Reduction of an iterator/mesh dimension.", [
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
-def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
let assemblyFormat = "`<` $value `>`";
}
-// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
- I32EnumAttrCase<"Parallel", 1, "parallel">,
- I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
- I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
- I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
- I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
- I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::mesh";
-}
-
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
"The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
- OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+ OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
);
let summary = "Attribute that extends tensor type to distributed tensor type.";
let description = [{
- The MeshSharding attribute could be used in the encoding of a
- `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+ The MeshSharding attribute is used in a `mesh.shard` operation.
+ It specifies how a tensor is sharded and distributed across the process
+ mesh.
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
4. `partial_type`: indicates the reduction type of the possible all-reduce
op. It has 4 possible values:
- - `partial_sum`: denotes it's an all-reduce-sum
- - `partial_max`: denotes it's an all-reduce-max
- - `partial_min`: denotes it's an all-reduce-min
- - `partial_generic`: denotes that the all-reduce type is complex and cannot
- be represented merely by a simple sum, max, or min. The exact reduction
- computation may be derived from the semantics of the corresponding operation
- or from the reduction computation IR
+ `generic`: is not an allowed value inside a shard attribute.
Example:
@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
- "mesh::Partial": $partial_type), [{
+ "mesh::ReductionKind": $partial_type), [{
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>,
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
+ return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fb9425b96e68e2..4569b77441c3f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -38,9 +39,9 @@ using MeshAxesAttr = DenseI16ArrayAttr;
namespace mlir {
namespace mesh {
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+inline bool isReductionLoop(utils::IteratorType iType) {
+ return iType == utils::IteratorType::reduction;
+}
template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
@@ -48,8 +49,6 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
array.pop_back();
}
-Partial getPartialTypeFromReduction(IteratorType iType);
-
// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 96636d5347ff6e..8ba7c111aea6bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index cc90ddd40a6222..c47a7ddd3f9cc3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 4afb1c36a72f7b..1f75135f42882f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
output tensors.
Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
- types are parallel, parallel, reduction-sum. This indicates that M and
+ types are parallel, parallel, reduction. This indicates that M and
N are traversed in parallel, while the K dimension is used for
reduction.
-
- Example 2: A softmax op's loop iterator types are parallel and
- invalid. The second dimension is considered as invalid because it is
- neither parallel nor any kind of reduction.
}],
- /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*retType=*/"SmallVector<mlir::utils::IteratorType>",
/*methodName=*/"getLoopIteratorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the kind of all reduction loop iterators.
+ The order is the same as the same as the result from
+ `getLoopIteratorTypes`.
+
+ Example 1:
+ iterator types = (parallel, reduction, parallel, reduction)
+ || ||
+ reduction kinds = ( sum, max)
+
+ Example 2:
+ A softmax op's loop iterator types are parallel and
+ reduction.
+ The reduction iterator will be of kind `generic`, since it is non of
+ the available presets.
+ }],
+ /*retType=*/"SmallVector<ReductionKind>",
+ /*methodName=*/"getReductionLoopIteratorKinds",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 8108386c2e0437..ffc9b6fb18be53 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -36,8 +36,9 @@ template <typename Op>
struct IndependentParallelIteratorDomainShardingInterface
: public ShardingInterface::ExternalModel<
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
- SmallVector<IteratorType> iterTypes;
+ SmallVector<utils::IteratorType>
+ getLoopIteratorTypes(Operation *operation) const {
+ SmallVector<utils::IteratorType> iterTypes;
for (Type t : operation->getOperandTypes()) {
populateIteratorTypes(t, iterTypes);
}
@@ -65,8 +66,9 @@ struct IndependentParallelIteratorDomainShardingInterface
}
private:
- void populateIteratorTypes(Type t,
- SmallVector<IteratorType> &iterTypes) const {
+ void
+ populateIteratorTypes(Type t,
+ SmallVector<utils::IteratorType> &iterTypes) const {
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
if (!rankedTensorType) {
return;
@@ -74,7 +76,7 @@ struct IndependentParallelIteratorDomainShardingInterface
iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
- iterTypes.push_back(IteratorType::Parallel);
+ iterTypes.push_back(utils::IteratorType::parallel);
}
}
};
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
struct ElementwiseShardingInterface
: public ShardingInterface::ExternalModel<
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
if (!type)
return {};
- SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+ SmallVector<utils::IteratorType> types(type.getRank(),
+ utils::IteratorType::parallel);
return types;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f438465251bb06..c64da29ca64123 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -38,7 +38,7 @@ namespace mesh {
// the algebraic structure.
template <typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns(
- RewritePatternSet &patterns, Partial reduction) {
+ RewritePatternSet &patterns, ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 678a25f1c3cf58..45ac9edb280bc9 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRDialectUtils
MLIRIR
MLIRSupport
MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3291010d27428a..838255cf5a5ba3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
return success();
}
-bool mesh::isReductionLoop(IteratorType iType) {
- return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
- return (partial == Partial::Generic &&
- iType == IteratorType::ReductionGeneric) ||
- (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
- (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
- (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
- switch (iType) {
- case IteratorType::ReductionGeneric:
- return Partial::Generic;
- case IteratorType::ReductionSum:
- return Partial::Sum;
- case IteratorType::ReductionMax:
- return Partial::Max;
- case IteratorType::ReductionMin:
- return Partial::Min;
- default:
- llvm_unreachable("No corresponding partial type can be found");
- }
-}
-
template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
- ArrayRef<MeshAxis> partialAxes, Partial) {
+ ArrayRef<MeshAxis> partialAxes, ReductionKind) {
// TODO: At present mesh symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index b8b3841d947abd..fe3d7c44413fef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
@@ -163,7 +164,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
return failure();
// check loop types
- SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+ SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
if (loopTypes.size() == 0)
return failure();
@@ -198,7 +199,7 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
getOperation()->print(os);
os << "\n";
os << "loop types: [";
- for (IteratorType type : getLoopIteratorTypes()) {
+ for (utils::IteratorType type : getLoopIteratorTypes()) {
os << stringifyEnum(type) << " ";
}
os << "]\n";
@@ -257,12 +258,12 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (failed(shardingOp.verifyShardingInterfaceImpl()))
return op->emitOpError() << "invalid sharding interface implementation";
- SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+ SmallVector<utils::IteratorType> loopTypes =
+ shardingOp.getLoopIteratorTypes();
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
shardingOption.shardingArray.resize(loopTypes.size());
llvm::SmallVector<MeshAxis> partialMeshAxes;
- Partial partialType;
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
bool anyShardingInResultsOrOperands = false;
@@ -294,7 +295,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!partialMeshAxes.empty())
return op->emitOpError() << "at most one result with partial axes is "
"supported at present";
- partialType = shardAttr.getPartialType();
partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
// Add all the reduction loop indices to `visitedLoopIndices` if
// `partialAxes` is not empty
@@ -370,8 +370,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!anyNonEmptyReductionLoop) {
bool filled = false;
for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
- if (isReductionLoop(loopTypes[idx]) &&
- areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+ if (isReductionLoop(loopTypes[idx])) {
std::ignore = fillShardingOption(op, shardingOption, nullptr,
partialMeshAxes, idx);
filled = true;
@@ -398,7 +397,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
- ArrayRef<IteratorType> loopTypes) {
+ ArrayRef<utils::IteratorType> loopTypes,
+ ArrayRef<ReductionKind> reductionLoopKinds) {
FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
getMeshShardingAttr(result);
if (succeeded(maybeSharding) && !maybeSharding->first)
@@ -421,11 +421,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
// process the partial axes
// partialType will be ignored if partialAxes is empty
- Partial partialType = Partial::Sum;
+ ReductionKind partialType = ReductionKind::Sum;
+ size_t reductionLoopKindsIdx = 0;
for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
- IteratorType iType = std::get<0>(it);
+ utils::IteratorType iType = std::get<0>(it);
if (isReductionLoop(iType)) {
- Partial curPartialType = getPartialTypeFromReduction(iType);
+ ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
+ ++reductionLoopKindsIdx;
if (!partialAxes.empty())
assert(partialType == curPartialType &&
"Only one reduction type is supported");
@@ -450,8 +452,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
// in `shardingO...
[truncated]
|
@yaochengji, could you review this PR? |
The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.
Use
utils::IteratorType
instead ofmesh::IteratorType
. It now has only parallel and reduction values.Rename
Partial
toReductionKind
.Add
getReductionLoopIteratorKinds
method toShardingInterface
.