@@ -29,64 +29,6 @@ using namespace mlir::mesh;
29
29
// common util functions
30
30
// ===----------------------------------------------------------------------===//
31
31
32
- // This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
33
- // for a given operation result.
34
- static FailureOr<MeshShardingAttr>
35
- getMeshShardingAttr (OpResult result, bool useOperandSharding) {
36
- Value val = result.cast <Value>();
37
- bool anyShardedForDef = llvm::any_of (val.getUsers (), [](Operation *user) {
38
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
39
- if (!shardOp)
40
- return false ;
41
- return !shardOp.getAnnotateForUsers ();
42
- });
43
-
44
- if (anyShardedForDef) {
45
- // expected to have exact one use if it has a use of `mesh.shard` without
46
- // unit attr annotate_for_users
47
- if (!val.hasOneUse ())
48
- return failure ();
49
- auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers ().begin ());
50
- return shardOp.getShard ();
51
- } else if (useOperandSharding) {
52
- bool anyShardedForUsers = llvm::any_of (val.getUsers (), [](Operation *user) {
53
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
54
- if (!shardOp)
55
- return false ;
56
- return shardOp.getAnnotateForUsers ();
57
- });
58
- if (anyShardedForUsers) {
59
- SmallVector<ShardOp> shardOps;
60
- for (Operation *user : val.getUsers ()) {
61
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
62
- if (shardOp)
63
- shardOps.push_back (shardOp);
64
- }
65
- MeshShardingAttr shardForDef = shardOps[0 ].getShard ();
66
- for (size_t i = 1 ; i < shardOps.size (); ++i) {
67
- // TODO: Deduce a reasonable mesh sharding attr for def when they are
68
- // different
69
- assert (shardOps[i].getShard () == shardForDef &&
70
- " only support all shard ops have the same mesh sharding attr" );
71
- }
72
- return shardForDef;
73
- }
74
- }
75
-
76
- return failure ();
77
- }
78
-
79
- // This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
80
- // for a given operation operand.
81
- static FailureOr<std::pair<bool , MeshShardingAttr>>
82
- getMeshShardingAttr (OpOperand &opOperand) {
83
- Value val = opOperand.get ();
84
- if (ShardOp shardOp = val.getDefiningOp <ShardOp>())
85
- return std::make_pair (shardOp.getAnnotateForUsers (), shardOp.getShard ());
86
-
87
- return failure ();
88
- }
89
-
90
32
static LogicalResult
91
33
checkOperandAffineExprRecursively (AffineExpr expr,
92
34
SmallVectorImpl<bool > &seenIds) {
@@ -146,6 +88,64 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
146
88
return positions;
147
89
}
148
90
91
+ // ===----------------------------------------------------------------------===//
92
+ // mesh::getMeshShardingAttr
93
+ // ===----------------------------------------------------------------------===//
94
+
95
+ FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr (OpResult result,
96
+ bool useOperandSharding) {
97
+ Value val = result.cast <Value>();
98
+ bool anyShardedForDef = llvm::any_of (val.getUsers (), [](Operation *user) {
99
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
100
+ if (!shardOp)
101
+ return false ;
102
+ return !shardOp.getAnnotateForUsers ();
103
+ });
104
+
105
+ if (anyShardedForDef) {
106
+ // expected to have exact one use if it has a use of `mesh.shard` without
107
+ // unit attr annotate_for_users
108
+ if (!val.hasOneUse ())
109
+ return failure ();
110
+ auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers ().begin ());
111
+ return shardOp.getShard ();
112
+ } else if (useOperandSharding) {
113
+ bool anyShardedForUsers = llvm::any_of (val.getUsers (), [](Operation *user) {
114
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
115
+ if (!shardOp)
116
+ return false ;
117
+ return shardOp.getAnnotateForUsers ();
118
+ });
119
+ if (anyShardedForUsers) {
120
+ SmallVector<ShardOp> shardOps;
121
+ for (Operation *user : val.getUsers ()) {
122
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
123
+ if (shardOp)
124
+ shardOps.push_back (shardOp);
125
+ }
126
+ MeshShardingAttr shardForDef = shardOps[0 ].getShard ();
127
+ for (size_t i = 1 ; i < shardOps.size (); ++i) {
128
+ // TODO: Deduce a reasonable mesh sharding attr for def when they are
129
+ // different
130
+ assert (shardOps[i].getShard () == shardForDef &&
131
+ " only support all shard ops have the same mesh sharding attr" );
132
+ }
133
+ return shardForDef;
134
+ }
135
+ }
136
+
137
+ return failure ();
138
+ }
139
+
140
+ FailureOr<std::pair<bool , MeshShardingAttr>>
141
+ mesh::getMeshShardingAttr (OpOperand &opOperand) {
142
+ Value val = opOperand.get ();
143
+ if (ShardOp shardOp = val.getDefiningOp <ShardOp>())
144
+ return std::make_pair (shardOp.getAnnotateForUsers (), shardOp.getShard ());
145
+
146
+ return failure ();
147
+ }
148
+
149
149
// ===----------------------------------------------------------------------===//
150
150
// ShardingInterface::verifyShardingInterfaceImpl
151
151
// ===----------------------------------------------------------------------===//
@@ -214,19 +214,18 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
214
214
namespace {
215
215
216
216
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
217
- static LogicalResult
218
- fillShardingOption (Operation *op, ShardingOption &shardingOption,
219
- SymbolRefAttr cluster, ArrayRef<int32_t > meshAxes,
220
- unsigned loopIdx, bool ignoreIfConflicted = false ) {
217
+ static LogicalResult fillShardingOption (Operation *op,
218
+ ShardingOption &shardingOption,
219
+ SymbolRefAttr cluster,
220
+ ArrayRef<int32_t > meshAxes,
221
+ unsigned loopIdx) {
221
222
if ((shardingOption.cluster && cluster &&
222
223
shardingOption.cluster != cluster) ||
223
224
(!shardingOption.shardingArray [loopIdx].empty () &&
224
225
shardingOption.shardingArray [loopIdx] != meshAxes)) {
225
- if (ignoreIfConflicted)
226
- return success ();
227
- else
228
- return op->emitOpError ()
229
- << " sharding option conflicts on loop iterator " << loopIdx;
226
+ LLVM_DEBUG (DBGS () << " sharding option conflicts on loop iterator "
227
+ << loopIdx << " \n " );
228
+ return failure ();
230
229
}
231
230
for (size_t i = 0 ; i < shardingOption.shardingArray .size (); ++i) {
232
231
if (i == loopIdx)
@@ -236,12 +235,9 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
236
235
if (std::find (shardingOption.shardingArray [i].begin (),
237
236
shardingOption.shardingArray [i].end (),
238
237
axis) != shardingOption.shardingArray [i].end ()) {
239
- if (ignoreIfConflicted)
240
- return success ();
241
- else
242
- return op->emitOpError ()
243
- << " sharding option conflicts because mesh axes " << axis
244
- << " duplicate" ;
238
+ LLVM_DEBUG (DBGS () << " sharding option conflicts because mesh axes "
239
+ << axis << " duplicate" );
240
+ return failure ();
245
241
}
246
242
}
247
243
}
@@ -255,8 +251,9 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
255
251
256
252
} // namespace
257
253
258
- FailureOr<ShardingOption>
259
- mesh::detail::defaultGetShardingOption (Operation *op) {
254
+ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption (
255
+ Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
256
+ ArrayRef<MeshShardingAttr> resultShardings) {
260
257
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
261
258
ShardingOption shardingOption;
262
259
@@ -272,35 +269,34 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
272
269
bool anyShardingInResultsOrOperands = false ;
273
270
274
271
// 1. Fill sharding option based on op results
275
- for (OpResult result : op->getResults ()) {
276
- AffineMap map = maps[numOperands + result.getResultNumber ()];
277
- FailureOr<MeshShardingAttr> shardAttr =
278
- getMeshShardingAttr (result, /* useOperandSharding*/ true );
279
- if (failed (shardAttr))
272
+ for (auto shardingIt : llvm::enumerate (resultShardings)) {
273
+ MeshShardingAttr shardAttr = shardingIt.value ();
274
+ if (!shardAttr)
280
275
continue ;
276
+ AffineMap map = maps[numOperands + shardingIt.index ()];
281
277
anyShardingInResultsOrOperands = true ;
282
278
// Handle the split axes: calculate the corresponding loop index for each
283
279
// split axes sub-array, and then store the sub-array to
284
280
// shardingOption[index]
285
- for (auto it : llvm::zip (map.getResults (), shardAttr-> getSplitAxes ())) {
281
+ for (auto it : llvm::zip (map.getResults (), shardAttr. getSplitAxes ())) {
286
282
AffineExpr expr = std::get<0 >(it);
287
283
ArrayRef<int32_t > axes = std::get<1 >(it).asArrayRef ();
288
284
auto dim = expr.cast <AffineDimExpr>();
289
285
unsigned index = dim.getPosition ();
290
286
visitedLoopIndices.insert (index);
291
- if (failed (fillShardingOption (op, shardingOption, shardAttr-> getCluster (),
287
+ if (failed (fillShardingOption (op, shardingOption, shardAttr. getCluster (),
292
288
axes, index)))
293
289
return failure ();
294
290
}
295
291
296
292
// Handle the partial axes: at this stage, the exact loop index/indices
297
293
// cannot be decided because there could be multiple reduction loops.
298
- ArrayRef<int32_t > partialAxes = shardAttr-> getPartialAxes ();
294
+ ArrayRef<int32_t > partialAxes = shardAttr. getPartialAxes ();
299
295
if (!partialAxes.empty ()) {
300
296
if (!partialMeshAxes.empty ())
301
297
return op->emitOpError () << " at most one result with partial axes is "
302
298
" supported at present" ;
303
- partialType = shardAttr-> getPartialType ();
299
+ partialType = shardAttr. getPartialType ();
304
300
partialMeshAxes.append (partialAxes.begin (), partialAxes.end ());
305
301
// Add all the reduction loop indices to `visitedLoopIndices` if
306
302
// `partialAxes` is not empty
@@ -312,16 +308,13 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
312
308
}
313
309
314
310
// 2. Fill sharding option based on operands
315
- for (OpOperand &opOperand : op->getOpOperands ()) {
316
- FailureOr<std::pair<bool , MeshShardingAttr>> maybeShardAttr =
317
- getMeshShardingAttr (opOperand);
318
- if (failed (maybeShardAttr))
311
+ for (auto shardingIt : llvm::enumerate (operandShardings)) {
312
+ MeshShardingAttr shardAttr = shardingIt.value ();
313
+ if (!shardAttr)
319
314
continue ;
320
315
321
316
anyShardingInResultsOrOperands = true ;
322
- bool annotateForUsers = maybeShardAttr->first ;
323
- MeshShardingAttr shardAttr = maybeShardAttr->second ;
324
- AffineMap map = maps[opOperand.getOperandNumber ()];
317
+ AffineMap map = maps[shardingIt.index ()];
325
318
unsigned numDims = map.getNumDims ();
326
319
327
320
// Handle the split axes. Partial axes don't need to be handled because they
@@ -344,8 +337,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
344
337
unsigned loopIdx = *loopIndices->begin ();
345
338
visitedLoopIndices.insert (loopIdx);
346
339
if (failed (fillShardingOption (op, shardingOption,
347
- shardAttr.getCluster (), axes, loopIdx,
348
- !annotateForUsers)))
340
+ shardAttr.getCluster (), axes, loopIdx)))
349
341
return failure ();
350
342
}
351
343
// If multiple loop indices correspond to a dimension of an operand, it is
@@ -361,7 +353,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
361
353
}
362
354
if (!seenLoopIndices)
363
355
return op->emitOpError ()
364
- << " the operand " << opOperand. getOperandNumber ()
356
+ << " the operand " << shardingIt. index ()
365
357
<< " has multiple loop indices in a dimension, but none of "
366
358
" them could be found in the exactly specified annotation "
367
359
" of op results or operands." ;
0 commit comments