Skip to content

Commit 0fd50ec

Browse files
authored
[MLIR][mesh] Mesh fixes (#124724)
A collection of fixes to the mesh dialect - allow constants in sharding propagation/spmdization - fixes to tensor replication (e.g. 0d tensors) - improved canonicalization - sharding propagation incorrectly generated too many ShardOps New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)
1 parent 0e779ad commit 0fd50ec

File tree

17 files changed

+525
-89
lines changed

17 files changed

+525
-89
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
10+
#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace arith {
17+
18+
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
19+
20+
} // namespace arith
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class MeshSharding {
5151
SmallVector<Value> dynamic_sharded_dims_offsets;
5252

5353
public:
54-
MeshSharding() = default;
54+
MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
5555
MeshSharding(Value rhs);
5656
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
5757
ArrayRef<MeshAxesAttr> split_axes_,
@@ -62,7 +62,7 @@ class MeshSharding {
6262
ArrayRef<Value> dynamic_halo_sizes_ = {},
6363
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
6464
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
65-
::llvm::StringRef getMesh() const { return mesh.getValue(); }
65+
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
6666
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
6767
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
6868
ReductionKind getPartialType() const { return partial_type; }
@@ -201,10 +201,12 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
201201
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
202202

203203
// Insert shard op if there is not one that already has the same sharding.
204+
// Use newShardOp if it is not null. Otherwise create a new one.
204205
// May insert resharding if required.
206+
// Potentially updates newShardOp.
205207
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
206-
OpOperand &operand,
207-
OpBuilder &builder);
208+
OpOperand &operand, OpBuilder &builder,
209+
ShardOp &newShardOp);
208210
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
209211
OpBuilder &builder);
210212
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
2828
Op<Mesh_Dialect, mnemonic, traits> {
2929
}
3030

31-
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
31+
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
3232
let summary = "Description of a device/process mesh.";
3333
let description = [{
3434
The mesh.mesh operation is a symbol operation that identifies a specific
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
318318
"ArrayRef<MeshAxesAttr>":$split_axes,
319319
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
320320
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
321+
OpBuilder<(ins "llvm::StringRef":$mesh,
322+
"ArrayRef<MeshAxesAttr>":$split_axes,
323+
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
324+
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
325+
)>,
321326
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
322327
];
323328
let hasVerifier = 1;
324329
let hasCanonicalizer = 1;
325330
}
326331

332+
def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
333+
let summary = "Get the sharding of the given tensor.";
334+
let description = [{
335+
This operation returns the sharding of the given tensor as a MeshSharding.
336+
}];
337+
let arguments = (ins
338+
AnyRankedTensor:$source
339+
);
340+
let results = (outs
341+
Mesh_Sharding:$result
342+
);
343+
let assemblyFormat = [{
344+
$source attr-dict `:` type($source) `->` type($result)
345+
}];
346+
}
347+
327348
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
328349
let summary = "Get the shard shape of a given process/device.";
329350
let description = [{
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
460481
(`annotate_for_users` $annotate_for_users^)?
461482
attr-dict `:` type($result)
462483
}];
484+
let hasCanonicalizer = 1;
463485
}
464486

465487
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ struct ShardingOption {
3636
bool empty = false;
3737
ShardingOption() = default;
3838
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
39-
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
39+
: shardingArray(std::move(shardingArray)), mesh(mesh) {
40+
assert(this->mesh);
41+
}
4042
static ShardingOption makeEmpty() {
4143
auto res = ShardingOption();
4244
res.empty = true;

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
2424
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
2525
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
26+
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
2627
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
2728
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
2829
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
158159
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
159160
arith::registerBufferizableOpInterfaceExternalModels(registry);
160161
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
162+
arith::registerShardingInterfaceExternalModels(registry);
161163
arith::registerValueBoundsOpInterfaceExternalModels(registry);
162164
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
163165
registry);

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
88
ExpandOps.cpp
99
IntRangeOptimizations.cpp
1010
ReifyValueBounds.cpp
11+
ShardingInterfaceImpl.cpp
1112
UnsignedWhenEquivalent.cpp
1213

1314
ADDITIONAL_HEADER_DIRS
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
2627
MLIRInferIntRangeInterface
2728
MLIRIR
2829
MLIRMemRefDialect
30+
MLIRMeshDialect
2931
MLIRPass
32+
MLIRShardingInterface
3033
MLIRTensorDialect
3134
MLIRTransforms
3235
MLIRTransformUtils
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
12+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
13+
#include "mlir/IR/DialectRegistry.h"
14+
#include "llvm/Support/Debug.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::arith;
18+
using namespace mlir::mesh;
19+
20+
namespace {
21+
22+
// Sharding of arith.constant
23+
// RankedTensor constants can be sharded like any other tensor.
24+
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
25+
// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
26+
// Scalar constants are always replicated and need no sharding annotation.
27+
28+
struct ConstantShardingInterface
29+
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
30+
ConstantOp> {
31+
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
32+
auto ndims = 0;
33+
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
34+
ndims = type.getRank();
35+
}
36+
return SmallVector<utils::IteratorType>(ndims,
37+
utils::IteratorType::parallel);
38+
}
39+
40+
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
41+
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
42+
return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
43+
type.getRank(), op->getContext())});
44+
}
45+
return {};
46+
}
47+
48+
// Indicate failure if no result sharding exists.
49+
// Otherwise mirror result sharding if it is a tensor constant.
50+
// Otherwise return replication option.
51+
FailureOr<ShardingOption>
52+
getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
53+
ArrayRef<MeshSharding> resultShardings) const {
54+
assert(resultShardings.size() == 1 &&
55+
"Expecting exactly one result sharding for arith.constant");
56+
auto resultSharding = resultShardings[0];
57+
if (!resultSharding) {
58+
return failure();
59+
}
60+
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
61+
ShardingArray axesArray(resultSharding.getSplitAxes().size());
62+
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
63+
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
64+
}
65+
return ShardingOption(axesArray, resultSharding.getMeshAttr());
66+
}
67+
return ShardingOption({}, resultSharding.getMeshAttr());
68+
}
69+
70+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
71+
ArrayRef<MeshSharding> operandShardings,
72+
ArrayRef<MeshSharding> resultShardings,
73+
IRMapping &spmdizationMap,
74+
SymbolTableCollection &symbolTable,
75+
OpBuilder &builder) const {
76+
auto cOp = cast<ConstantOp>(op);
77+
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
78+
if (!value.isSplat() || !resultShardings[0]) {
79+
// Currently non-splat constants are not supported.
80+
return failure();
81+
}
82+
auto sharding = resultShardings[0];
83+
auto newType = cast<RankedTensorType>(shardType(
84+
cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
85+
sharding));
86+
auto newValue = value.resizeSplat(newType);
87+
auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
88+
spmdizationMap.map(op->getResult(0), newOp.getResult());
89+
spmdizationMap.map(op, newOp.getOperation());
90+
} else {
91+
// `clone` will populate the mapping of old to new results.
92+
(void)builder.clone(*op, spmdizationMap);
93+
}
94+
return success();
95+
}
96+
};
97+
} // namespace
98+
99+
void mlir::arith::registerShardingInterfaceExternalModels(
100+
DialectRegistry &registry) {
101+
102+
registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
103+
ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
104+
});
105+
}

0 commit comments

Comments
 (0)