Skip to content

[mlir][mesh] Better Op result names #82408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
}

def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the shape of the mesh.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
Expand All @@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}

def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
def Mesh_ShardOp : Mesh_Op<"shard", [
Pure,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
Expand Down Expand Up @@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {

def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
Expand All @@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [

def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the linear index of the current device.";
let description = [{
Expand All @@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
Mesh_Op<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
[
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
Expand All @@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank
SameOperandsAndResultRank,
]> {
let summary = "All-gather over a device mesh.";
let description = [{
Expand Down
84 changes: 82 additions & 2 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
Expand All @@ -34,7 +33,6 @@
#include <iterator>
#include <numeric>
#include <optional>
#include <string>
#include <utility>

#define DEBUG_TYPE "mesh-ops"
Expand Down Expand Up @@ -244,6 +242,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}

void MeshShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResults()[0], "mesh_shape");
}

//===----------------------------------------------------------------------===//
// mesh.shard attr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -307,6 +310,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}

//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//

void ShardOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "sharding_annotated");
}

//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -345,6 +357,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}

void ProcessMultiIndexOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResults()[0], "proc_linear_idx");
}

//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//
Expand All @@ -363,6 +380,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
build(odsBuilder, odsState, mesh.getSymName());
}

void ProcessLinearIndexOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "proc_linear_idx");
}

//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -606,6 +628,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
}

void AllGatherOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "all_gather");
}

//===----------------------------------------------------------------------===//
// mesh.all_reduce op
//===----------------------------------------------------------------------===//
Expand All @@ -620,6 +647,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}

void AllReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "all_reduce");
}

//===----------------------------------------------------------------------===//
// mesh.all_slice op
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -654,6 +686,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}

void AllSliceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "all_slice");
}

//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
Expand All @@ -674,6 +711,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
}

void AllToAllOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "all_to_all");
}

//===----------------------------------------------------------------------===//
// mesh.broadcast op
//===----------------------------------------------------------------------===//
Expand All @@ -698,6 +740,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}

void BroadcastOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "broadcast");
}

//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//
Expand All @@ -724,6 +771,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}

void GatherOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "gather");
}

//===----------------------------------------------------------------------===//
// mesh.recv op
//===----------------------------------------------------------------------===//
Expand All @@ -747,6 +799,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}

void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "recv");
}

//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//
Expand All @@ -770,6 +826,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}

void ReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "reduce");
}

//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//
Expand All @@ -791,6 +852,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}

void ReduceScatterOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "reduce_scatter");
}

//===----------------------------------------------------------------------===//
// mesh.scatter op
//===----------------------------------------------------------------------===//
Expand All @@ -817,6 +883,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}

void ScatterOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "scatter");
}

//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//
Expand All @@ -839,6 +910,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}

void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "send");
}

//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//
Expand All @@ -865,6 +940,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// offset % shift_axis_mesh_dim_size == 0.
}

void ShiftOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "shift");
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0:2 = mesh.process_multi_index on @mesh2d : index, index
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
return %0#0, %0#1 : index, index
Expand All @@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
// CHECK: return %[[MULTI_IDX]]#0 : index
return %0 : index
Expand Down