@@ -34,9 +34,10 @@ namespace {
34
34
// Utilities
35
35
// ===----------------------------------------------------------------------===//
36
36
37
- // This method returns all possible sharding attributes. For example,
38
- // mustShardings = [shard0, None] and optionalShardings = [None, shard1], the
39
- // result will be [[shard0, shard1], [shard0, None]]
37
+ // This method retrieves all potential sharding attributes, prioritizing
38
+ // specific shardings. For example, mustShardings = [shard0, None] and
39
+ // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
40
+ // [shard0, None]]
40
41
static SmallVector<SmallVector<MeshShardingAttr>>
41
42
getOrderedPossibleShardingAttrs (ArrayRef<MeshShardingAttr> mustShardings,
42
43
ArrayRef<MeshShardingAttr> optionalShardings) {
@@ -92,15 +93,20 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
92
93
}
93
94
94
95
// collect MeshShardingAttr from results
95
- SmallVector<MeshShardingAttr> resultShardings;
96
- resultShardings.reserve (op->getNumResults ());
96
+ SmallVector<MeshShardingAttr> allowConflictsResultShardings;
97
+ allowConflictsResultShardings.resize (op->getNumResults ());
98
+ SmallVector<MeshShardingAttr> resultMustShardings;
99
+ resultMustShardings.resize (op->getNumResults ());
97
100
for (OpResult result : op->getResults ()) {
98
- FailureOr<MeshShardingAttr> shardAttr =
99
- getMeshShardingAttr (result, /* useOperandSharding*/ true );
100
- if (succeeded (shardAttr))
101
- resultShardings.push_back (*shardAttr);
101
+ FailureOr<std::pair<bool , MeshShardingAttr>> maybeShardAttr =
102
+ getMeshShardingAttr (result);
103
+ if (failed (maybeShardAttr))
104
+ continue ;
105
+ if (!maybeShardAttr->first )
106
+ resultMustShardings[result.getResultNumber ()] = maybeShardAttr->second ;
102
107
else
103
- resultShardings.push_back (nullptr );
108
+ allowConflictsResultShardings[result.getResultNumber ()] =
109
+ maybeShardAttr->second ;
104
110
}
105
111
106
112
// collect MeshShardingAttr from operands
@@ -114,8 +120,7 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
114
120
if (failed (maybeShardAttr))
115
121
continue ;
116
122
117
- bool annotateForUsers = maybeShardAttr->first ;
118
- if (annotateForUsers)
123
+ if (maybeShardAttr->first )
119
124
operandMustShardings[opOperand.getOperandNumber ()] =
120
125
maybeShardAttr->second ;
121
126
else
@@ -127,14 +132,22 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
127
132
SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
128
133
getOrderedPossibleShardingAttrs (operandMustShardings,
129
134
allowConflictsOperandShardings);
135
+ SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
136
+ getOrderedPossibleShardingAttrs (resultMustShardings,
137
+ allowConflictsResultShardings);
130
138
FailureOr<ShardingOption> finalShardingOption = failure ();
131
- for (ArrayRef<MeshShardingAttr> operandShardings :
132
- possibleOperandShardingAttrs) {
133
- FailureOr<ShardingOption> shardingOption =
134
- shardingOp.getShardingOption (operandShardings, resultShardings);
135
- if (succeeded (shardingOption)) {
136
- finalShardingOption = shardingOption;
139
+ for (ArrayRef<MeshShardingAttr> resultShardings :
140
+ possibleResultShardingAttrs) {
141
+ if (succeeded (finalShardingOption))
137
142
break ;
143
+ for (ArrayRef<MeshShardingAttr> operandShardings :
144
+ possibleOperandShardingAttrs) {
145
+ FailureOr<ShardingOption> shardingOption =
146
+ shardingOp.getShardingOption (operandShardings, resultShardings);
147
+ if (succeeded (shardingOption)) {
148
+ finalShardingOption = shardingOption;
149
+ break ;
150
+ }
138
151
}
139
152
}
140
153
0 commit comments