@@ -380,23 +380,159 @@ struct ConvertNeighborsLinearIndicesOp
380
380
[&](OpBuilder &builder, Location loc) {
381
381
SmallVector<Value> tmp = mIdx ;
382
382
tmp[axes[0 ]] =
383
- rewriter.create <arith::AddIOp>(op.getLoc (), orgIdx, one)
384
- .getResult ();
383
+ rewriter.create <arith::AddIOp>(op.getLoc (), orgIdx, one);
385
384
builder.create <scf::YieldOp>(
386
385
loc, multiToLinearIndex (loc, rewriter, tmp, dims));
387
386
});
388
387
rewriter.replaceOp (op, ValueRange{down.getResult (0 ), up.getResult (0 )});
389
- return mlir:: success ();
388
+ return success ();
390
389
}
391
390
};
392
391
393
- struct ConvertUpdateHaloOp
394
- : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
395
- using OpRewritePattern::OpRewritePattern;
392
+ struct ConvertShardShapeOp : public OpConversionPattern <ShardShapeOp> {
393
+ using OpConversionPattern::OpConversionPattern;
396
394
397
- mlir::LogicalResult
398
- matchAndRewrite (mlir::mesh::UpdateHaloOp op,
399
- mlir::PatternRewriter &rewriter) const override {
395
+ LogicalResult
396
+ matchAndRewrite (ShardShapeOp op, OneToNOpAdaptor adaptor,
397
+ ConversionPatternRewriter &rewriter) const override {
398
+ auto sharding = op.getSharding ().getDefiningOp <ShardingOp>();
399
+ if (!sharding) {
400
+ return op->emitError ()
401
+ << " Expected SharingOp as defining op for sharding"
402
+ << " but found " << adaptor.getSharding ()[0 ].getDefiningOp ();
403
+ }
404
+
405
+ // Compute the sharded shape by applying the sharding to the input shape.
406
+ // Without shardedDimsOffsets in the sharding, the shard shape is computed
407
+ // by dividing the dimension size by the number of shards in that dimension
408
+ // (which is given by the size of the mesh axes provided in split-axes).
409
+ // Odd elements get distributed to trailing shards.
410
+ // If a shardedDimsOffsets is provided, the shard shape is computed by
411
+ // subtracting the offset of the current shard from the offset of the next
412
+ // shard.
413
+
414
+ Location loc = op.getLoc ();
415
+ Type index = rewriter.getIndexType ();
416
+
417
+ // This is a 1:N conversion because the sharding op is a 1:3 conversion.
418
+ // The operands in the adaptor are a vector<ValeRange>. For dims and device
419
+ // we have a 1:1 conversion.
420
+ // For simpler access fill a vector with the dynamic dims.
421
+ SmallVector<Value> dynDims, dynDevice;
422
+ for (auto dim : adaptor.getDimsDynamic ()) {
423
+ // type conversion should be 1:1 for ints
424
+ assert (dim.size () == 1 );
425
+ dynDims.emplace_back (dim[0 ]);
426
+ }
427
+ // same for device
428
+ for (auto device : adaptor.getDeviceDynamic ()) {
429
+ assert (device.size () == 1 );
430
+ dynDevice.emplace_back (device[0 ]);
431
+ }
432
+
433
+ // To keep the code simple, convert dims/device to values when they are
434
+ // attributes. Count on canonicalization to fold static values.
435
+ auto shape = getMixedAsValues (rewriter, loc, op.getDims (), dynDims, index);
436
+ auto multiIdx =
437
+ getMixedAsValues (rewriter, loc, adaptor.getDevice (), dynDevice, index);
438
+
439
+ // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
440
+ SymbolTableCollection symbolTableCollection;
441
+ auto meshOp = getMesh (sharding, symbolTableCollection);
442
+ // For now we only support static mesh shapes
443
+ if (ShapedType::isDynamicShape (meshOp.getShape ()))
444
+ return failure ();
445
+
446
+ auto splitAxes = sharding.getSplitAxes ().getAxes ();
447
+ // shardedDimsOffsets are optional and might be Values (not attributes).
448
+ // Also, the shardId might be dynamic which means the position in the
449
+ // shardedDimsOffsets is not statically known. Create a tensor of the
450
+ // shardedDimsOffsets and later extract the offsets for computing the
451
+ // local shard-size.
452
+ Value shardedDimsOffs;
453
+ {
454
+ auto tmp = getMixedAsValues (
455
+ rewriter, loc, sharding.getStaticShardedDimsOffsets (),
456
+ sharding.getDynamicShardedDimsOffsets (), index);
457
+ if (!tmp.empty ())
458
+ shardedDimsOffs = rewriter.create <tensor::FromElementsOp>(
459
+ loc, RankedTensorType::get ({(int64_t )tmp.size ()}, index), tmp);
460
+ }
461
+
462
+ // With static mesh shape the sizes of the split axes are known.
463
+ // Hence the start/pos for each split axes in shardDimsOffsets can be
464
+ // computed statically.
465
+ int64_t pos = 0 ;
466
+ SmallVector<Value> shardShape;
467
+ Value zero =
468
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getZeroAttr (index));
469
+ Value one =
470
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getOneAttr (index));
471
+
472
+ // Iterate over the dimensions of the tensor shape, get their split Axes,
473
+ // and compute the sharded shape.
474
+ for (auto [i, dim] : llvm::enumerate (shape)) {
475
+ // Trailing dimensions might not be annotated.
476
+ if (i < splitAxes.size () && !splitAxes[i].empty ()) {
477
+ auto axes = splitAxes[i];
478
+ // The current dimension might not be sharded.
479
+ // Create a value from the static position in shardDimsOffsets.
480
+ Value posVal =
481
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (pos));
482
+ // Get the index of the local shard in the mesh axis.
483
+ Value idx = multiIdx[axes[0 ]];
484
+ auto _numShards =
485
+ collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
486
+ if (shardedDimsOffs) {
487
+ // If sharded dims offsets are provided, use them to compute the
488
+ // sharded shape.
489
+ if (axes.size () > 1 ) {
490
+ return op->emitError () << " Only single axis sharding is "
491
+ << " supported for each dimension." ;
492
+ }
493
+ idx = rewriter.create <arith::AddIOp>(loc, posVal, idx);
494
+ // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
495
+ Value off =
496
+ rewriter.create <tensor::ExtractOp>(loc, shardedDimsOffs, idx);
497
+ idx = rewriter.create <arith::AddIOp>(loc, idx, one);
498
+ Value nextOff =
499
+ rewriter.create <tensor::ExtractOp>(loc, shardedDimsOffs, idx);
500
+ Value sz = rewriter.create <arith::SubIOp>(loc, nextOff, off);
501
+ shardShape.emplace_back (sz);
502
+ } else {
503
+ auto numShards = rewriter.create <arith::ConstantOp>(
504
+ loc, rewriter.getIndexAttr (_numShards));
505
+ // Compute shard dim size by distributing odd elements to trailing
506
+ // shards:
507
+ // sz = dim / numShards
508
+ // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
509
+ Value sz = rewriter.create <arith::DivSIOp>(loc, dim, numShards);
510
+ Value sz1 = rewriter.create <arith::RemSIOp>(loc, dim, numShards);
511
+ sz1 = rewriter.create <arith::SubIOp>(loc, numShards, sz1);
512
+ auto cond = rewriter.create <arith::CmpIOp>(
513
+ loc, arith::CmpIPredicate::sge, idx, sz1);
514
+ Value odd = rewriter.create <arith::SelectOp>(loc, cond, one, zero);
515
+ sz = rewriter.create <arith::AddIOp>(loc, sz, odd);
516
+ shardShape.emplace_back (sz);
517
+ }
518
+ pos += _numShards + 1 ; // add one for the total size.
519
+ } // else no sharding if split axis is empty or no split axis
520
+ // If no size was added -> no sharding in this dimension.
521
+ if (shardShape.size () <= i)
522
+ shardShape.emplace_back (dim);
523
+ }
524
+ assert (shardShape.size () == shape.size ());
525
+ rewriter.replaceOp (op, shardShape);
526
+ return success ();
527
+ }
528
+ };
529
+
530
+ struct ConvertUpdateHaloOp : public OpConversionPattern <UpdateHaloOp> {
531
+ using OpConversionPattern::OpConversionPattern;
532
+
533
+ LogicalResult
534
+ matchAndRewrite (UpdateHaloOp op, OpAdaptor adaptor,
535
+ ConversionPatternRewriter &rewriter) const override {
400
536
401
537
// The input/output memref is assumed to be in C memory order.
402
538
// Halos are exchanged as 2 blocks per dimension (one for each side: down
0 commit comments