Skip to content

Commit 2d1c8c6

Browse files
committed
fix comments, 3rd
1 parent 50205fc commit 2d1c8c6

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,9 @@ checkOperandAffineExprRecursively(AffineExpr expr,
106106
AffineExpr lhs = binOpExpr.getLHS();
107107
AffineExpr rhs = binOpExpr.getRHS();
108108
AffineExpr dimExpr;
109-
if (lhs.getKind() == AffineExprKind::DimId) {
109+
if (lhs.getKind() == AffineExprKind::DimId &&
110+
rhs.getKind() == AffineExprKind::Constant) {
110111
dimExpr = lhs;
111-
if (rhs.getKind() != AffineExprKind::Constant)
112-
return failure();
113112
} else if (rhs.getKind() == AffineExprKind::DimId &&
114113
lhs.getKind() == AffineExprKind::Constant) {
115114
dimExpr = rhs;
@@ -275,7 +274,8 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
275274
// 1. Fill sharding option based on op results
276275
for (OpResult result : op->getResults()) {
277276
AffineMap map = maps[numOperands + result.getResultNumber()];
278-
FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr(result, true);
277+
FailureOr<MeshShardingAttr> shardAttr =
278+
getMeshShardingAttr(result, /*useOperandSharding*/ true);
279279
if (failed(shardAttr))
280280
continue;
281281
anyShardingInResultsOrOperands = true;
@@ -324,11 +324,11 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
324324
AffineMap map = maps[opOperand.getOperandNumber()];
325325
unsigned numDims = map.getNumDims();
326326

327-
// Handle the split axes, and partial axes don't need to be handled because
328-
// they only affect the defining op of the operand
327+
// Handle the split axes. Partial axes don't need to be handled because they
328+
// only affect the defining op of the operand.
329329
//
330330
// TODO: Change to process the operands with single loop index first and
331-
// then the operands with multiple loop indices
331+
// then the operands with multiple loop indices.
332332
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
333333
AffineExpr expr = std::get<0>(it);
334334
ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
@@ -411,7 +411,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
411411
const ShardingOption &shardingOption,
412412
AffineMap map,
413413
ArrayRef<IteratorType> loopTypes) {
414-
if (succeeded(getMeshShardingAttr(result, false)))
414+
if (succeeded(getMeshShardingAttr(result, /*useOperandSharding*/ false)))
415415
return success();
416416

417417
auto resultType = result.getType().cast<RankedTensorType>();
@@ -421,6 +421,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
421421
// process the split axes
422422
for (auto it : llvm::enumerate(map.getResults())) {
423423
AffineExpr expr = it.value();
424+
// `expr` must be an `AffineDimExpr` because `map` is verified by
425+
// isProjectedPermutation
424426
auto dim = expr.cast<AffineDimExpr>();
425427
unsigned loopIdx = dim.getPosition();
426428
if (loopIdx < shardingOption.shardingArray.size())

0 commit comments

Comments
 (0)