Skip to content

Commit d254e84

Browse files
committed
change getShardingOption signature
1 parent 2d1c8c6 commit d254e84

File tree

4 files changed

+199
-102
lines changed

4 files changed

+199
-102
lines changed

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,22 @@ struct ShardingOption {
3535
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
3636
};
3737

38+
// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
39+
// for a given operation result.
40+
FailureOr<MeshShardingAttr> getMeshShardingAttr(OpResult result,
41+
bool useOperandSharding);
42+
43+
// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
44+
// for a given operation operand.
45+
FailureOr<std::pair<bool, MeshShardingAttr>>
46+
getMeshShardingAttr(OpOperand &opOperand);
47+
3848
namespace detail {
3949

40-
FailureOr<ShardingOption> defaultGetShardingOption(Operation *op);
50+
FailureOr<ShardingOption>
51+
defaultGetShardingOption(Operation *op,
52+
ArrayRef<MeshShardingAttr> operandShardings,
53+
ArrayRef<MeshShardingAttr> resultShardings);
4154

4255
LogicalResult
4356
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
6161
}],
6262
/*retTy=*/"FailureOr<ShardingOption>",
6363
/*methodName=*/"getShardingOption",
64-
/*args=*/(ins),
64+
/*args=*/(ins
65+
"ArrayRef<MeshShardingAttr>": $operandShardings,
66+
"ArrayRef<MeshShardingAttr>": $resultShardings
67+
),
6568
/*methodBody=*/"",
6669
/*defaultImplementation=*/[{
6770
return detail::defaultGetShardingOption(
68-
$_op.getOperation());
71+
$_op.getOperation(), operandShardings, resultShardings);
6972
}]
7073
>,
7174
InterfaceMethod<

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 86 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -29,64 +29,6 @@ using namespace mlir::mesh;
2929
// common util functions
3030
//===----------------------------------------------------------------------===//
3131

32-
// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
33-
// for a given operation result.
34-
static FailureOr<MeshShardingAttr>
35-
getMeshShardingAttr(OpResult result, bool useOperandSharding) {
36-
Value val = result.cast<Value>();
37-
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
38-
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
39-
if (!shardOp)
40-
return false;
41-
return !shardOp.getAnnotateForUsers();
42-
});
43-
44-
if (anyShardedForDef) {
45-
// expected to have exact one use if it has a use of `mesh.shard` without
46-
// unit attr annotate_for_users
47-
if (!val.hasOneUse())
48-
return failure();
49-
auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
50-
return shardOp.getShard();
51-
} else if (useOperandSharding) {
52-
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
53-
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
54-
if (!shardOp)
55-
return false;
56-
return shardOp.getAnnotateForUsers();
57-
});
58-
if (anyShardedForUsers) {
59-
SmallVector<ShardOp> shardOps;
60-
for (Operation *user : val.getUsers()) {
61-
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
62-
if (shardOp)
63-
shardOps.push_back(shardOp);
64-
}
65-
MeshShardingAttr shardForDef = shardOps[0].getShard();
66-
for (size_t i = 1; i < shardOps.size(); ++i) {
67-
// TODO: Deduce a reasonable mesh sharding attr for def when they are
68-
// different
69-
assert(shardOps[i].getShard() == shardForDef &&
70-
"only support all shard ops have the same mesh sharding attr");
71-
}
72-
return shardForDef;
73-
}
74-
}
75-
76-
return failure();
77-
}
78-
79-
// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
80-
// for a given operation operand.
81-
static FailureOr<std::pair<bool, MeshShardingAttr>>
82-
getMeshShardingAttr(OpOperand &opOperand) {
83-
Value val = opOperand.get();
84-
if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
85-
return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
86-
87-
return failure();
88-
}
89-
9032
static LogicalResult
9133
checkOperandAffineExprRecursively(AffineExpr expr,
9234
SmallVectorImpl<bool> &seenIds) {
@@ -146,6 +88,64 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
14688
return positions;
14789
}
14890

91+
//===----------------------------------------------------------------------===//
92+
// mesh::getMeshShardingAttr
93+
//===----------------------------------------------------------------------===//
94+
95+
FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr(OpResult result,
96+
bool useOperandSharding) {
97+
Value val = result.cast<Value>();
98+
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
99+
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
100+
if (!shardOp)
101+
return false;
102+
return !shardOp.getAnnotateForUsers();
103+
});
104+
105+
if (anyShardedForDef) {
106+
// expected to have exact one use if it has a use of `mesh.shard` without
107+
// unit attr annotate_for_users
108+
if (!val.hasOneUse())
109+
return failure();
110+
auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
111+
return shardOp.getShard();
112+
} else if (useOperandSharding) {
113+
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
114+
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
115+
if (!shardOp)
116+
return false;
117+
return shardOp.getAnnotateForUsers();
118+
});
119+
if (anyShardedForUsers) {
120+
SmallVector<ShardOp> shardOps;
121+
for (Operation *user : val.getUsers()) {
122+
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
123+
if (shardOp)
124+
shardOps.push_back(shardOp);
125+
}
126+
MeshShardingAttr shardForDef = shardOps[0].getShard();
127+
for (size_t i = 1; i < shardOps.size(); ++i) {
128+
// TODO: Deduce a reasonable mesh sharding attr for def when they are
129+
// different
130+
assert(shardOps[i].getShard() == shardForDef &&
131+
"only support all shard ops have the same mesh sharding attr");
132+
}
133+
return shardForDef;
134+
}
135+
}
136+
137+
return failure();
138+
}
139+
140+
FailureOr<std::pair<bool, MeshShardingAttr>>
141+
mesh::getMeshShardingAttr(OpOperand &opOperand) {
142+
Value val = opOperand.get();
143+
if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
144+
return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
145+
146+
return failure();
147+
}
148+
149149
//===----------------------------------------------------------------------===//
150150
// ShardingInterface::verifyShardingInterfaceImpl
151151
//===----------------------------------------------------------------------===//
@@ -214,19 +214,18 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
214214
namespace {
215215

216216
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
217-
static LogicalResult
218-
fillShardingOption(Operation *op, ShardingOption &shardingOption,
219-
SymbolRefAttr cluster, ArrayRef<int32_t> meshAxes,
220-
unsigned loopIdx, bool ignoreIfConflicted = false) {
217+
static LogicalResult fillShardingOption(Operation *op,
218+
ShardingOption &shardingOption,
219+
SymbolRefAttr cluster,
220+
ArrayRef<int32_t> meshAxes,
221+
unsigned loopIdx) {
221222
if ((shardingOption.cluster && cluster &&
222223
shardingOption.cluster != cluster) ||
223224
(!shardingOption.shardingArray[loopIdx].empty() &&
224225
shardingOption.shardingArray[loopIdx] != meshAxes)) {
225-
if (ignoreIfConflicted)
226-
return success();
227-
else
228-
return op->emitOpError()
229-
<< "sharding option conflicts on loop iterator " << loopIdx;
226+
LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
227+
<< loopIdx << "\n");
228+
return failure();
230229
}
231230
for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
232231
if (i == loopIdx)
@@ -236,12 +235,9 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
236235
if (std::find(shardingOption.shardingArray[i].begin(),
237236
shardingOption.shardingArray[i].end(),
238237
axis) != shardingOption.shardingArray[i].end()) {
239-
if (ignoreIfConflicted)
240-
return success();
241-
else
242-
return op->emitOpError()
243-
<< "sharding option conflicts because mesh axes " << axis
244-
<< " duplicate";
238+
LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
239+
<< axis << " duplicate");
240+
return failure();
245241
}
246242
}
247243
}
@@ -255,8 +251,9 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
255251

256252
} // namespace
257253

258-
FailureOr<ShardingOption>
259-
mesh::detail::defaultGetShardingOption(Operation *op) {
254+
FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
255+
Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
256+
ArrayRef<MeshShardingAttr> resultShardings) {
260257
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
261258
ShardingOption shardingOption;
262259

@@ -272,35 +269,34 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
272269
bool anyShardingInResultsOrOperands = false;
273270

274271
// 1. Fill sharding option based on op results
275-
for (OpResult result : op->getResults()) {
276-
AffineMap map = maps[numOperands + result.getResultNumber()];
277-
FailureOr<MeshShardingAttr> shardAttr =
278-
getMeshShardingAttr(result, /*useOperandSharding*/ true);
279-
if (failed(shardAttr))
272+
for (auto shardingIt : llvm::enumerate(resultShardings)) {
273+
MeshShardingAttr shardAttr = shardingIt.value();
274+
if (!shardAttr)
280275
continue;
276+
AffineMap map = maps[numOperands + shardingIt.index()];
281277
anyShardingInResultsOrOperands = true;
282278
// Handle the split axes: calculate the corresponding loop index for each
283279
// split axes sub-array, and then store the sub-array to
284280
// shardingOption[index]
285-
for (auto it : llvm::zip(map.getResults(), shardAttr->getSplitAxes())) {
281+
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
286282
AffineExpr expr = std::get<0>(it);
287283
ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
288284
auto dim = expr.cast<AffineDimExpr>();
289285
unsigned index = dim.getPosition();
290286
visitedLoopIndices.insert(index);
291-
if (failed(fillShardingOption(op, shardingOption, shardAttr->getCluster(),
287+
if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
292288
axes, index)))
293289
return failure();
294290
}
295291

296292
// Handle the partial axes: at this stage, the exact loop index/indices
297293
// cannot be decided because there could be multiple reduction loops.
298-
ArrayRef<int32_t> partialAxes = shardAttr->getPartialAxes();
294+
ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
299295
if (!partialAxes.empty()) {
300296
if (!partialMeshAxes.empty())
301297
return op->emitOpError() << "at most one result with partial axes is "
302298
"supported at present";
303-
partialType = shardAttr->getPartialType();
299+
partialType = shardAttr.getPartialType();
304300
partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
305301
// Add all the reduction loop indices to `visitedLoopIndices` if
306302
// `partialAxes` is not empty
@@ -312,16 +308,13 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
312308
}
313309

314310
// 2. Fill sharding option based on operands
315-
for (OpOperand &opOperand : op->getOpOperands()) {
316-
FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
317-
getMeshShardingAttr(opOperand);
318-
if (failed(maybeShardAttr))
311+
for (auto shardingIt : llvm::enumerate(operandShardings)) {
312+
MeshShardingAttr shardAttr = shardingIt.value();
313+
if (!shardAttr)
319314
continue;
320315

321316
anyShardingInResultsOrOperands = true;
322-
bool annotateForUsers = maybeShardAttr->first;
323-
MeshShardingAttr shardAttr = maybeShardAttr->second;
324-
AffineMap map = maps[opOperand.getOperandNumber()];
317+
AffineMap map = maps[shardingIt.index()];
325318
unsigned numDims = map.getNumDims();
326319

327320
// Handle the split axes. Partial axes don't need to be handled because they
@@ -344,8 +337,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
344337
unsigned loopIdx = *loopIndices->begin();
345338
visitedLoopIndices.insert(loopIdx);
346339
if (failed(fillShardingOption(op, shardingOption,
347-
shardAttr.getCluster(), axes, loopIdx,
348-
!annotateForUsers)))
340+
shardAttr.getCluster(), axes, loopIdx)))
349341
return failure();
350342
}
351343
// If multiple loop indices correspond to a dimension of an operand, it is
@@ -361,7 +353,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
361353
}
362354
if (!seenLoopIndices)
363355
return op->emitOpError()
364-
<< "the operand " << opOperand.getOperandNumber()
356+
<< "the operand " << shardingIt.index()
365357
<< " has multiple loop indices in a dimension, but none of "
366358
"them could be found in the exactly specified annotation "
367359
"of op results or operands.";

0 commit comments

Comments
 (0)