Skip to content

Commit f78027d

Browse files
authored
[mlir][mesh] Better op result names (#82408)
Implement OpAsmOpInterface for most ops to increase IR readability. For example `mesh.process_linear_index` would produce a value with name `proc_linear_idx`.
1 parent 98db8d0 commit f78027d

File tree

3 files changed

+103
-10
lines changed

3 files changed

+103
-10
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1616
include "mlir/IR/BuiltinTypes.td"
1717
include "mlir/IR/CommonAttrConstraints.td"
1818
include "mlir/IR/CommonTypeConstraints.td"
19+
include "mlir/IR/OpAsmInterface.td"
1920
include "mlir/IR/SymbolInterfaces.td"
2021

2122
//===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
7879
}
7980

8081
def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
81-
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
82+
Pure,
83+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
84+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
85+
]> {
8286
let summary = "Get the shape of the mesh.";
8387
let arguments = (ins
8488
FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
101105
];
102106
}
103107

104-
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
108+
def Mesh_ShardOp : Mesh_Op<"shard", [
109+
Pure,
110+
SameOperandsAndResultType,
111+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
112+
]> {
105113
let summary = "Annotate on how a tensor is sharded across a mesh.";
106114
let description = [{
107115
The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
194202

195203
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
196204
Pure,
197-
DeclareOpInterfaceMethods<SymbolUserOpInterface>
205+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
206+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
198207
]> {
199208
let summary = "Get the multi index of current device along specified mesh axes.";
200209
let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
221230

222231
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
223232
Pure,
224-
DeclareOpInterfaceMethods<SymbolUserOpInterface>
233+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
234+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
225235
]> {
226236
let summary = "Get the linear index of the current device.";
227237
let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
248258
string mnemonic, list<Trait> traits = []> :
249259
Mesh_Op<mnemonic,
250260
!listconcat(traits,
251-
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
261+
[
262+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
263+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
264+
])> {
252265
dag commonArgs = (ins
253266
FlatSymbolRefAttr:$mesh,
254267
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
258271
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
259272
Pure,
260273
SameOperandsAndResultElementType,
261-
SameOperandsAndResultRank
274+
SameOperandsAndResultRank,
262275
]> {
263276
let summary = "All-gather over a device mesh.";
264277
let description = [{

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "mlir/Support/LLVM.h"
2525
#include "mlir/Support/LogicalResult.h"
2626
#include "llvm/ADT/ArrayRef.h"
27-
#include "llvm/ADT/DenseSet.h"
2827
#include "llvm/ADT/STLExtras.h"
2928
#include "llvm/ADT/SmallSet.h"
3029
#include "llvm/ADT/SmallVector.h"
@@ -34,7 +33,6 @@
3433
#include <iterator>
3534
#include <numeric>
3635
#include <optional>
37-
#include <string>
3836
#include <utility>
3937

4038
#define DEBUG_TYPE "mesh-ops"
@@ -244,6 +242,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
244242
MeshAxesAttr::get(odsBuilder.getContext(), axes));
245243
}
246244

245+
void MeshShapeOp::getAsmResultNames(
246+
function_ref<void(Value, StringRef)> setNameFn) {
247+
setNameFn(getResults()[0], "mesh_shape");
248+
}
249+
247250
//===----------------------------------------------------------------------===//
248251
// mesh.shard attr
249252
//===----------------------------------------------------------------------===//
@@ -307,6 +310,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
307310
std::mem_fn(&MeshAxesAttr::empty));
308311
}
309312

313+
//===----------------------------------------------------------------------===//
314+
// mesh.shard op
315+
//===----------------------------------------------------------------------===//
316+
317+
void ShardOp::getAsmResultNames(
318+
function_ref<void(Value, StringRef)> setNameFn) {
319+
setNameFn(getResult(), "sharding_annotated");
320+
}
321+
310322
//===----------------------------------------------------------------------===//
311323
// mesh.process_multi_index op
312324
//===----------------------------------------------------------------------===//
@@ -345,6 +357,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
345357
MeshAxesAttr::get(odsBuilder.getContext(), axes));
346358
}
347359

360+
void ProcessMultiIndexOp::getAsmResultNames(
361+
function_ref<void(Value, StringRef)> setNameFn) {
362+
setNameFn(getResults()[0], "proc_linear_idx");
363+
}
364+
348365
//===----------------------------------------------------------------------===//
349366
// mesh.process_linear_index op
350367
//===----------------------------------------------------------------------===//
@@ -363,6 +380,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
363380
build(odsBuilder, odsState, mesh.getSymName());
364381
}
365382

383+
void ProcessLinearIndexOp::getAsmResultNames(
384+
function_ref<void(Value, StringRef)> setNameFn) {
385+
setNameFn(getResult(), "proc_linear_idx");
386+
}
387+
366388
//===----------------------------------------------------------------------===//
367389
// collective communication ops
368390
//===----------------------------------------------------------------------===//
@@ -606,6 +628,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
606628
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
607629
}
608630

631+
void AllGatherOp::getAsmResultNames(
632+
function_ref<void(Value, StringRef)> setNameFn) {
633+
setNameFn(getResult(), "all_gather");
634+
}
635+
609636
//===----------------------------------------------------------------------===//
610637
// mesh.all_reduce op
611638
//===----------------------------------------------------------------------===//
@@ -620,6 +647,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
620647
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
621648
}
622649

650+
void AllReduceOp::getAsmResultNames(
651+
function_ref<void(Value, StringRef)> setNameFn) {
652+
setNameFn(getResult(), "all_reduce");
653+
}
654+
623655
//===----------------------------------------------------------------------===//
624656
// mesh.all_slice op
625657
//===----------------------------------------------------------------------===//
@@ -654,6 +686,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
654686
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
655687
}
656688

689+
void AllSliceOp::getAsmResultNames(
690+
function_ref<void(Value, StringRef)> setNameFn) {
691+
setNameFn(getResult(), "all_slice");
692+
}
693+
657694
//===----------------------------------------------------------------------===//
658695
// mesh.all_to_all op
659696
//===----------------------------------------------------------------------===//
@@ -674,6 +711,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
674711
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
675712
}
676713

714+
void AllToAllOp::getAsmResultNames(
715+
function_ref<void(Value, StringRef)> setNameFn) {
716+
setNameFn(getResult(), "all_to_all");
717+
}
718+
677719
//===----------------------------------------------------------------------===//
678720
// mesh.broadcast op
679721
//===----------------------------------------------------------------------===//
@@ -698,6 +740,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
698740
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
699741
}
700742

743+
void BroadcastOp::getAsmResultNames(
744+
function_ref<void(Value, StringRef)> setNameFn) {
745+
setNameFn(getResult(), "broadcast");
746+
}
747+
701748
//===----------------------------------------------------------------------===//
702749
// mesh.gather op
703750
//===----------------------------------------------------------------------===//
@@ -724,6 +771,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
724771
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
725772
}
726773

774+
void GatherOp::getAsmResultNames(
775+
function_ref<void(Value, StringRef)> setNameFn) {
776+
setNameFn(getResult(), "gather");
777+
}
778+
727779
//===----------------------------------------------------------------------===//
728780
// mesh.recv op
729781
//===----------------------------------------------------------------------===//
@@ -747,6 +799,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747799
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
748800
}
749801

802+
void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
803+
setNameFn(getResult(), "recv");
804+
}
805+
750806
//===----------------------------------------------------------------------===//
751807
// mesh.reduce op
752808
//===----------------------------------------------------------------------===//
@@ -770,6 +826,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
770826
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
771827
}
772828

829+
void ReduceOp::getAsmResultNames(
830+
function_ref<void(Value, StringRef)> setNameFn) {
831+
setNameFn(getResult(), "reduce");
832+
}
833+
773834
//===----------------------------------------------------------------------===//
774835
// mesh.reduce_scatter op
775836
//===----------------------------------------------------------------------===//
@@ -791,6 +852,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
791852
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
792853
}
793854

855+
void ReduceScatterOp::getAsmResultNames(
856+
function_ref<void(Value, StringRef)> setNameFn) {
857+
setNameFn(getResult(), "reduce_scatter");
858+
}
859+
794860
//===----------------------------------------------------------------------===//
795861
// mesh.scatter op
796862
//===----------------------------------------------------------------------===//
@@ -817,6 +883,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
817883
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
818884
}
819885

886+
void ScatterOp::getAsmResultNames(
887+
function_ref<void(Value, StringRef)> setNameFn) {
888+
setNameFn(getResult(), "scatter");
889+
}
890+
820891
//===----------------------------------------------------------------------===//
821892
// mesh.send op
822893
//===----------------------------------------------------------------------===//
@@ -839,6 +910,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
839910
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
840911
}
841912

913+
void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
914+
setNameFn(getResult(), "send");
915+
}
916+
842917
//===----------------------------------------------------------------------===//
843918
// mesh.shift op
844919
//===----------------------------------------------------------------------===//
@@ -865,6 +940,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
865940
// offset % shift_axis_mesh_dim_size == 0.
866941
}
867942

943+
void ShiftOp::getAsmResultNames(
944+
function_ref<void(Value, StringRef)> setNameFn) {
945+
setNameFn(getResult(), "shift");
946+
}
947+
868948
//===----------------------------------------------------------------------===//
869949
// TableGen'd op method definitions
870950
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
66
func.func @multi_index_2d_mesh() -> (index, index) {
77
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
88
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
9-
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
9+
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
1010
%0:2 = mesh.process_multi_index on @mesh2d : index, index
1111
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
1212
return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
1616
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
1717
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
1818
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
19-
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
19+
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
2020
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
2121
// CHECK: return %[[MULTI_IDX]]#0 : index
2222
return %0 : index

0 commit comments

Comments
 (0)