@@ -106,10 +106,9 @@ checkOperandAffineExprRecursively(AffineExpr expr,
106
106
AffineExpr lhs = binOpExpr.getLHS ();
107
107
AffineExpr rhs = binOpExpr.getRHS ();
108
108
AffineExpr dimExpr;
109
- if (lhs.getKind () == AffineExprKind::DimId) {
109
+ if (lhs.getKind () == AffineExprKind::DimId &&
110
+ rhs.getKind () == AffineExprKind::Constant) {
110
111
dimExpr = lhs;
111
- if (rhs.getKind () != AffineExprKind::Constant)
112
- return failure ();
113
112
} else if (rhs.getKind () == AffineExprKind::DimId &&
114
113
lhs.getKind () == AffineExprKind::Constant) {
115
114
dimExpr = rhs;
@@ -275,7 +274,8 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
275
274
// 1. Fill sharding option based on op results
276
275
for (OpResult result : op->getResults ()) {
277
276
AffineMap map = maps[numOperands + result.getResultNumber ()];
278
- FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr (result, true );
277
+ FailureOr<MeshShardingAttr> shardAttr =
278
+ getMeshShardingAttr (result, /* useOperandSharding*/ true );
279
279
if (failed (shardAttr))
280
280
continue ;
281
281
anyShardingInResultsOrOperands = true ;
@@ -324,11 +324,11 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
324
324
AffineMap map = maps[opOperand.getOperandNumber ()];
325
325
unsigned numDims = map.getNumDims ();
326
326
327
- // Handle the split axes, and partial axes don't need to be handled because
328
- // they only affect the defining op of the operand
327
+ // Handle the split axes. Partial axes don't need to be handled because they
328
+ // only affect the defining op of the operand.
329
329
//
330
330
// TODO: Change to process the operands with single loop index first and
331
- // then the operands with multiple loop indices
331
+ // then the operands with multiple loop indices.
332
332
for (auto it : llvm::zip (map.getResults (), shardAttr.getSplitAxes ())) {
333
333
AffineExpr expr = std::get<0 >(it);
334
334
ArrayRef<int32_t > axes = std::get<1 >(it).asArrayRef ();
@@ -411,7 +411,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
411
411
const ShardingOption &shardingOption,
412
412
AffineMap map,
413
413
ArrayRef<IteratorType> loopTypes) {
414
- if (succeeded (getMeshShardingAttr (result, false )))
414
+ if (succeeded (getMeshShardingAttr (result, /* useOperandSharding */ false )))
415
415
return success ();
416
416
417
417
auto resultType = result.getType ().cast <RankedTensorType>();
@@ -421,6 +421,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
421
421
// process the split axes
422
422
for (auto it : llvm::enumerate (map.getResults ())) {
423
423
AffineExpr expr = it.value ();
424
+ // `expr` must be an `AffineDimExpr` because `map` is verified by
425
+ // isProjectedPermutation
424
426
auto dim = expr.cast <AffineDimExpr>();
425
427
unsigned loopIdx = dim.getPosition ();
426
428
if (loopIdx < shardingOption.shardingArray .size ())
0 commit comments