@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
339
339
let hasCanonicalizer = 1;
340
340
}
341
341
342
+ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
343
+ AllShapesMatch<["input", "result"]>,
344
+ AllElementTypesMatch<["input", "result"]>
345
+ ]> {
346
+ let summary = "Broadcast over a device mesh.";
347
+ let description = [{
348
+ Broadcast the tensor on `root` to all devices in each respective group.
349
+ The operation broadcasts along mesh axes `mesh_axes`.
350
+ The `root` device specifies the in-group multi-index that is broadcast to
351
+ all other devices in the group.
352
+
353
+ Example:
354
+ ```
355
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
356
+
357
+ %1 = mesh.broadcast %0 on @mesh0
358
+ mesh_axes = [0]
359
+ root = [0]
360
+ : (tensor<2xi8>) -> tensor<2xi8>
361
+ ```
362
+
363
+ Input:
364
+ ```
365
+ +-------+-------+ | broadcast
366
+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
367
+ +-------+-------+ ↓
368
+ device (1, 0) -> | | | <- device (1, 1)
369
+ +-------+-------+
370
+ ```
371
+
372
+ Output:
373
+ ```
374
+ +-------+-------+
375
+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
376
+ +-------+-------+
377
+ device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
378
+ +-------+-------+
379
+ ```
380
+ }];
381
+ let arguments = !con(commonArgs, (ins
382
+ AnyRankedTensor:$input,
383
+ DenseI64ArrayAttr:$root,
384
+ Variadic<Index>:$root_dynamic
385
+ ));
386
+ let results = (outs
387
+ AnyRankedTensor:$result
388
+ );
389
+ let assemblyFormat = [{
390
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
391
+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
392
+ attr-dict `:` functional-type(operands, results)
393
+ }];
394
+ }
395
+
396
+ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
397
+ AllRanksMatch<["input", "result"]>,
398
+ AllElementTypesMatch<["input", "result"]>
399
+ ]> {
400
+ let summary = "Gather over a device mesh.";
401
+ let description = [{
402
+ Gathers on device `root` along the `gather_axis` tensor axis.
403
+ `root` specifies the coordinates of a device along `mesh_axes`.
404
+ It uniquely identifies the root device for each device group.
405
+ The result tensor on non-root devices is undefined.
406
+ Using it will result in undefined behavior.
407
+
408
+ Example:
409
+ ```mlir
410
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
411
+ ...
412
+ %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
413
+ gather_axis = 1 root = [1]
414
+ : (tensor<2x2xi8>) -> tensor<2x4xi8>
415
+ ```
416
+ Input:
417
+ ```
418
+ gather tensor
419
+ axis 1
420
+ ------------>
421
+ +-------+-------+
422
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
423
+ | 3 4 | 7 8 |
424
+ +-------+-------+
425
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
426
+ | 11 12 | 15 16 |
427
+ +-------+-------+
428
+ ```
429
+ Result:
430
+ ```
431
+ +-------------+
432
+ | 1 2 5 6 | <- devices (0, 1)
433
+ | 3 4 7 8 |
434
+ +-------------+
435
+ | 9 10 13 14 | <- devices (1, 1)
436
+ | 11 12 15 16 |
437
+ +-------------+
438
+ ```
439
+ Devices `(0, 0)` and `(1, 0)` have undefined result.
440
+ }];
441
+ let arguments = !con(commonArgs, (ins
442
+ AnyNon0RankedTensor:$input,
443
+ IndexAttr:$gather_axis,
444
+ DenseI64ArrayAttr:$root,
445
+ Variadic<Index>:$root_dynamic
446
+ ));
447
+ let results = (outs
448
+ AnyNon0RankedTensor:$result
449
+ );
450
+ let assemblyFormat = [{
451
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
452
+ `gather_axis` `=` $gather_axis
453
+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
454
+ attr-dict `:` functional-type(operands, results)
455
+ }];
456
+ }
457
+
458
+ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
459
+ AllShapesMatch<["input", "result"]>,
460
+ AllElementTypesMatch<["input", "result"]>
461
+ ]> {
462
+ let summary = "Send over a device mesh.";
463
+ let description = [{
464
+ Receive from a device within a device group.
465
+ }];
466
+ let arguments = !con(commonArgs, (ins
467
+ AnyNon0RankedTensor:$input,
468
+ OptionalAttr<DenseI64ArrayAttr>:$source,
469
+ Variadic<Index>:$source_dynamic
470
+ ));
471
+ let results = (outs
472
+ AnyRankedTensor:$result
473
+ );
474
+ let assemblyFormat = [{
475
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
476
+ (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
477
+ attr-dict `:` functional-type(operands, results)
478
+ }];
479
+ }
480
+
481
+ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
482
+ AllShapesMatch<["input", "result"]>
483
+ ]> {
484
+ let summary = "Reduce over a device mesh.";
485
+ let description = [{
486
+ Reduces on device `root` within each device group.
487
+ `root` specifies the coordinates of a device along `mesh_axes`.
488
+ It uniquely identifies the root device within its device group.
489
+ The accumulation element type is specified by the result type and
490
+ it does not need to match the input element type.
491
+ The input element is converted to the result element type before
492
+ performing the reduction.
493
+
494
+ Attributes:
495
+ `reduction`: Indicates the reduction method.
496
+
497
+ Example:
498
+ ```
499
+ %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
500
+ reduction = <max> root = [2, 3]
501
+ : (tensor<3x4xf32>) -> tensor<3x4xf64>
502
+ ```
503
+ }];
504
+ let arguments = !con(commonArgs, (ins
505
+ AnyRankedTensor:$input,
506
+ DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
507
+ DenseI64ArrayAttr:$root,
508
+ Variadic<Index>:$root_dynamic
509
+ ));
510
+ let results = (outs
511
+ AnyRankedTensor:$result
512
+ );
513
+ let assemblyFormat = [{
514
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
515
+ (`reduction` `=` $reduction^)?
516
+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
517
+ attr-dict `:` functional-type(operands, results)
518
+ }];
519
+ }
520
+
342
521
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
343
522
SameOperandsAndResultRank]> {
344
523
let summary = "Reduce-scatter over a device mesh.";
@@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
400
579
let hasCanonicalizer = 1;
401
580
}
402
581
582
+ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
583
+ AllRanksMatch<["input", "result"]>,
584
+ AllElementTypesMatch<["input", "result"]>
585
+ ]> {
586
+ let summary = "Scatter over a device mesh.";
587
+ let description = [{
588
+ For each device group split the input tensor on the `root` device along
589
+ axis `scatter_axis` and scatter the parts across the group devices.
590
+
591
+ Example:
592
+ ```
593
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
594
+ %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
595
+ scatter_axis = 0
596
+ root = [1]
597
+ : (tensor<2x2xi8>) -> tensor<1x2xi8>
598
+ ```
599
+
600
+ Input:
601
+ ```
602
+ device
603
+ (0, 1)
604
+ ↓
605
+ +-------+-------+ | scatter tensor
606
+ device (0, 0) -> | | | | axis 0
607
+ | | | ↓
608
+ +-------+-------+
609
+ device (1, 0) -> | 1 2 | 5 6 |
610
+ | 3 4 | 7 8 |
611
+ +-------+-------+
612
+ ↑
613
+ device
614
+ (1, 1)
615
+ ```
616
+
617
+ Result:
618
+ ```
619
+ device
620
+ (0, 1)
621
+ ↓
622
+ +-------+-------+
623
+ device (0, 0) -> | 1 2 | 5 6 |
624
+ +-------+-------+
625
+ device (1, 0) -> | 3 4 | 7 8 |
626
+ +-------+-------+
627
+ ↑
628
+ device
629
+ (1, 1)
630
+ ```
631
+ }];
632
+ let arguments = !con(commonArgs, (ins
633
+ AnyNon0RankedTensor:$input,
634
+ IndexAttr:$scatter_axis,
635
+ DenseI64ArrayAttr:$root,
636
+ Variadic<Index>:$root_dynamic
637
+ ));
638
+ let results = (outs
639
+ AnyRankedTensor:$result
640
+ );
641
+ let assemblyFormat = [{
642
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
643
+ `scatter_axis` `=` $scatter_axis
644
+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
645
+ attr-dict `:` functional-type(operands, results)
646
+ }];
647
+ }
648
+
649
+ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
650
+ AllShapesMatch<["input", "result"]>,
651
+ AllElementTypesMatch<["input", "result"]>
652
+ ]> {
653
+ let summary = "Send over a device mesh.";
654
+ let description = [{
655
+ Send from one device to another within a device group.
656
+ }];
657
+ let arguments = !con(commonArgs, (ins
658
+ AnyNon0RankedTensor:$input,
659
+ DenseI64ArrayAttr:$destination,
660
+ Variadic<Index>:$destination_dynamic
661
+ ));
662
+ let results = (outs
663
+ AnyRankedTensor:$result
664
+ );
665
+ let assemblyFormat = [{
666
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
667
+ `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
668
+ attr-dict `:` functional-type(operands, results)
669
+ }];
670
+ }
671
+
672
+ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
673
+ SameOperandsAndResultElementType,
674
+ SameOperandsAndResultShape
675
+ ]> {
676
+ let summary = "Sift over a device mesh.";
677
+ let description = [{
678
+ Within each device group shift along mesh axis `shift_axis` by an offset
679
+ `offset`.
680
+ The result on devices that do not have a corresponding source is undefined.
681
+ `shift_axis` must be one of `mesh_axes`.
682
+ If the `rotate` attribute is present,
683
+ instead of a shift a rotation is done.
684
+
685
+ Example:
686
+ ```
687
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
688
+ %1 = mesh.shift on @mesh0 mesh_axes = [1]
689
+ shift_axis = 1 offset = 2 rotate
690
+ : tensor<2xi8> -> tensor<2xi8>
691
+ ```
692
+
693
+ Input:
694
+ ```
695
+ mesh axis 1
696
+ ----------->
697
+
698
+ +----+----+----+----+
699
+ | 1 | 2 | 3 | 4 |
700
+ +----+----+----+----+
701
+ | 5 | 6 | 7 | 8 |
702
+ +----+----+----+----+
703
+ ```
704
+
705
+ Result:
706
+ ```
707
+ +----+----+----+----+
708
+ | 3 | 4 | 1 | 2 |
709
+ +----+----+----+----+
710
+ | 7 | 8 | 5 | 6 |
711
+ +----+----+----+----+
712
+ ```
713
+ }];
714
+ let arguments = !con(commonArgs, (ins
715
+ AnyNon0RankedTensor:$input,
716
+ IndexAttr:$shift_axis,
717
+ I64Attr:$offset,
718
+ UnitAttr:$rotate
719
+ ));
720
+ let results = (outs
721
+ AnyRankedTensor:$result
722
+ );
723
+ let assemblyFormat = [{
724
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
725
+ `shift_axis` `=` $shift_axis
726
+ `offset` `=` $offset
727
+ (`rotate` $rotate^)?
728
+ attr-dict `:` type($input) `->` type($result)
729
+ }];
730
+ }
731
+
403
732
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
0 commit comments