Skip to content

[MLIR][Mesh] Add sharding propagation pass #71261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Transforms)
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}

// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
// is partial.
def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"Parallel", 1, "parallel">,
I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
I32EnumAttrCase<"Invalid", 100, "invalid">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mesh";
}

//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
$partial_axes^ `]`)? `>`
}];

let builders = [
AttrBuilder<(ins "SymbolRefAttr":$cluster,
"ArrayRef<SmallVector<int32_t>>":$split_axes,
"ArrayRef<int32_t>": $partial_axes,
"mesh::Partial": $partial_type), [{
SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<int32_t> array) {
return DenseI32ArrayAttr::get($_ctxt, array);
});
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
partial_type);
}]>,
AttrBuilder<(ins "SymbolRefAttr":$cluster,
"ArrayRef<SmallVector<int32_t>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
}]>
];

let genVerifyDecl = 1;
}

Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,22 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"

namespace mlir {
namespace mesh {

bool isReductionLoop(IteratorType iType);

bool areReductionAndPartialMatch(IteratorType iType, Partial partial);

template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
while (!array.empty() && array.back().empty())
array.pop_back();
}

Partial getPartialTypeFromReduction(IteratorType iType);

} // namespace mesh
} // namespace mlir

#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRShardingInterfaceIncGen)
68 changes: 68 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Support/LLVM.h"

namespace mlir {

class Operation;

namespace mesh {

using ShardingArray = SmallVector<SmallVector<int32_t>>;
using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;

struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
// mesh axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
SymbolRefAttr cluster = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is different from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
};

// This method retrieves the 'MeshShardingAttr' attribute from a given operation
// result and includes the 'annotate_for_users' information.
FailureOr<std::pair<bool, MeshShardingAttr>>
getMeshShardingAttr(OpResult result);

// This method retrieves the 'MeshShardingAttr' attribute from a given operation
// operand and includes the 'annotate_for_users' information.
FailureOr<std::pair<bool, MeshShardingAttr>>
getMeshShardingAttr(OpOperand &opOperand);

namespace detail {

FailureOr<ShardingOption>
defaultGetShardingOption(Operation *op,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings);

LogicalResult
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
const ShardingOption &shardingOption);

} // namespace detail

} // namespace mesh

} // namespace mlir

/// Include the ODS generated interface header files.
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"

#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
102 changes: 102 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD

include "mlir/IR/OpBase.td"

def ShardingInterface : OpInterface<"ShardingInterface"> {
let description = [{
Interface for allowing operations to expose information needed to
shard them.
}];
let cppNamespace = "::mlir::mesh";

let methods = [
InterfaceMethod<
/*desc=*/[{
Returns a list of iterator types that describe the number of loops.
The iterator types determine how the operation traverses its input and
output tensors.

Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
types are parallel, parallel, reduction-sum. This indicates that M and
N are traversed in parallel, while the K dimension is used for
reduction.

Example 2: A softmax op's loop iterator types are parallel and
invalid. The second dimension is considered as invalid because it is
neither parallel nor any kind of reduction.
}],
/*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
/*methodName=*/"getLoopIteratorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
Indexing maps determine how indices in the iteration space map to
tensor indices. They are specified using `affine_map` in MLIR, which
provides an affine transformation of indices.
}],
/*retTy=*/"SmallVector<AffineMap>",
/*methodName=*/"getIndexingMaps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
InterfaceMethod<
/*desc=*/[{
Given that certain operands or results of the operation may have
sharding annotations, this method leverages this information to deduce
how the operation should be sharded.
}],
/*retTy=*/"FailureOr<ShardingOption>",
/*methodName=*/"getShardingOption",
/*args=*/(ins
"ArrayRef<MeshShardingAttr>": $operandShardings,
"ArrayRef<MeshShardingAttr>": $resultShardings
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return detail::defaultGetShardingOption(
$_op.getOperation(), operandShardings, resultShardings);
}]
>,
InterfaceMethod<
/*desc=*/[{
Based on a given ShardingOption, this method adds `mesh.shard`
operations for the operands and results that previously lacked
sharding annotations.
}],
/*retTy=*/"LogicalResult",
/*methodName=*/"addShardingAnnotations",
/*args=*/(ins
"OpBuilder &":$b,
"const ShardingOption &":$shardingOption
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return detail::defaultAddShardingAnnotations(
$_op.getOperation(), b, shardingOption);
}]
>
];

let extraClassDeclaration = [{
LogicalResult verifyShardingInterfaceImpl();

void printLoopTypesAndIndexingMaps(raw_ostream &os);
}];
}


#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
add_public_tablegen_target(MLIRMeshPassIncGen)
add_dependencies(mlir-headers MLIRMeshPassIncGen)

add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H

#include "mlir/Pass/Pass.h"

namespace mlir {

namespace func {
class FuncOp;
}

namespace mesh {

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

} // namespace mesh
} // namespace mlir

#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//


#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD

include "mlir/Pass/PassBase.td"

//===----------------------------------------------------------------------===//
// ShardingPropagation
//===----------------------------------------------------------------------===//

def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
let summary = "sharding propagation";
let description = [{
Propagates sharding information throughout the graph. After this pass, each
of the operations' operands and results is annotated with a `mesh.shard`
operation, and the operations themselves are added with sharding option
attributes.
}];
let dependentDialects = [
"mesh::MeshDialect"
];
}

#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

namespace mlir {

class DialectRegistry;

namespace tosa {

void registerShardingInterfaceExternalModels(DialectRegistry &registry);

} // namespace tosa
} // namespace mlir

#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
12 changes: 12 additions & 0 deletions mlir/include/mlir/IR/AffineMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);

/// Returns an affine map with `numDims` input dimensions and results
/// specified by `targets`.
///
/// Examples:
/// * getMultiDimMapWithTargets(3, [0, 2, 1])
/// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
/// * getMultiDimMapWithTargets(3, [2, 1])
/// -> affine_map<(d0, d1, d2) -> (d2, d1)>
static AffineMap getMultiDimMapWithTargets(unsigned numDims,
ArrayRef<unsigned> targets,
MLIRContext *context);

/// Returns a vector of AffineMaps; each with as many results as
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
/// symbols as the largest symbol in `exprs`.
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
Expand Down Expand Up @@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerSubsetOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
Expand Down
Loading