Skip to content

Commit b58581b

Browse files
committed
minor fix
1 parent d254e84 commit b58581b

File tree

3 files changed

+65
-51
lines changed

3 files changed

+65
-51
lines changed

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ struct ShardingOption {
3535
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
3636
};
3737

38-
// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
39-
// for a given operation result.
40-
FailureOr<MeshShardingAttr> getMeshShardingAttr(OpResult result,
41-
bool useOperandSharding);
38+
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
39+
// result and includes the 'annotate_for_users' information.
40+
FailureOr<std::pair<bool, MeshShardingAttr>>
41+
getMeshShardingAttr(OpResult result);
4242

43-
// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
44-
// for a given operation operand.
43+
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
44+
// operand and includes the 'annotate_for_users' information.
4545
FailureOr<std::pair<bool, MeshShardingAttr>>
4646
getMeshShardingAttr(OpOperand &opOperand);
4747

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
9292
// mesh::getMeshShardingAttr
9393
//===----------------------------------------------------------------------===//
9494

95-
FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr(OpResult result,
96-
bool useOperandSharding) {
95+
FailureOr<std::pair<bool, MeshShardingAttr>>
96+
mesh::getMeshShardingAttr(OpResult result) {
9797
Value val = result.cast<Value>();
9898
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
9999
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
@@ -108,32 +108,31 @@ FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr(OpResult result,
108108
if (!val.hasOneUse())
109109
return failure();
110110
auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
111-
return shardOp.getShard();
112-
} else if (useOperandSharding) {
113-
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
114-
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
115-
if (!shardOp)
116-
return false;
117-
return shardOp.getAnnotateForUsers();
118-
});
119-
if (anyShardedForUsers) {
120-
SmallVector<ShardOp> shardOps;
121-
for (Operation *user : val.getUsers()) {
122-
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
123-
if (shardOp)
124-
shardOps.push_back(shardOp);
125-
}
126-
MeshShardingAttr shardForDef = shardOps[0].getShard();
127-
for (size_t i = 1; i < shardOps.size(); ++i) {
128-
// TODO: Deduce a reasonable mesh sharding attr for def when they are
129-
// different
130-
assert(shardOps[i].getShard() == shardForDef &&
131-
"only support all shard ops have the same mesh sharding attr");
132-
}
133-
return shardForDef;
134-
}
111+
return std::make_pair(false, shardOp.getShard());
135112
}
136113

114+
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
115+
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
116+
if (!shardOp)
117+
return false;
118+
return shardOp.getAnnotateForUsers();
119+
});
120+
if (anyShardedForUsers) {
121+
SmallVector<ShardOp> shardOps;
122+
for (Operation *user : val.getUsers()) {
123+
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
124+
if (shardOp)
125+
shardOps.push_back(shardOp);
126+
}
127+
MeshShardingAttr shardForDef = shardOps[0].getShard();
128+
for (size_t i = 1; i < shardOps.size(); ++i) {
129+
// TODO: Deduce a reasonable mesh sharding attr for def when they are
130+
// different
131+
assert(shardOps[i].getShard() == shardForDef &&
132+
"only support all shard ops have the same mesh sharding attr");
133+
}
134+
return std::make_pair(true, shardForDef);
135+
}
137136
return failure();
138137
}
139138

@@ -403,7 +402,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
403402
const ShardingOption &shardingOption,
404403
AffineMap map,
405404
ArrayRef<IteratorType> loopTypes) {
406-
if (succeeded(getMeshShardingAttr(result, /*useOperandSharding*/ false)))
405+
FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
406+
getMeshShardingAttr(result);
407+
if (succeeded(maybeSharding) && !maybeSharding->first)
407408
return success();
408409

409410
auto resultType = result.getType().cast<RankedTensorType>();

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ namespace {
3434
// Utilities
3535
//===----------------------------------------------------------------------===//
3636

37-
// This method returns all possible sharding attributes. For example,
38-
// mustShardings = [shard0, None] and optionalShardings = [None, shard1], the
39-
// result will be [[shard0, shard1], [shard0, None]]
37+
// This method retrieves all potential sharding attributes, prioritizing
38+
// specific shardings. For example, mustShardings = [shard0, None] and
39+
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
40+
// [shard0, None]]
4041
static SmallVector<SmallVector<MeshShardingAttr>>
4142
getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
4243
ArrayRef<MeshShardingAttr> optionalShardings) {
@@ -92,15 +93,20 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
9293
}
9394

9495
// collect MeshShardingAttr from results
95-
SmallVector<MeshShardingAttr> resultShardings;
96-
resultShardings.reserve(op->getNumResults());
96+
SmallVector<MeshShardingAttr> allowConflictsResultShardings;
97+
allowConflictsResultShardings.resize(op->getNumResults());
98+
SmallVector<MeshShardingAttr> resultMustShardings;
99+
resultMustShardings.resize(op->getNumResults());
97100
for (OpResult result : op->getResults()) {
98-
FailureOr<MeshShardingAttr> shardAttr =
99-
getMeshShardingAttr(result, /*useOperandSharding*/ true);
100-
if (succeeded(shardAttr))
101-
resultShardings.push_back(*shardAttr);
101+
FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
102+
getMeshShardingAttr(result);
103+
if (failed(maybeShardAttr))
104+
continue;
105+
if (!maybeShardAttr->first)
106+
resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
102107
else
103-
resultShardings.push_back(nullptr);
108+
allowConflictsResultShardings[result.getResultNumber()] =
109+
maybeShardAttr->second;
104110
}
105111

106112
// collect MeshShardingAttr from operands
@@ -114,8 +120,7 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
114120
if (failed(maybeShardAttr))
115121
continue;
116122

117-
bool annotateForUsers = maybeShardAttr->first;
118-
if (annotateForUsers)
123+
if (maybeShardAttr->first)
119124
operandMustShardings[opOperand.getOperandNumber()] =
120125
maybeShardAttr->second;
121126
else
@@ -127,14 +132,22 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
127132
SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
128133
getOrderedPossibleShardingAttrs(operandMustShardings,
129134
allowConflictsOperandShardings);
135+
SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
136+
getOrderedPossibleShardingAttrs(resultMustShardings,
137+
allowConflictsResultShardings);
130138
FailureOr<ShardingOption> finalShardingOption = failure();
131-
for (ArrayRef<MeshShardingAttr> operandShardings :
132-
possibleOperandShardingAttrs) {
133-
FailureOr<ShardingOption> shardingOption =
134-
shardingOp.getShardingOption(operandShardings, resultShardings);
135-
if (succeeded(shardingOption)) {
136-
finalShardingOption = shardingOption;
139+
for (ArrayRef<MeshShardingAttr> resultShardings :
140+
possibleResultShardingAttrs) {
141+
if (succeeded(finalShardingOption))
137142
break;
143+
for (ArrayRef<MeshShardingAttr> operandShardings :
144+
possibleOperandShardingAttrs) {
145+
FailureOr<ShardingOption> shardingOption =
146+
shardingOp.getShardingOption(operandShardings, resultShardings);
147+
if (succeeded(shardingOption)) {
148+
finalShardingOption = shardingOption;
149+
break;
150+
}
138151
}
139152
}
140153

0 commit comments

Comments
 (0)