-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Mesh] Add sharding propagation pass #69665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
25d68a6
[MLIR][Mesh] Add sharding propagation pass
yaochengji bd81c39
format code
yaochengji bf1dfdd
fix comments, 1st
yaochengji 95ed43c
fix comments, 2nd
yaochengji f4290d5
remove sharding option attr
yaochengji 50205fc
clang format
yaochengji 2d1c8c6
fix comments, 3rd
yaochengji d254e84
change getShardingOption signature
yaochengji b58581b
minor fix
yaochengji eb477e3
delete unused methods
yaochengji 1b739b2
modify getMultiDimMapWithTargets method
yaochengji File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
add_subdirectory(Interfaces) | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
set(LLVM_TARGET_DEFINITIONS ShardingInterface.td) | ||
mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls) | ||
mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs) | ||
add_public_tablegen_target(MLIRShardingInterfaceIncGen) |
68 changes: 68 additions & 0 deletions
68
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
//===- ShardingInterface.h --------------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ | ||
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ | ||
|
||
#include "mlir/Dialect/Mesh/IR/MeshOps.h" | ||
#include "mlir/Support/LLVM.h" | ||
|
||
namespace mlir { | ||
|
||
class Operation; | ||
|
||
namespace mesh { | ||
|
||
using ShardingArray = SmallVector<SmallVector<int32_t>>; | ||
using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>; | ||
|
||
struct ShardingOption { | ||
// An array of int array. The sub-array at the i-th position signifies the | ||
// mesh axes the i-th loop will be sharded on. | ||
ShardingArray shardingArray; | ||
SymbolRefAttr cluster; | ||
// `empty` being true indicates that no sharding information can be inferred | ||
// at present. Note that it is different from the case where an operation is | ||
// not sharded. | ||
bool empty = false; | ||
ShardingOption() = default; | ||
ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster) | ||
: shardingArray(std::move(shardingArray)), cluster(cluster) {} | ||
}; | ||
|
||
// This method retrieves the 'MeshShardingAttr' attribute from a given operation | ||
// result and includes the 'annotate_for_users' information. | ||
FailureOr<std::pair<bool, MeshShardingAttr>> | ||
getMeshShardingAttr(OpResult result); | ||
|
||
// This method retrieves the 'MeshShardingAttr' attribute from a given operation | ||
// operand and includes the 'annotate_for_users' information. | ||
FailureOr<std::pair<bool, MeshShardingAttr>> | ||
getMeshShardingAttr(OpOperand &opOperand); | ||
|
||
namespace detail { | ||
|
||
FailureOr<ShardingOption> | ||
defaultGetShardingOption(Operation *op, | ||
ArrayRef<MeshShardingAttr> operandShardings, | ||
ArrayRef<MeshShardingAttr> resultShardings); | ||
|
||
LogicalResult | ||
defaultAddShardingAnnotations(Operation *op, OpBuilder &b, | ||
const ShardingOption &shardingOption); | ||
|
||
} // namespace detail | ||
|
||
} // namespace mesh | ||
|
||
} // namespace mlir | ||
|
||
/// Include the ODS generated interface header files. | ||
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc" | ||
|
||
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ |
102 changes: 102 additions & 0 deletions
102
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD | ||
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD | ||
|
||
include "mlir/IR/OpBase.td" | ||
|
||
def ShardingInterface : OpInterface<"ShardingInterface"> { | ||
let description = [{ | ||
Interface for allowing operations to expose information needed to | ||
shard them. | ||
}]; | ||
let cppNamespace = "::mlir::mesh"; | ||
|
||
let methods = [ | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Returns a list of iterator types that describe the number of loops. | ||
The iterator types determine how the operation traverses its input and | ||
output tensors. | ||
|
||
Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator | ||
types are parallel, parallel, reduction-sum. This indicates that M and | ||
N are traversed in parallel, while the K dimension is used for | ||
reduction. | ||
|
||
Example 2: A softmax op's loop iterator types are parallel and | ||
invalid. The second dimension is considered as invalid because it is | ||
neither parallel nor any kind of reduction. | ||
}], | ||
/*retType=*/"SmallVector<::mlir::mesh::IteratorType>", | ||
/*methodName=*/"getLoopIteratorTypes", | ||
/*args=*/(ins), | ||
/*methodBody=*/"", | ||
/*defaultImplementation=*/"return {};" | ||
>, | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Return the indexing maps attribute within the current operation. | ||
joker-eph marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Indexing maps determine how indices in the iteration space map to | ||
tensor indices. They are specified using `affine_map` in MLIR, which | ||
provides an affine transformation of indices. | ||
}], | ||
/*retTy=*/"SmallVector<AffineMap>", | ||
/*methodName=*/"getIndexingMaps", | ||
/*args=*/(ins), | ||
/*methodBody=*/"", | ||
/*defaultImplementation=*/"return {};" | ||
>, | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Given that certain operands or results of the operation may have | ||
sharding annotations, this method leverages this information to deduce | ||
how the operation should be sharded. | ||
joker-eph marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}], | ||
/*retTy=*/"FailureOr<ShardingOption>", | ||
/*methodName=*/"getShardingOption", | ||
sogartar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/*args=*/(ins | ||
"ArrayRef<MeshShardingAttr>": $operandShardings, | ||
"ArrayRef<MeshShardingAttr>": $resultShardings | ||
), | ||
/*methodBody=*/"", | ||
/*defaultImplementation=*/[{ | ||
return detail::defaultGetShardingOption( | ||
$_op.getOperation(), operandShardings, resultShardings); | ||
}] | ||
>, | ||
InterfaceMethod< | ||
sogartar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/*desc=*/[{ | ||
Based on a given ShardingOption, this method adds `mesh.shard` | ||
operations for the operands and results that previously lacked | ||
sharding annotations. | ||
}], | ||
/*retTy=*/"LogicalResult", | ||
/*methodName=*/"addShardingAnnotations", | ||
/*args=*/(ins | ||
"OpBuilder &":$b, | ||
"const ShardingOption &":$shardingOption | ||
), | ||
/*methodBody=*/"", | ||
/*defaultImplementation=*/[{ | ||
return detail::defaultAddShardingAnnotations( | ||
$_op.getOperation(), b, shardingOption); | ||
}] | ||
> | ||
]; | ||
|
||
let extraClassDeclaration = [{ | ||
LogicalResult verifyShardingInterfaceImpl(); | ||
|
||
void printLoopTypesAndIndexingMaps(raw_ostream &os); | ||
}]; | ||
} | ||
|
||
|
||
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh) | ||
add_public_tablegen_target(MLIRMeshPassIncGen) | ||
add_dependencies(mlir-headers MLIRMeshPassIncGen) | ||
|
||
add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H | ||
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
|
||
namespace func { | ||
class FuncOp; | ||
} | ||
|
||
namespace mesh { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Passes | ||
//===----------------------------------------------------------------------===// | ||
|
||
#define GEN_PASS_DECL | ||
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Registration | ||
//===----------------------------------------------------------------------===// | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" | ||
|
||
} // namespace mesh | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
|
||
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD | ||
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ShardingPropagation | ||
//===----------------------------------------------------------------------===// | ||
|
||
def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> { | ||
let summary = "sharding propagation"; | ||
let description = [{ | ||
Propagates sharding information throughout the graph. After this pass, each | ||
of the operations' operands and results is annotated with a `mesh.shard` | ||
operation, and the operations themselves are added with sharding option | ||
attributes. | ||
}]; | ||
let dependentDialects = [ | ||
"mesh::MeshDialect" | ||
]; | ||
} | ||
|
||
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
//===- ShardingInterfaceImpl.h - ------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ | ||
#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ | ||
|
||
namespace mlir { | ||
|
||
class DialectRegistry; | ||
|
||
namespace tosa { | ||
|
||
void registerShardingInterfaceExternalModels(DialectRegistry ®istry); | ||
|
||
} // namespace tosa | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.