-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Mesh] Add sharding propagation pass #69665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Chengji Yao (yaochengji) Changes
Patch is 66.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69665.diff 27 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..b6623ed818f0770 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+ I32EnumAttrCase<"Parallel", 1, "parallel">,
+ I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+ I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+ I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+ I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+ I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::mesh";
+}
+
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
$partial_axes^ `]`)? `>`
}];
+ let builders = [
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes,
+ "ArrayRef<int32_t>": $partial_axes,
+ "mesh::Partial": $partial_type), [{
+ SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
+ llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+ return DenseI32ArrayAttr::get($_ctxt, array);
+ }));
+ return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+ partial_type);
+ }]>,
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+ return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+ }]>
+ ];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..cb86887091330c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,26 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+ for (int64_t i = array.size() - 1; i >= 0; i--) {
+ if (array[i].empty())
+ array.pop_back();
+ else
+ break;
+ }
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..1d19e41ac1fc555
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+ // An array of int array. The sub-array at the i-th position signifies the
+ // mesh axes the i-th loop will be sharded on.
+ ShardingArray shardingArray;
+ SymbolRefAttr cluster;
+ // `empty` is true indicates that no sharding infomation can be inferred at
+ // present. Note that it is different from that an operation is not sharded.
+ bool empty = false;
+ ShardingOption() = default;
+ ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+ : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+namespace detail {
+
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+ const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..c98b9f081492997
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
+ /*args=*/(ins
+ "OpBuilder &":$b
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingOption(
+ $_op.getOperation(), b);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Based on a given ShardingOption, this method adds `mesh.shard`
+ operations for the operands and results that previously lacked
+ sharding annotations.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"addShardingAnnotations",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "const ShardingOption &":$shardingOption
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultAddShardingAnnotations(
+ $_op.getOperation(), b, shardingOption);
+ }]
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ LogicalResult verifyShardingInterfaceImpl();
+
+ void printLoopTypesAndIndexingMaps(raw_ostream &os);
+ }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..aa3555f7f186f24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..d36adfe476a72ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,33 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+ let summary = "sharding propagation";
+ let description = [{
+ Propagates sharding information throughout the graph. After this pass, each
+ of the operations' operands and results is annotated with a `mesh.shard`
+ operation, and the operations themselves are added with sharding option
+ attributes.
+ }];
+ let constructor = "mlir::mesh::createShardingPropagationPass()";
+ let dependentDialects = [
+ "mesh::MeshDialect"
+ ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548ed..b24164cfb552b4f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
+/// Helper to return a vector of sub-vector of int64_t
+SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
+
/// Compute linear index from provided strides and indices, assuming strided
/// layout.
/// Returns AffineExpr and list of values to apply to it, e.g.:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 3430db2b99c3f2e..18e2313ef2b446b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -101,6 +101,18 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ /// Returns an affine map with `numDims` input dimensions and results
+ /// specified by `targets`.
+ ///
+ /// Examples:
+ /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+ /// * getMultiDimMapWithTargets(3, [2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d2, d1)>
+ static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+ ArrayRef<int64_t> targets,
+ MLIRContext *context);
+
/// Returns a vector of AffineMaps; each with as many results as
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
/// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5e54d4ea49e8251..3988835622b7629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -168,6 +168,7 @@ class Builder {
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
ArrayAttr getTypeArrayAttr(TypeRange values);
+ ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..3556f82023828b2 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..27711417ed91a8c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -29,6 +29,7 @@
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -73,6 +74,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
+ mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+ return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+ return (partial == Partial::Generic &&
+ iType == IteratorType::ReductionGeneric) ||
+ (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+ (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+ (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+ switch (iType) {
+ case IteratorType::ReductionGeneric:
+ return Partial::Generic;
+ case IteratorType::ReductionSum:
+ return Partial::Sum;
+ case IteratorType::ReductionMax:
+ return Partial::Max;
+ case Iterator...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
661c516
to
bd81c39
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
A few early comments after a quick scan.
mlir/lib/IR/Builders.cpp
Outdated
Builder::getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values) { | ||
auto attrs = | ||
llvm::map_to_vector<8>(values, [this](ArrayRef<int32_t> v) -> Attribute { | ||
return getI32ArrayAttr(v); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use DenseI32Array here? This is pretty sparse and inefficient storage actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return failure(); | ||
auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName()); | ||
if (!symbolRefAttr) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should move directly behind an API, the getAttr
method of Operation are intended to be deprecated (to use getDiscardableAttr
/getInherentAttr
instead) and implementer of this interface could put these in the Properties storage, and so the best way is to expose all these accessors through the OpInterface
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added getShardingOptionFromAttr
and setShardingOptionAttr
in extraClassDeclaration
|
||
// 1. propagate in reversed order | ||
{ | ||
std::vector<Operation *> curOps = getReversedOperationsVector(block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because new operations (mesh.shard
ops) will be added during propagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only see setShardingOptionAttr(builder, *shardingOption);
and shardingOp.addShardingAnnotations
, where are these added?
Also, wouldn't for (Operation &op : llvm::make_early_inc(block.getOperations())) {
still work? Assuming the shard
ops are added immediately after the current operation (otherwise they can be also just skipped ...)
|
||
// 2. propagate in original order | ||
{ | ||
std::vector<Operation *> curOps = getOperationsVector(block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question: why the copy here?
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) { | ||
ops.insert(shardingOp); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get what the DenseSet is doing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be there, my bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more superficial comment, I haven't dug deeply in the logic yet. Would be good to also have another pair of eyes.
Also what about control flow operations?
InterfaceMethod< | ||
/*desc=*/[{ | ||
Returns a list of iterator types that describe the number of loops. | ||
The iterator types determine how the operation tranverses its input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The iterator types determine how the operation tranverses its input | |
The iterator types determine how the operation traverses its input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry it seems too many typos in this PR. I have proofread the entire PR for spelling errors, and I hope there will be fewer typos in the future.
|
||
FailureOr<ShardingOption> getShardingOptionFromAttr(); | ||
|
||
void setShardingOptionAttr(Builder &b, const ShardingOption& option); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There shoulder be InterfaceMethod
: I believe we should not use discardable attributes, Operation
implementing the interface should declare these as inherent and use direct property access to store the info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the logic of adding ShardingOption
as an attribute first. Because in my second thought, this information might be more suitable in an analysis.
BTW, the reason they were in extraClassDeclaration
not InterfaceMethod
is because they are not supposed to be customized for different operations.
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { | ||
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); | ||
if (!shardOp) | ||
return false; | ||
return !shardOp.getAnnotateForUsers(); | ||
}); | ||
|
||
if (anyShardedForDef) { | ||
assert(val.hasOneUse() && | ||
"expected to has exact one use if it has a use of mesh.shard " | ||
"without unit attr annotate_for_users"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertions seems dangerous to me: we can't have a verifier for this, what about bailing out instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to return failure()
if val.hasOneUse()
.
|
||
// 1. propagate in reversed order | ||
{ | ||
std::vector<Operation *> curOps = getReversedOperationsVector(block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only see setShardingOptionAttr(builder, *shardingOption);
and shardingOp.addShardingAnnotations
, where are these added?
Also, wouldn't for (Operation &op : llvm::make_early_inc(block.getOperations())) {
still work? Assuming the shard
ops are added immediately after the current operation (otherwise they can be also just skipped ...)
They are defined in
Yeah, |
@sogartar could you also review this PR if you have time?
The sharding propagation pass is currently in an early stage. I suggest that it be completed progressively so that other developers can participate. I believe that more comprehensive support needs to be added in the future, including but not limited to control flow support. (Maybe I should add more TODOs in the code) |
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); | ||
} | ||
); | ||
// clang-format on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-format seems to work fine here for me?
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
"block.\n";
for (Operation &op
: block.getOperations()) {
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
(it fixes your indentation issue line 90 where the stmt inside the if
isn't indented right now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
I left some comments. |
@sogartar, Thank you for your careful review. I have responded to all of your comments. |
@yaochengji, I marked most of the discussions as resolved. |
if (!llvm::isa<RankedTensorType>(type)) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we assume that if you have a non-tensor argument it will be replicated across all devices?
9d9b9d7
to
b58581b
Compare
Hi @joker-eph , do you think this PR could be merged? As @sogartar already approved this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine, just minor things to cleanup and I can merge this for you.
@@ -235,6 +235,19 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, | |||
return permutationMap; | |||
} | |||
|
|||
AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Groverkss wanna double check this? (or @ftynse maybe?)
mlir/lib/IR/AffineMap.cpp
Outdated
AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0, context); | ||
int64_t pos = 0; | ||
for (int64_t t : targets) { | ||
result = result.insertResult(getAffineDimExpr(t, context), pos); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do this the same way getPermutationMap does. You could even make getPermutationMap call to this function for some clean up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion. I have modified the code accordingly.
I tweaked the description and landed it, thanks! |
Some tests failed. It seems that |
Seems like the CI does not merge before testing? |
See also the fix in bda763a to add to your next PR |
Thanks @joker-eph, the new PR is created, and the fix in bda763a is added. |
ShardingInterface
and the methods' default implementationShardingInterface
implementation for element-wise and matmul ops in TOSA dialectSmall changes in non-mesh folder:
getMultiDimMapWithTargets
method in AffineMap.h