24
24
#include " mlir/Support/LLVM.h"
25
25
#include " mlir/Support/LogicalResult.h"
26
26
#include " llvm/ADT/ArrayRef.h"
27
- #include " llvm/ADT/DenseSet.h"
28
27
#include " llvm/ADT/STLExtras.h"
29
28
#include " llvm/ADT/SmallSet.h"
30
29
#include " llvm/ADT/SmallVector.h"
34
33
#include < iterator>
35
34
#include < numeric>
36
35
#include < optional>
37
- #include < string>
38
36
#include < utility>
39
37
40
38
#define DEBUG_TYPE " mesh-ops"
@@ -244,6 +242,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
244
242
MeshAxesAttr::get (odsBuilder.getContext (), axes));
245
243
}
246
244
245
+ void MeshShapeOp::getAsmResultNames (
246
+ function_ref<void (Value, StringRef)> setNameFn) {
247
+ setNameFn (getResults ()[0 ], " mesh_shape" );
248
+ }
249
+
247
250
// ===----------------------------------------------------------------------===//
248
251
// mesh.shard attr
249
252
// ===----------------------------------------------------------------------===//
@@ -307,6 +310,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
307
310
std::mem_fn (&MeshAxesAttr::empty));
308
311
}
309
312
313
+ // ===----------------------------------------------------------------------===//
314
+ // mesh.shard op
315
+ // ===----------------------------------------------------------------------===//
316
+
317
+ void ShardOp::getAsmResultNames (
318
+ function_ref<void (Value, StringRef)> setNameFn) {
319
+ setNameFn (getResult (), " sharding_annotated" );
320
+ }
321
+
310
322
// ===----------------------------------------------------------------------===//
311
323
// mesh.process_multi_index op
312
324
// ===----------------------------------------------------------------------===//
@@ -345,6 +357,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
345
357
MeshAxesAttr::get (odsBuilder.getContext (), axes));
346
358
}
347
359
360
+ void ProcessMultiIndexOp::getAsmResultNames (
361
+ function_ref<void (Value, StringRef)> setNameFn) {
362
+ setNameFn (getResults ()[0 ], " proc_linear_idx" );
363
+ }
364
+
348
365
// ===----------------------------------------------------------------------===//
349
366
// mesh.process_linear_index op
350
367
// ===----------------------------------------------------------------------===//
@@ -363,6 +380,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
363
380
build (odsBuilder, odsState, mesh.getSymName ());
364
381
}
365
382
383
+ void ProcessLinearIndexOp::getAsmResultNames (
384
+ function_ref<void (Value, StringRef)> setNameFn) {
385
+ setNameFn (getResult (), " proc_linear_idx" );
386
+ }
387
+
366
388
// ===----------------------------------------------------------------------===//
367
389
// collective communication ops
368
390
// ===----------------------------------------------------------------------===//
@@ -606,6 +628,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
606
628
patterns.add <EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
607
629
}
608
630
631
+ void AllGatherOp::getAsmResultNames (
632
+ function_ref<void (Value, StringRef)> setNameFn) {
633
+ setNameFn (getResult (), " all_gather" );
634
+ }
635
+
609
636
// ===----------------------------------------------------------------------===//
610
637
// mesh.all_reduce op
611
638
// ===----------------------------------------------------------------------===//
@@ -620,6 +647,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
620
647
patterns.add <EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
621
648
}
622
649
650
+ void AllReduceOp::getAsmResultNames (
651
+ function_ref<void (Value, StringRef)> setNameFn) {
652
+ setNameFn (getResult (), " all_reduce" );
653
+ }
654
+
623
655
// ===----------------------------------------------------------------------===//
624
656
// mesh.all_slice op
625
657
// ===----------------------------------------------------------------------===//
@@ -654,6 +686,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
654
686
APInt (sizeof (sliceAxis) * CHAR_BIT, sliceAxis));
655
687
}
656
688
689
+ void AllSliceOp::getAsmResultNames (
690
+ function_ref<void (Value, StringRef)> setNameFn) {
691
+ setNameFn (getResult (), " all_slice" );
692
+ }
693
+
657
694
// ===----------------------------------------------------------------------===//
658
695
// mesh.all_to_all op
659
696
// ===----------------------------------------------------------------------===//
@@ -674,6 +711,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
674
711
patterns.add <EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
675
712
}
676
713
714
+ void AllToAllOp::getAsmResultNames (
715
+ function_ref<void (Value, StringRef)> setNameFn) {
716
+ setNameFn (getResult (), " all_to_all" );
717
+ }
718
+
677
719
// ===----------------------------------------------------------------------===//
678
720
// mesh.broadcast op
679
721
// ===----------------------------------------------------------------------===//
@@ -698,6 +740,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
698
740
patterns.add <EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
699
741
}
700
742
743
+ void BroadcastOp::getAsmResultNames (
744
+ function_ref<void (Value, StringRef)> setNameFn) {
745
+ setNameFn (getResult (), " broadcast" );
746
+ }
747
+
701
748
// ===----------------------------------------------------------------------===//
702
749
// mesh.gather op
703
750
// ===----------------------------------------------------------------------===//
@@ -724,6 +771,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
724
771
patterns.add <EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
725
772
}
726
773
774
+ void GatherOp::getAsmResultNames (
775
+ function_ref<void (Value, StringRef)> setNameFn) {
776
+ setNameFn (getResult (), " gather" );
777
+ }
778
+
727
779
// ===----------------------------------------------------------------------===//
728
780
// mesh.recv op
729
781
// ===----------------------------------------------------------------------===//
@@ -747,6 +799,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747
799
patterns.add <EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
748
800
}
749
801
802
+ void RecvOp::getAsmResultNames (function_ref<void (Value, StringRef)> setNameFn) {
803
+ setNameFn (getResult (), " recv" );
804
+ }
805
+
750
806
// ===----------------------------------------------------------------------===//
751
807
// mesh.reduce op
752
808
// ===----------------------------------------------------------------------===//
@@ -770,6 +826,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
770
826
patterns.add <EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
771
827
}
772
828
829
+ void ReduceOp::getAsmResultNames (
830
+ function_ref<void (Value, StringRef)> setNameFn) {
831
+ setNameFn (getResult (), " reduce" );
832
+ }
833
+
773
834
// ===----------------------------------------------------------------------===//
774
835
// mesh.reduce_scatter op
775
836
// ===----------------------------------------------------------------------===//
@@ -791,6 +852,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
791
852
patterns.add <EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
792
853
}
793
854
855
+ void ReduceScatterOp::getAsmResultNames (
856
+ function_ref<void (Value, StringRef)> setNameFn) {
857
+ setNameFn (getResult (), " reduce_scatter" );
858
+ }
859
+
794
860
// ===----------------------------------------------------------------------===//
795
861
// mesh.scatter op
796
862
// ===----------------------------------------------------------------------===//
@@ -817,6 +883,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
817
883
patterns.add <EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
818
884
}
819
885
886
+ void ScatterOp::getAsmResultNames (
887
+ function_ref<void (Value, StringRef)> setNameFn) {
888
+ setNameFn (getResult (), " scatter" );
889
+ }
890
+
820
891
// ===----------------------------------------------------------------------===//
821
892
// mesh.send op
822
893
// ===----------------------------------------------------------------------===//
@@ -839,6 +910,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
839
910
patterns.add <EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
840
911
}
841
912
913
+ void SendOp::getAsmResultNames (function_ref<void (Value, StringRef)> setNameFn) {
914
+ setNameFn (getResult (), " send" );
915
+ }
916
+
842
917
// ===----------------------------------------------------------------------===//
843
918
// mesh.shift op
844
919
// ===----------------------------------------------------------------------===//
@@ -865,6 +940,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
865
940
// offset % shift_axis_mesh_dim_size == 0.
866
941
}
867
942
943
+ void ShiftOp::getAsmResultNames (
944
+ function_ref<void (Value, StringRef)> setNameFn) {
945
+ setNameFn (getResult (), " shift" );
946
+ }
947
+
868
948
// ===----------------------------------------------------------------------===//
869
949
// TableGen'd op method definitions
870
950
// ===----------------------------------------------------------------------===//
0 commit comments