@@ -45,7 +45,8 @@ using namespace mlir;
45
45
using namespace mesh ;
46
46
47
47
namespace {
48
- // / Convert vec of OpFoldResults (ints) into vector of Values.
48
+ // / Converts a vector of OpFoldResults (ints) into vector of Values of the
49
+ // / provided type.
49
50
static SmallVector<Value> getMixedAsValues (OpBuilder b, const Location &loc,
50
51
llvm::ArrayRef<int64_t > statics,
51
52
ValueRange dynamics,
@@ -55,14 +56,15 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
55
56
Type i64 = b.getI64Type ();
56
57
if (!type)
57
58
type = i64 ;
58
- assert (i64 == type || b.getIndexType () == type);
59
+ assert ((i64 == type || b.getIndexType () == type) &&
60
+ " expected an i64 or an intex type" );
59
61
for (auto s : statics) {
60
- values. emplace_back (
61
- ShapedType::isDynamic (s)
62
- ? *(dyn++)
63
- : b.create <arith::ConstantOp>(loc, type,
64
- i64 == type ? b. getI64IntegerAttr (s)
65
- : b. getIndexAttr (s)));
62
+ if (s == ShapedType:: kDynamic ) {
63
+ values. emplace_back (*(dyn++));
64
+ } else {
65
+ TypedAttr val = type == i64 ? b.getI64IntegerAttr (s) : b. getIndexAttr (s);
66
+ values. emplace_back (b. create <arith::ConstantOp>(loc, type, val));
67
+ }
66
68
}
67
69
return values;
68
70
};
@@ -129,33 +131,33 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
129
131
ConversionPatternRewriter &rewriter) const override {
130
132
auto splitAxes = op.getSplitAxes ().getAxes ();
131
133
int64_t maxNAxes = 0 ;
132
- for (auto axes : splitAxes) {
134
+ for (auto axes : splitAxes)
133
135
maxNAxes = std::max<int64_t >(maxNAxes, axes.size ());
134
- }
135
136
136
137
// To hold the split axes, create empty 2d tensor with shape
137
138
// {splitAxes.size(), max-size-of-split-groups}.
138
139
// Set trailing elements for smaller split-groups to -1.
139
140
Location loc = op.getLoc ();
140
141
auto i16 = rewriter.getI16Type ();
141
142
auto i64 = rewriter.getI64Type ();
142
- int64_t shape[] = {static_cast <int64_t >(splitAxes.size ()), maxNAxes};
143
+ std::array<int64_t , 2 > shape = {static_cast <int64_t >(splitAxes.size ()),
144
+ maxNAxes};
143
145
Value resSplitAxes = rewriter.create <tensor::EmptyOp>(loc, shape, i16 );
144
- auto attr = IntegerAttr::get (i16 , 0xffff );
146
+ auto attr = IntegerAttr::get (i16 , - 1 );
145
147
Value fillValue = rewriter.create <arith::ConstantOp>(loc, i16 , attr);
146
148
resSplitAxes = rewriter.create <linalg::FillOp>(loc, fillValue, resSplitAxes)
147
149
.getResult (0 );
148
150
149
151
// explicitly write values into tensor row by row
150
- int64_t strides[] = {1 , 1 };
152
+ std::array< int64_t , 2 > strides = {1 , 1 };
151
153
int64_t nSplits = 0 ;
152
154
ValueRange empty = {};
153
155
for (auto [i, axes] : llvm::enumerate (splitAxes)) {
154
156
int64_t size = axes.size ();
155
157
if (size > 0 )
156
158
++nSplits;
157
- int64_t offs[] = {(int64_t )i, 0 };
158
- int64_t sizes[] = {1 , size};
159
+ std::array< int64_t , 2 > offs = {(int64_t )i, 0 };
160
+ std::array< int64_t , 2 > sizes = {1 , size};
159
161
auto tensorType = RankedTensorType::get ({size}, i16 );
160
162
auto attrs = DenseIntElementsAttr::get (tensorType, axes.asArrayRef ());
161
163
auto vals = rewriter.create <arith::ConstantOp>(loc, tensorType, attrs);
@@ -165,7 +167,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
165
167
166
168
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
167
169
// Store the halo sizes in the tensor.
168
- auto haloSizes =
170
+ SmallVector<Value> haloSizes =
169
171
getMixedAsValues (rewriter, loc, adaptor.getStaticHaloSizes (),
170
172
adaptor.getDynamicHaloSizes ());
171
173
auto type = RankedTensorType::get ({nSplits, 2 }, i64 );
@@ -190,7 +192,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
190
192
} else {
191
193
SymbolTableCollection symbolTableCollection;
192
194
auto meshOp = getMesh (op, symbolTableCollection);
193
- auto maxSplitSize = 0 ;
195
+ int64_t maxSplitSize = 0 ;
194
196
for (auto axes : splitAxes) {
195
197
int64_t splitSize =
196
198
collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
@@ -206,7 +208,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
206
208
loc, i64 , rewriter.getI64IntegerAttr (ShapedType::kDynamic ));
207
209
resOffsets =
208
210
rewriter.create <linalg::FillOp>(loc, zero, resOffsets).getResult (0 );
209
- auto offsets =
211
+ SmallVector<Value> offsets =
210
212
getMixedAsValues (rewriter, loc, adaptor.getStaticShardedDimsOffsets (),
211
213
adaptor.getDynamicShardedDimsOffsets ());
212
214
int64_t curr = 0 ;
@@ -217,8 +219,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
217
219
++splitSize; // add one for the total size
218
220
ArrayRef<Value> values (&offsets[curr], splitSize);
219
221
Value vals = rewriter.create <tensor::FromElementsOp>(loc, values);
220
- int64_t offs[] = {( int64_t )i , 0 };
221
- int64_t sizes[] = {1 , splitSize};
222
+ std::array< int64_t , 2 > offs = {static_cast < int64_t >(i) , 0 };
223
+ std::array< int64_t , 2 > sizes = {1 , splitSize};
222
224
resOffsets = rewriter.create <tensor::InsertSliceOp>(
223
225
loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
224
226
curr += splitSize;
@@ -275,9 +277,9 @@ struct ConvertProcessMultiIndexOp
275
277
if (!axes.empty ()) {
276
278
SmallVector<Value> subIndex;
277
279
for (auto axis : axes) {
278
- subIndex.push_back (mIdx [axis]);
280
+ subIndex.emplace_back (mIdx [axis]);
279
281
}
280
- mIdx = subIndex;
282
+ mIdx = std::move ( subIndex) ;
281
283
}
282
284
283
285
rewriter.replaceOp (op, mIdx );
@@ -294,8 +296,8 @@ class ConvertProcessLinearIndexOp
294
296
295
297
// Constructor accepting worldRank
296
298
ConvertProcessLinearIndexOp (const TypeConverter &typeConverter,
297
- MLIRContext *context, int64_t worldRank_ = -1 )
298
- : OpConversionPattern(typeConverter, context), worldRank(worldRank_ ) {}
299
+ MLIRContext *context, int64_t worldRank = -1 )
300
+ : OpConversionPattern(typeConverter, context), worldRank(worldRank ) {}
299
301
300
302
LogicalResult
301
303
matchAndRewrite (ProcessLinearIndexOp op, OpAdaptor adaptor,
@@ -308,12 +310,11 @@ class ConvertProcessLinearIndexOp
308
310
}
309
311
310
312
// Otherwise call create mpi::CommRankOp
311
- auto rank =
312
- rewriter
313
- .create <mpi::CommRankOp>(
314
- op.getLoc (), TypeRange{mpi::RetvalType::get (op->getContext ()),
313
+ auto rank = rewriter
314
+ .create <mpi::CommRankOp>(
315
+ loc, TypeRange{mpi::RetvalType::get (op->getContext ()),
315
316
rewriter.getI32Type ()})
316
- .getRank ();
317
+ .getRank ();
317
318
rewriter.replaceOpWithNewOp <arith::IndexCastOp>(op, rewriter.getIndexType (),
318
319
rank);
319
320
return success ();
@@ -400,11 +401,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
400
401
}
401
402
402
403
// Compute the sharded shape by applying the sharding to the input shape.
403
- // Without shardedDimsOffsets in the sharding, the shard shape is computed
404
- // by dividing the dimension size by the number of shards in that dimension
405
- // (which is given by the size of the mesh axes provided in split-axes).
406
- // Odd elements get distributed to trailing shards.
407
- // If a shardedDimsOffsets is provided, the shard shape is computed by
404
+ // If shardedDimsOffsets is not defined in the sharding, the shard shape is
405
+ // computed by dividing the dimension size by the number of shards in that
406
+ // dimension (which is given by the size of the mesh axes provided in
407
+ // split-axes). Odd elements get distributed to trailing shards. If a
408
+ // shardedDimsOffsets is provided, the shard shape is computed by
408
409
// subtracting the offset of the current shard from the offset of the next
409
410
// shard.
410
411
@@ -429,8 +430,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
429
430
430
431
// To keep the code simple, convert dims/device to values when they are
431
432
// attributes. Count on canonicalization to fold static values.
432
- auto shape = getMixedAsValues (rewriter, loc, op.getDims (), dynDims, index);
433
- auto multiIdx =
433
+ SmallVector<Value> shape =
434
+ getMixedAsValues (rewriter, loc, op.getDims (), dynDims, index);
435
+ SmallVector<Value> multiIdx =
434
436
getMixedAsValues (rewriter, loc, adaptor.getDevice (), dynDevice, index);
435
437
436
438
// Get the MeshOp, the mesh shape is needed to compute the sharded shape.
@@ -448,7 +450,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
448
450
// local shard-size.
449
451
Value shardedDimsOffs;
450
452
{
451
- auto tmp = getMixedAsValues (
453
+ SmallVector<Value> tmp = getMixedAsValues (
452
454
rewriter, loc, sharding.getStaticShardedDimsOffsets (),
453
455
sharding.getDynamicShardedDimsOffsets (), index);
454
456
if (!tmp.empty ())
@@ -478,7 +480,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
478
480
rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (pos));
479
481
// Get the index of the local shard in the mesh axis.
480
482
Value idx = multiIdx[axes[0 ]];
481
- auto _numShards =
483
+ auto numShards =
482
484
collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
483
485
if (shardedDimsOffs) {
484
486
// If sharded dims offsets are provided, use them to compute the
@@ -497,22 +499,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
497
499
Value sz = rewriter.create <arith::SubIOp>(loc, nextOff, off);
498
500
shardShape.emplace_back (sz);
499
501
} else {
500
- auto numShards = rewriter.create <arith::ConstantOp>(
501
- loc, rewriter.getIndexAttr (_numShards ));
502
+ Value numShardsVal = rewriter.create <arith::ConstantOp>(
503
+ loc, rewriter.getIndexAttr (numShards ));
502
504
// Compute shard dim size by distributing odd elements to trailing
503
505
// shards:
504
506
// sz = dim / numShards
505
507
// + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
506
- Value sz = rewriter.create <arith::DivSIOp>(loc, dim, numShards );
507
- Value sz1 = rewriter.create <arith::RemSIOp>(loc, dim, numShards );
508
- sz1 = rewriter.create <arith::SubIOp>(loc, numShards , sz1);
508
+ Value sz = rewriter.create <arith::DivSIOp>(loc, dim, numShardsVal );
509
+ Value sz1 = rewriter.create <arith::RemSIOp>(loc, dim, numShardsVal );
510
+ sz1 = rewriter.create <arith::SubIOp>(loc, numShardsVal , sz1);
509
511
auto cond = rewriter.create <arith::CmpIOp>(
510
512
loc, arith::CmpIPredicate::sge, idx, sz1);
511
513
Value odd = rewriter.create <arith::SelectOp>(loc, cond, one, zero);
512
514
sz = rewriter.create <arith::AddIOp>(loc, sz, odd);
513
515
shardShape.emplace_back (sz);
514
516
}
515
- pos += _numShards + 1 ; // add one for the total size.
517
+ pos += numShards + 1 ; // add one for the total size.
516
518
} // else no sharding if split axis is empty or no split axis
517
519
// If no size was added -> no sharding in this dimension.
518
520
if (shardShape.size () <= i)
@@ -698,25 +700,24 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
698
700
offsets[dim] = orgOffset;
699
701
};
700
702
701
- auto get_i32val = [&](OpFoldResult &v) {
702
- return isa<Value>(v)
703
- ? cast<Value>(v)
704
- : rewriter.create <arith::ConstantOp>(
705
- loc,
706
- rewriter.getI32IntegerAttr (
707
- cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
708
- };
709
-
710
- for (int i = 0 ; i < 2 ; ++i) {
711
- Value haloSz = get_i32val (haloSizes[currHaloDim * 2 + i]);
703
+ auto doSendRecv = [&](int upOrDown) {
704
+ OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
705
+ Value haloSz = dyn_cast<Value>(v);
706
+ if (!haloSz)
707
+ haloSz = rewriter.create <arith::ConstantOp>(
708
+ loc, rewriter.getI32IntegerAttr (
709
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
712
710
auto hasSize = rewriter.create <arith::CmpIOp>(
713
711
loc, arith::CmpIPredicate::sgt, haloSz, zero);
714
712
rewriter.create <scf::IfOp>(loc, hasSize,
715
713
[&](OpBuilder &builder, Location loc) {
716
- genSendRecv (i > 0 );
714
+ genSendRecv (upOrDown > 0 );
717
715
builder.create <scf::YieldOp>(loc);
718
716
});
719
- }
717
+ };
718
+
719
+ doSendRecv (0 );
720
+ doSendRecv (1 );
720
721
721
722
// the shape for lower dims include higher dims' halos
722
723
dimSizes[dim] = shape[dim];
@@ -775,8 +776,8 @@ struct ConvertMeshToMPIPass
775
776
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
776
777
auto i16 = IntegerType::get (type.getContext (), 16 );
777
778
auto i64 = IntegerType::get (type.getContext (), 64 );
778
- std::array<int64_t , 2 > shp{ShapedType::kDynamic ,
779
- ShapedType::kDynamic };
779
+ std::array<int64_t , 2 > shp = {ShapedType::kDynamic ,
780
+ ShapedType::kDynamic };
780
781
results.emplace_back (RankedTensorType::get (shp, i16 ));
781
782
results.emplace_back (RankedTensorType::get (shp, i64 )); // actually ?x2
782
783
results.emplace_back (RankedTensorType::get (shp, i64 ));
0 commit comments