Skip to content

[mlir][mesh] Add lowering of process multi-index op #77490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}

def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Get the shape of the cluster.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
Expand Down Expand Up @@ -209,11 +210,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}

def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Get the index of current device along specified mesh axis.";
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
It is used in the SPMD format of IR.
The `axes` mush be non-negative and less than the total number of mesh axes.
If the axes are empty then get the index along all axes.
}];
let arguments = (ins
FlatSymbolRefAttr:$mesh,
Expand All @@ -232,6 +237,27 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
];
}

def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Get the linear index of the current device.";
let description = [{
Example:
```
%idx = mesh.process_linear_index on @mesh : index
```
if `@mesh` has shape `(10, 20, 30)`, a device with multi
index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
}];
let arguments = (ins FlatSymbolRefAttr:$mesh);
let results = (outs Index:$result);
let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
let builders = [
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
];
}

//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H

namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
namespace mesh {

void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);

void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);

} // namespace mesh
} // namespace mlir

#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
33 changes: 26 additions & 7 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), MeshAxesAttr());
mesh.getSymName(),
MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
}

void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Expand Down Expand Up @@ -325,11 +326,11 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
}

//===----------------------------------------------------------------------===//
// mesh.process_index op
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//

LogicalResult
ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
Expand All @@ -348,20 +349,38 @@ ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), MeshAxesAttr());
}

void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}

//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//

LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
return success();
}

void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState, ClusterOp mesh) {
build(odsBuilder, odsState, mesh.getSymName());
}

//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
Simplifications.cpp
ShardingPropagation.cpp
Spmdization.cpp
Transforms.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
Expand All @@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRShardingInterface

LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
MLIRControlFlowDialect
MLIRFuncDialect
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,

Value processIndexAlongAxis =
builder
.create<ProcessIndexOp>(mesh.getSymName(),
SmallVector<MeshAxis>({splitMeshAxis}))
.create<ProcessMultiIndexOp>(mesh.getSymName(),
SmallVector<MeshAxis>({splitMeshAxis}))
.getResult()[0];

MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
Expand Down
84 changes: 84 additions & 0 deletions mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//===- Transforms.cpp ---------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>

namespace mlir::mesh {

namespace {

/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.cluster_shape`.
struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
template <typename... OpRewritePatternArgs>
ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTableCollection(symbolTableCollection) {}

LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
ClusterOp mesh =
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
[&completeMultiIndex](MeshAxis meshAxis) {
return completeMultiIndex[meshAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}

private:
SymbolTableCollection &symbolTableCollection;
};

} // namespace

void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
patterns.getContext());
}

void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}

} // namespace mlir::mesh
30 changes: 19 additions & 11 deletions mlir/test/Dialect/Mesh/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -128,48 +128,56 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {

mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)

func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
%0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}

// -----

mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)

func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
%0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// -----

mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)

func.func @process_index_wrong_number_of_results() -> (index, index) {
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
%0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}

// -----

mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)

func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
%0:2 = mesh.process_index on @mesh0 : index, index
%0:2 = mesh.process_multi_index on @mesh0 : index, index
return %0#0, %0#1 : index, index
}

// -----

func.func @process_index_invalid_mesh_name() -> (index) {
func.func @process_multi_index_invalid_mesh_name() -> (index) {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
return %0#0 : index
%0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
return %0 : index
}

// -----

func.func @process_linear_index_invalid_mesh_name() -> (index) {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
return %0 : index
}

// -----
Expand Down
31 changes: 19 additions & 12 deletions mlir/test/Dialect/Mesh/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,37 @@ func.func @cluster_shape_empty_axes() -> (index, index, index) {
return %0#0, %0#1, %0#2 : index, index, index
}

// CHECK-LABEL: func @process_index
func.func @process_index() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
%0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
// CHECK-LABEL: func @process_multi_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops/invalid check of mesh.process_linear_index seems missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests.

func.func @process_multi_index() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}

// CHECK-LABEL: func @process_index_default_axes
func.func @process_index_default_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
%0:3 = mesh.process_index on @mesh0 : index, index, index
// CHECK-LABEL: func @process_multi_index_default_axes
func.func @process_multi_index_default_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// CHECK-LABEL: func @process_index_empty_axes
func.func @process_index_empty_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
%0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
// CHECK-LABEL: func @process_multi_index_empty_axes
func.func @process_multi_index_empty_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
%0 = mesh.process_linear_index on @mesh0 : index
// CHECK: return %[[RES]] : index
return %0 : index
}

// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
Expand Down
Loading