Skip to content

[mlir][mesh] Dedublicate iterator type and partial type information #81920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
// Mesh Enums.
//===----------------------------------------------------------------------===//

def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
"Reduction of an iterator/mesh dimension.", [
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
Expand All @@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}

def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
let assemblyFormat = "`<` $value `>`";
}

// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
// is partial.
def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"Parallel", 1, "parallel">,
I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
I32EnumAttrCase<"Invalid", 100, "invalid">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mesh";
}

//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
Expand All @@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
"The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
);

let summary = "Attribute that extends tensor type to distributed tensor type.";

let description = [{
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
The MeshSharding attribute is used in a `mesh.shard` operation.
It specifies how a tensor is sharded and distributed across the process
mesh.

1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
Expand All @@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {

4. `partial_type`: indicates the reduction type of the possible all-reduce
op. It has 4 possible values:
- `partial_sum`: denotes it's an all-reduce-sum
- `partial_max`: denotes it's an all-reduce-max
- `partial_min`: denotes it's an all-reduce-min
- `partial_generic`: denotes that the all-reduce type is complex and cannot
be represented merely by a simple sum, max, or min. The exact reduction
computation may be derived from the semantics of the corresponding operation
or from the reduction computation IR
`generic`: is not an allowed value inside a shard attribute.

Example:

Expand Down Expand Up @@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
"mesh::ReductionKind": $partial_type), [{
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
Expand All @@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>,
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
}]>
];

Expand Down
9 changes: 4 additions & 5 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -38,18 +39,16 @@ using MeshAxesAttr = DenseI16ArrayAttr;
namespace mlir {
namespace mesh {

bool isReductionLoop(IteratorType iType);

bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
inline bool isReductionLoop(utils::IteratorType iType) {
return iType == utils::IteratorType::reduction;
}

template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
while (!array.empty() && array.back().empty())
array.pop_back();
}

Partial getPartialTypeFromReduction(IteratorType iType);

// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
Expand Down Expand Up @@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
Expand Down Expand Up @@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

Expand Down
31 changes: 25 additions & 6 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
output tensors.

Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
types are parallel, parallel, reduction-sum. This indicates that M and
types are parallel, parallel, reduction. This indicates that M and
N are traversed in parallel, while the K dimension is used for
reduction.

Example 2: A softmax op's loop iterator types are parallel and
invalid. The second dimension is considered as invalid because it is
neither parallel nor any kind of reduction.
}],
/*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
/*retType=*/"SmallVector<mlir::utils::IteratorType>",
/*methodName=*/"getLoopIteratorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
InterfaceMethod<
/*desc=*/[{
Return the kind of all reduction loop iterators.
The order is the same as the same as the result from
`getLoopIteratorTypes`.

Example 1:
iterator types = (parallel, reduction, parallel, reduction)
|| ||
reduction kinds = ( sum, max)

Example 2:
A softmax op's loop iterator types are parallel and
reduction.
The reduction iterator will be of kind `generic`, since it is non of
the available presets.
}],
/*retType=*/"SmallVector<ReductionKind>",
/*methodName=*/"getReductionLoopIteratorKinds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
Expand Down
17 changes: 10 additions & 7 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ template <typename Op>
struct IndependentParallelIteratorDomainShardingInterface
: public ShardingInterface::ExternalModel<
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
SmallVector<IteratorType> iterTypes;
SmallVector<utils::IteratorType>
getLoopIteratorTypes(Operation *operation) const {
SmallVector<utils::IteratorType> iterTypes;
for (Type t : operation->getOperandTypes()) {
populateIteratorTypes(t, iterTypes);
}
Expand Down Expand Up @@ -65,16 +66,17 @@ struct IndependentParallelIteratorDomainShardingInterface
}

private:
void populateIteratorTypes(Type t,
SmallVector<IteratorType> &iterTypes) const {
void
populateIteratorTypes(Type t,
SmallVector<utils::IteratorType> &iterTypes) const {
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
if (!rankedTensorType) {
return;
}

iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
iterTypes.push_back(IteratorType::Parallel);
iterTypes.push_back(utils::IteratorType::parallel);
}
}
};
Expand All @@ -84,12 +86,13 @@ template <typename ElemwiseOp>
struct ElementwiseShardingInterface
: public ShardingInterface::ExternalModel<
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
if (!type)
return {};
SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
SmallVector<utils::IteratorType> types(type.getRank(),
utils::IteratorType::parallel);
return types;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace mesh {
// the algebraic structure.
template <typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns(
RewritePatternSet &patterns, Partial reduction) {
RewritePatternSet &patterns, ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRSupport
MLIRViewLikeInterface
Expand Down
29 changes: 1 addition & 28 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
return success();
}

bool mesh::isReductionLoop(IteratorType iType) {
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
}

bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
return (partial == Partial::Generic &&
iType == IteratorType::ReductionGeneric) ||
(partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
(partial == Partial::Max && iType == IteratorType::ReductionMax) ||
(partial == Partial::Min && iType == IteratorType::ReductionMin);
}

Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
switch (iType) {
case IteratorType::ReductionGeneric:
return Partial::Generic;
case IteratorType::ReductionSum:
return Partial::Sum;
case IteratorType::ReductionMax:
return Partial::Max;
case IteratorType::ReductionMin:
return Partial::Min;
default:
llvm_unreachable("No corresponding partial type can be found");
}
}

template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
Expand Down Expand Up @@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
ArrayRef<MeshAxis> partialAxes, Partial) {
ArrayRef<MeshAxis> partialAxes, ReductionKind) {
// TODO: At present mesh symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.

Expand Down
Loading