Skip to content

Commit 276e5e6

Browse files
fschlimbDinistro
andcommitted
Apply suggestions from code review
Co-authored-by: Christian Ulmann <[email protected]>
1 parent ce882e6 commit 276e5e6

File tree

2 files changed

+63
-62
lines changed

2 files changed

+63
-62
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ using namespace mlir;
4545
using namespace mesh;
4646

4747
namespace {
48-
/// Convert vec of OpFoldResults (ints) into vector of Values.
48+
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
49+
/// provided type.
4950
static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
5051
llvm::ArrayRef<int64_t> statics,
5152
ValueRange dynamics,
@@ -55,14 +56,15 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
5556
Type i64 = b.getI64Type();
5657
if (!type)
5758
type = i64;
58-
assert(i64 == type || b.getIndexType() == type);
59+
assert((i64 == type || b.getIndexType() == type) &&
60+
"expected an i64 or an intex type");
5961
for (auto s : statics) {
60-
values.emplace_back(
61-
ShapedType::isDynamic(s)
62-
? *(dyn++)
63-
: b.create<arith::ConstantOp>(loc, type,
64-
i64 == type ? b.getI64IntegerAttr(s)
65-
: b.getIndexAttr(s)));
62+
if (s == ShapedType::kDynamic) {
63+
values.emplace_back(*(dyn++));
64+
} else {
65+
TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
66+
values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
67+
}
6668
}
6769
return values;
6870
};
@@ -129,33 +131,33 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
129131
ConversionPatternRewriter &rewriter) const override {
130132
auto splitAxes = op.getSplitAxes().getAxes();
131133
int64_t maxNAxes = 0;
132-
for (auto axes : splitAxes) {
134+
for (auto axes : splitAxes)
133135
maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
134-
}
135136

136137
// To hold the split axes, create empty 2d tensor with shape
137138
// {splitAxes.size(), max-size-of-split-groups}.
138139
// Set trailing elements for smaller split-groups to -1.
139140
Location loc = op.getLoc();
140141
auto i16 = rewriter.getI16Type();
141142
auto i64 = rewriter.getI64Type();
142-
int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
143+
std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
144+
maxNAxes};
143145
Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
144-
auto attr = IntegerAttr::get(i16, 0xffff);
146+
auto attr = IntegerAttr::get(i16, -1);
145147
Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
146148
resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
147149
.getResult(0);
148150

149151
// explicitly write values into tensor row by row
150-
int64_t strides[] = {1, 1};
152+
std::array<int64_t, 2> strides = {1, 1};
151153
int64_t nSplits = 0;
152154
ValueRange empty = {};
153155
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
154156
int64_t size = axes.size();
155157
if (size > 0)
156158
++nSplits;
157-
int64_t offs[] = {(int64_t)i, 0};
158-
int64_t sizes[] = {1, size};
159+
std::array<int64_t, 2> offs = {(int64_t)i, 0};
160+
std::array<int64_t, 2> sizes = {1, size};
159161
auto tensorType = RankedTensorType::get({size}, i16);
160162
auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
161163
auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
@@ -165,7 +167,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
165167

166168
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
167169
// Store the halo sizes in the tensor.
168-
auto haloSizes =
170+
SmallVector<Value> haloSizes =
169171
getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
170172
adaptor.getDynamicHaloSizes());
171173
auto type = RankedTensorType::get({nSplits, 2}, i64);
@@ -190,7 +192,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
190192
} else {
191193
SymbolTableCollection symbolTableCollection;
192194
auto meshOp = getMesh(op, symbolTableCollection);
193-
auto maxSplitSize = 0;
195+
int64_t maxSplitSize = 0;
194196
for (auto axes : splitAxes) {
195197
int64_t splitSize =
196198
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
@@ -206,7 +208,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
206208
loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
207209
resOffsets =
208210
rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
209-
auto offsets =
211+
SmallVector<Value> offsets =
210212
getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
211213
adaptor.getDynamicShardedDimsOffsets());
212214
int64_t curr = 0;
@@ -217,8 +219,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
217219
++splitSize; // add one for the total size
218220
ArrayRef<Value> values(&offsets[curr], splitSize);
219221
Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
220-
int64_t offs[] = {(int64_t)i, 0};
221-
int64_t sizes[] = {1, splitSize};
222+
std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
223+
std::array<int64_t, 2> sizes = {1, splitSize};
222224
resOffsets = rewriter.create<tensor::InsertSliceOp>(
223225
loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
224226
curr += splitSize;
@@ -275,9 +277,9 @@ struct ConvertProcessMultiIndexOp
275277
if (!axes.empty()) {
276278
SmallVector<Value> subIndex;
277279
for (auto axis : axes) {
278-
subIndex.push_back(mIdx[axis]);
280+
subIndex.emplace_back(mIdx[axis]);
279281
}
280-
mIdx = subIndex;
282+
mIdx = std::move(subIndex);
281283
}
282284

283285
rewriter.replaceOp(op, mIdx);
@@ -294,8 +296,8 @@ class ConvertProcessLinearIndexOp
294296

295297
// Constructor accepting worldRank
296298
ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
297-
MLIRContext *context, int64_t worldRank_ = -1)
298-
: OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
299+
MLIRContext *context, int64_t worldRank = -1)
300+
: OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
299301

300302
LogicalResult
301303
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
@@ -308,12 +310,11 @@ class ConvertProcessLinearIndexOp
308310
}
309311

310312
// Otherwise call create mpi::CommRankOp
311-
auto rank =
312-
rewriter
313-
.create<mpi::CommRankOp>(
314-
op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()),
313+
auto rank = rewriter
314+
.create<mpi::CommRankOp>(
315+
loc, TypeRange{mpi::RetvalType::get(op->getContext()),
315316
rewriter.getI32Type()})
316-
.getRank();
317+
.getRank();
317318
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
318319
rank);
319320
return success();
@@ -400,11 +401,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
400401
}
401402

402403
// Compute the sharded shape by applying the sharding to the input shape.
403-
// Without shardedDimsOffsets in the sharding, the shard shape is computed
404-
// by dividing the dimension size by the number of shards in that dimension
405-
// (which is given by the size of the mesh axes provided in split-axes).
406-
// Odd elements get distributed to trailing shards.
407-
// If a shardedDimsOffsets is provided, the shard shape is computed by
404+
// If shardedDimsOffsets is not defined in the sharding, the shard shape is
405+
// computed by dividing the dimension size by the number of shards in that
406+
// dimension (which is given by the size of the mesh axes provided in
407+
// split-axes). Odd elements get distributed to trailing shards. If a
408+
// shardedDimsOffsets is provided, the shard shape is computed by
408409
// subtracting the offset of the current shard from the offset of the next
409410
// shard.
410411

@@ -429,8 +430,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
429430

430431
// To keep the code simple, convert dims/device to values when they are
431432
// attributes. Count on canonicalization to fold static values.
432-
auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
433-
auto multiIdx =
433+
SmallVector<Value> shape =
434+
getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
435+
SmallVector<Value> multiIdx =
434436
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
435437

436438
// Get the MeshOp, the mesh shape is needed to compute the sharded shape.
@@ -448,7 +450,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
448450
// local shard-size.
449451
Value shardedDimsOffs;
450452
{
451-
auto tmp = getMixedAsValues(
453+
SmallVector<Value> tmp = getMixedAsValues(
452454
rewriter, loc, sharding.getStaticShardedDimsOffsets(),
453455
sharding.getDynamicShardedDimsOffsets(), index);
454456
if (!tmp.empty())
@@ -478,7 +480,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
478480
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
479481
// Get the index of the local shard in the mesh axis.
480482
Value idx = multiIdx[axes[0]];
481-
auto _numShards =
483+
auto numShards =
482484
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
483485
if (shardedDimsOffs) {
484486
// If sharded dims offsets are provided, use them to compute the
@@ -497,22 +499,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
497499
Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
498500
shardShape.emplace_back(sz);
499501
} else {
500-
auto numShards = rewriter.create<arith::ConstantOp>(
501-
loc, rewriter.getIndexAttr(_numShards));
502+
Value numShardsVal = rewriter.create<arith::ConstantOp>(
503+
loc, rewriter.getIndexAttr(numShards));
502504
// Compute shard dim size by distributing odd elements to trailing
503505
// shards:
504506
// sz = dim / numShards
505507
// + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
506-
Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
507-
Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
508-
sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
508+
Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
509+
Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
510+
sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
509511
auto cond = rewriter.create<arith::CmpIOp>(
510512
loc, arith::CmpIPredicate::sge, idx, sz1);
511513
Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
512514
sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
513515
shardShape.emplace_back(sz);
514516
}
515-
pos += _numShards + 1; // add one for the total size.
517+
pos += numShards + 1; // add one for the total size.
516518
} // else no sharding if split axis is empty or no split axis
517519
// If no size was added -> no sharding in this dimension.
518520
if (shardShape.size() <= i)
@@ -698,25 +700,24 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
698700
offsets[dim] = orgOffset;
699701
};
700702

701-
auto get_i32val = [&](OpFoldResult &v) {
702-
return isa<Value>(v)
703-
? cast<Value>(v)
704-
: rewriter.create<arith::ConstantOp>(
705-
loc,
706-
rewriter.getI32IntegerAttr(
707-
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
708-
};
709-
710-
for (int i = 0; i < 2; ++i) {
711-
Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
703+
auto doSendRecv = [&](int upOrDown) {
704+
OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
705+
Value haloSz = dyn_cast<Value>(v);
706+
if (!haloSz)
707+
haloSz = rewriter.create<arith::ConstantOp>(
708+
loc, rewriter.getI32IntegerAttr(
709+
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
712710
auto hasSize = rewriter.create<arith::CmpIOp>(
713711
loc, arith::CmpIPredicate::sgt, haloSz, zero);
714712
rewriter.create<scf::IfOp>(loc, hasSize,
715713
[&](OpBuilder &builder, Location loc) {
716-
genSendRecv(i > 0);
714+
genSendRecv(upOrDown > 0);
717715
builder.create<scf::YieldOp>(loc);
718716
});
719-
}
717+
};
718+
719+
doSendRecv(0);
720+
doSendRecv(1);
720721

721722
// the shape for lower dims include higher dims' halos
722723
dimSizes[dim] = shape[dim];
@@ -775,8 +776,8 @@ struct ConvertMeshToMPIPass
775776
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
776777
auto i16 = IntegerType::get(type.getContext(), 16);
777778
auto i64 = IntegerType::get(type.getContext(), 64);
778-
std::array<int64_t, 2> shp{ShapedType::kDynamic,
779-
ShapedType::kDynamic};
779+
std::array<int64_t, 2> shp = {ShapedType::kDynamic,
780+
ShapedType::kDynamic};
780781
results.emplace_back(RankedTensorType::get(shp, i16));
781782
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
782783
results.emplace_back(RankedTensorType::get(shp, i64));

mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
3434
return %1#0, %1#1, %1#2 : index, index, index
3535
}
3636

37-
// all except first shard in second dim get an extra element
37+
// In the second dimension the shard sizes are now [3 4 4 4]
3838
// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
3939
func.func @shard_shape_odd_2() -> (index, index, index) {
4040
%sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
@@ -46,7 +46,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
4646
return %1#0, %1#1, %1#2 : index, index, index
4747
}
4848

49-
// all except first shard in first dim get an extra element
49+
// In the first dimension the shard sizes are now [3 4 4]
5050
// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
5151
func.func @shard_shape_odd_3() -> (index, index, index) {
5252
%sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
@@ -72,4 +72,4 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
7272
// CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
7373
return %1#0, %1#1, %1#2 : index, index, index
7474
}
75-
}
75+
}

0 commit comments

Comments
 (0)