Skip to content

[mlir][mesh] Better Op result names #82408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2024
Merged

Conversation

sogartar
Copy link
Contributor

Implement OpAsmOpInterface for most ops to increase IR readability. For example mesh.process_linear_index would produce a value with name proc_linear_idx.

Implement OpAsmOpInterface for most ops to increase IR readability.
For example `mesh.process_linear_index` would produce a value with name
`proc_linear_idx`.
@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Implement OpAsmOpInterface for most ops to increase IR readability. For example mesh.process_linear_index would produce a value with name proc_linear_idx.


Full diff: https://github.com/llvm/llvm-project/pull/82408.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+19-6)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+98)
  • (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..b9cd15e2062669 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
 include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/IR/CommonTypeConstraints.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
 }
 
 def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
-  Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+    Pure,
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
   let summary = "Get the shape of the mesh.";
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
-def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+def Mesh_ShardOp : Mesh_Op<"shard", [
+    Pure,
+    SameOperandsAndResultType,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
   let summary = "Annotate on how a tensor is sharded across a mesh.";
   let description = [{
     The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
 
 def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
   Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary = "Get the multi index of current device along specified mesh axes.";
   let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary = "Get the linear index of the current device.";
   let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
     string mnemonic, list<Trait> traits = []> :
     Mesh_Op<mnemonic,
       !listconcat(traits,
-      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+        [
+          DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+          DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+        ])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
     DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     Pure,
     SameOperandsAndResultElementType,
-    SameOperandsAndResultRank
+    SameOperandsAndResultRank,
   ]> {
   let summary = "All-gather over a device mesh.";
   let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..07a320752b2595 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -27,7 +27,9 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <algorithm>
 #include <functional>
@@ -180,6 +182,20 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
   return type;
 }
 
+// static void getAsmMultiResultNames(Operation* op, StringRef namePrefix,
+// function_ref<void(Value, StringRef)> setNameFn) {
+//   if (op->getNumResults() == 1) {
+//     setNameFn(op->getResult(0), namePrefix);
+//     return;
+//   }
+//   SmallString<64> str;
+//   for (auto [i, result]: llvm::enumerate(op->getResults())) {
+//     (Twine(namePrefix) + "_" + Twine(i) + "_").toStringRef(str);
+//     setNameFn(result, str);
+//     str.clear();
+//   }
+// }
+
 //===----------------------------------------------------------------------===//
 // mesh.mesh op
 //===----------------------------------------------------------------------===//
@@ -244,6 +260,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+void MeshShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults()[0], "mesh_shape");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard attr
 //===----------------------------------------------------------------------===//
@@ -307,6 +328,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
                       std::mem_fn(&MeshAxesAttr::empty));
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+void ShardOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "sharding_annotated");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_multi_index op
 //===----------------------------------------------------------------------===//
@@ -345,6 +375,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+void ProcessMultiIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults()[0], "proc_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
@@ -363,6 +398,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
   build(odsBuilder, odsState, mesh.getSymName());
 }
 
+void ProcessLinearIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "proc_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -606,6 +646,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
 }
 
+void AllGatherOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_gather");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_reduce op
 //===----------------------------------------------------------------------===//
@@ -620,6 +665,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+void AllReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_reduce");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_slice op
 //===----------------------------------------------------------------------===//
@@ -654,6 +704,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
 }
 
+void AllSliceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_slice");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
@@ -674,6 +729,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
 }
 
+void AllToAllOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_to_all");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.broadcast op
 //===----------------------------------------------------------------------===//
@@ -698,6 +758,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
+void BroadcastOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "broadcast");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.gather op
 //===----------------------------------------------------------------------===//
@@ -724,6 +789,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
+void GatherOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "gather");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.recv op
 //===----------------------------------------------------------------------===//
@@ -747,6 +817,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
+void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "recv");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce op
 //===----------------------------------------------------------------------===//
@@ -770,6 +844,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
+void ReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reduce");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
@@ -791,6 +870,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
 }
 
+void ReduceScatterOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reduce_scatter");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.scatter op
 //===----------------------------------------------------------------------===//
@@ -817,6 +901,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
+void ScatterOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "scatter");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.send op
 //===----------------------------------------------------------------------===//
@@ -839,6 +928,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
+void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "send");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shift op
 //===----------------------------------------------------------------------===//
@@ -865,6 +958,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   // offset % shift_axis_mesh_dim_size == 0.
 }
 
+void ShiftOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "shift");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 677a5982ea2540..e23cfd79a42745 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
 func.func @multi_index_2d_mesh() -> (index, index) {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
   // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0:2 = mesh.process_multi_index on @mesh2d : index, index
   // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
   return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
 func.func @multi_index_2d_mesh_single_inner_axis() -> index {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
   // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
   // CHECK: return %[[MULTI_IDX]]#0 : index
   return %0 : index

@sogartar
Copy link
Contributor Author

@yaochengji, could you review this PR?

@sogartar sogartar requested a review from joker-eph February 20, 2024 20:34
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: these includes don't seem needed from the diff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cleaned up the unused headers.

@sogartar sogartar requested a review from joker-eph February 21, 2024 00:30
@sogartar sogartar merged commit f78027d into llvm:main Feb 21, 2024
@sogartar
Copy link
Contributor Author

@joker-eph, thank you for your timely review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants