-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Better Op result names #82408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implement OpAsmOpInterface for most ops to increase IR readability. For example `mesh.process_linear_index` would produce a value with name `proc_linear_idx`.
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesImplement OpAsmOpInterface for most ops to increase IR readability. For example Full diff: https://github.com/llvm/llvm-project/pull/82408.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..b9cd15e2062669 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
}
def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
- Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
let summary = "Get the shape of the mesh.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
-def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+def Mesh_ShardOp : Mesh_Op<"shard", [
+ Pure,
+ SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the linear index of the current device.";
let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
Mesh_Op<mnemonic,
!listconcat(traits,
- [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+ [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Pure,
SameOperandsAndResultElementType,
- SameOperandsAndResultRank
+ SameOperandsAndResultRank,
]> {
let summary = "All-gather over a device mesh.";
let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..07a320752b2595 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -27,7 +27,9 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
@@ -180,6 +182,20 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
return type;
}
+// static void getAsmMultiResultNames(Operation* op, StringRef namePrefix,
+// function_ref<void(Value, StringRef)> setNameFn) {
+// if (op->getNumResults() == 1) {
+// setNameFn(op->getResult(0), namePrefix);
+// return;
+// }
+// SmallString<64> str;
+// for (auto [i, result]: llvm::enumerate(op->getResults())) {
+// (Twine(namePrefix) + "_" + Twine(i) + "_").toStringRef(str);
+// setNameFn(result, str);
+// str.clear();
+// }
+// }
+
//===----------------------------------------------------------------------===//
// mesh.mesh op
//===----------------------------------------------------------------------===//
@@ -244,6 +260,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+void MeshShapeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults()[0], "mesh_shape");
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard attr
//===----------------------------------------------------------------------===//
@@ -307,6 +328,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+void ShardOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "sharding_annotated");
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
@@ -345,6 +375,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+void ProcessMultiIndexOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults()[0], "proc_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//
@@ -363,6 +398,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
build(odsBuilder, odsState, mesh.getSymName());
}
+void ProcessLinearIndexOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "proc_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -606,6 +646,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
}
+void AllGatherOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_gather");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_reduce op
//===----------------------------------------------------------------------===//
@@ -620,6 +665,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+void AllReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_reduce");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_slice op
//===----------------------------------------------------------------------===//
@@ -654,6 +704,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}
+void AllSliceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_slice");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
@@ -674,6 +729,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
}
+void AllToAllOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_to_all");
+}
+
//===----------------------------------------------------------------------===//
// mesh.broadcast op
//===----------------------------------------------------------------------===//
@@ -698,6 +758,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
+void BroadcastOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "broadcast");
+}
+
//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//
@@ -724,6 +789,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
+void GatherOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "gather");
+}
+
//===----------------------------------------------------------------------===//
// mesh.recv op
//===----------------------------------------------------------------------===//
@@ -747,6 +817,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
+void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "recv");
+}
+
//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//
@@ -770,6 +844,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}
+void ReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "reduce");
+}
+
//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//
@@ -791,6 +870,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
+void ReduceScatterOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "reduce_scatter");
+}
+
//===----------------------------------------------------------------------===//
// mesh.scatter op
//===----------------------------------------------------------------------===//
@@ -817,6 +901,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
+void ScatterOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "scatter");
+}
+
//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//
@@ -839,6 +928,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
+void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "send");
+}
+
//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//
@@ -865,6 +958,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// offset % shift_axis_mesh_dim_size == 0.
}
+void ShiftOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "shift");
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 677a5982ea2540..e23cfd79a42745 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0:2 = mesh.process_multi_index on @mesh2d : index, index
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
// CHECK: return %[[MULTI_IDX]]#0 : index
return %0 : index
|
@yaochengji, could you review this PR? |
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/Twine.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: these includes don't seem needed from the diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cleaned up the unused headers.
@joker-eph, thank you for your timely review. |
Implement OpAsmOpInterface for most ops to increase IR readability. For example
mesh.process_linear_index
would produce a value with nameproc_linear_idx
.