Skip to content

Commit 79aa776

Browse files
authored
[mlir][mesh] Add lowering of process multi-index op (#77490)
* Rename mesh.process_index -> mesh.process_multi_index. * Add mesh.process_linear_index op. * Add lowering of mesh.process_multi_index into an expression using mesh.process_linear_index, mesh.cluster_shape and affine.delinearize_index. This is useful to lower mesh ops and prepare them for further lowering where the runtime may have only the linear index of a device/process. For example in MPI we have a rank (linear index) in a communicator.
1 parent fef2fc3 commit 79aa776

File tree

14 files changed

+291
-38
lines changed

14 files changed

+291
-38
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
9696
let hasVerifier = 1;
9797
}
9898

99-
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
99+
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
100+
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
100101
let summary = "Get the shape of the cluster.";
101102
let arguments = (ins
102103
FlatSymbolRefAttr:$mesh,
@@ -209,11 +210,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
209210
}];
210211
}
211212

212-
def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
213-
let summary = "Get the index of current device along specified mesh axis.";
213+
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
214+
Pure,
215+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
216+
]> {
217+
let summary = "Get the multi index of current device along specified mesh axes.";
214218
let description = [{
215219
It is used in the SPMD format of IR.
216220
The `axes` mush be non-negative and less than the total number of mesh axes.
221+
If the axes are empty then get the index along all axes.
217222
}];
218223
let arguments = (ins
219224
FlatSymbolRefAttr:$mesh,
@@ -232,6 +237,27 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
232237
];
233238
}
234239

240+
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
241+
Pure,
242+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
243+
]> {
244+
let summary = "Get the linear index of the current device.";
245+
let description = [{
246+
Example:
247+
```
248+
%idx = mesh.process_linear_index on @mesh : index
249+
```
250+
if `@mesh` has shape `(10, 20, 30)`, a device with multi
251+
index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
252+
}];
253+
let arguments = (ins FlatSymbolRefAttr:$mesh);
254+
let results = (outs Index:$result);
255+
let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
256+
let builders = [
257+
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
258+
];
259+
}
260+
235261
//===----------------------------------------------------------------------===//
236262
// collective communication ops
237263
//===----------------------------------------------------------------------===//
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
10+
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
11+
12+
namespace mlir {
13+
class RewritePatternSet;
14+
class SymbolTableCollection;
15+
class DialectRegistry;
16+
namespace mesh {
17+
18+
void processMultiIndexOpLoweringPopulatePatterns(
19+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
20+
21+
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
22+
23+
} // namespace mesh
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
250250
ClusterOp mesh) {
251251
build(odsBuilder, odsState,
252252
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
253-
mesh.getSymName(), MeshAxesAttr());
253+
mesh.getSymName(),
254+
MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
254255
}
255256

256257
void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
@@ -325,11 +326,11 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
325326
}
326327

327328
//===----------------------------------------------------------------------===//
328-
// mesh.process_index op
329+
// mesh.process_multi_index op
329330
//===----------------------------------------------------------------------===//
330331

331332
LogicalResult
332-
ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
333+
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
333334
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
334335
if (failed(mesh)) {
335336
return failure();
@@ -348,20 +349,38 @@ ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
348349
return success();
349350
}
350351

351-
void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
352-
ClusterOp mesh) {
352+
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
353+
ClusterOp mesh) {
353354
build(odsBuilder, odsState,
354355
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
355356
mesh.getSymName(), MeshAxesAttr());
356357
}
357358

358-
void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
359-
StringRef mesh, ArrayRef<MeshAxis> axes) {
359+
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
360+
StringRef mesh, ArrayRef<MeshAxis> axes) {
360361
build(odsBuilder, odsState,
361362
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
362363
MeshAxesAttr::get(odsBuilder.getContext(), axes));
363364
}
364365

366+
//===----------------------------------------------------------------------===//
367+
// mesh.process_linear_index op
368+
//===----------------------------------------------------------------------===//
369+
370+
LogicalResult
371+
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
372+
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
373+
if (failed(mesh)) {
374+
return failure();
375+
}
376+
return success();
377+
}
378+
379+
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
380+
OperationState &odsState, ClusterOp mesh) {
381+
build(odsBuilder, odsState, mesh.getSymName());
382+
}
383+
365384
//===----------------------------------------------------------------------===//
366385
// collective communication ops
367386
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
22
Simplifications.cpp
33
ShardingPropagation.cpp
44
Spmdization.cpp
5+
Transforms.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
1112
MLIRShardingInterface
1213

1314
LINK_LIBS PUBLIC
15+
MLIRAffineDialect
1416
MLIRArithDialect
1517
MLIRControlFlowDialect
1618
MLIRFuncDialect

mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
1+
//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
206206

207207
Value processIndexAlongAxis =
208208
builder
209-
.create<ProcessIndexOp>(mesh.getSymName(),
210-
SmallVector<MeshAxis>({splitMeshAxis}))
209+
.create<ProcessMultiIndexOp>(mesh.getSymName(),
210+
SmallVector<MeshAxis>({splitMeshAxis}))
211211
.getResult()[0];
212212

213213
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//===- Transforms.cpp ---------------------------------------------- C++ --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
10+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
11+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/DialectRegistry.h"
14+
#include "mlir/IR/ImplicitLocOpBuilder.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/IR/Value.h"
17+
#include "llvm/ADT/STLExtras.h"
18+
#include "llvm/ADT/SmallVector.h"
19+
#include <iterator>
20+
#include <numeric>
21+
22+
namespace mlir::mesh {
23+
24+
namespace {
25+
26+
/// Lower `mesh.process_multi_index` into expression using
27+
/// `mesh.process_linear_index` and `mesh.cluster_shape`.
28+
struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
29+
template <typename... OpRewritePatternArgs>
30+
ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
31+
OpRewritePatternArgs &&...opRewritePatternArgs)
32+
: OpRewritePattern(
33+
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
34+
symbolTableCollection(symbolTableCollection) {}
35+
36+
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
37+
PatternRewriter &rewriter) const override {
38+
ClusterOp mesh =
39+
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
40+
op.getOperation(), op.getMeshAttr());
41+
if (!mesh) {
42+
return failure();
43+
}
44+
45+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
46+
builder.setInsertionPointAfter(op.getOperation());
47+
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
48+
ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
49+
SmallVector<Value> completeMultiIndex =
50+
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
51+
.getMultiIndex();
52+
SmallVector<Value> multiIndex;
53+
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
54+
SmallVector<MeshAxis> opAxesIota;
55+
if (opMeshAxes.empty()) {
56+
opAxesIota.resize(mesh.getRank());
57+
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
58+
opMeshAxes = opAxesIota;
59+
}
60+
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
61+
[&completeMultiIndex](MeshAxis meshAxis) {
62+
return completeMultiIndex[meshAxis];
63+
});
64+
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
65+
return success();
66+
}
67+
68+
private:
69+
SymbolTableCollection &symbolTableCollection;
70+
};
71+
72+
} // namespace
73+
74+
void processMultiIndexOpLoweringPopulatePatterns(
75+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
76+
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
77+
patterns.getContext());
78+
}
79+
80+
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
81+
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
82+
}
83+
84+
} // namespace mlir::mesh

mlir/test/Dialect/Mesh/invalid.mlir

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,48 +128,56 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
128128

129129
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
130130

131-
func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
131+
func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
132132
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
133-
%0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
133+
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
134134
return %0#0, %0#1 : index, index
135135
}
136136

137137
// -----
138138

139139
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
140140

141-
func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
141+
func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
142142
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
143-
%0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
143+
%0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
144144
return %0#0, %0#1, %0#2 : index, index, index
145145
}
146146

147147
// -----
148148

149149
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
150150

151-
func.func @process_index_wrong_number_of_results() -> (index, index) {
151+
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
152152
// expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
153-
%0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
153+
%0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
154154
return %0#0, %0#1 : index, index
155155
}
156156

157157
// -----
158158

159159
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
160160

161-
func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
161+
func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
162162
// expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
163-
%0:2 = mesh.process_index on @mesh0 : index, index
163+
%0:2 = mesh.process_multi_index on @mesh0 : index, index
164164
return %0#0, %0#1 : index, index
165165
}
166166

167167
// -----
168168

169-
func.func @process_index_invalid_mesh_name() -> (index) {
169+
func.func @process_multi_index_invalid_mesh_name() -> (index) {
170170
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
171-
%0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
172-
return %0#0 : index
171+
%0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
172+
return %0 : index
173+
}
174+
175+
// -----
176+
177+
func.func @process_linear_index_invalid_mesh_name() -> (index) {
178+
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
179+
%0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
180+
return %0 : index
173181
}
174182

175183
// -----

mlir/test/Dialect/Mesh/ops.mlir

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,30 +156,37 @@ func.func @cluster_shape_empty_axes() -> (index, index, index) {
156156
return %0#0, %0#1, %0#2 : index, index, index
157157
}
158158

159-
// CHECK-LABEL: func @process_index
160-
func.func @process_index() -> (index, index) {
161-
// CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
162-
%0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
159+
// CHECK-LABEL: func @process_multi_index
160+
func.func @process_multi_index() -> (index, index) {
161+
// CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
162+
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
163163
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
164164
return %0#0, %0#1 : index, index
165165
}
166166

167-
// CHECK-LABEL: func @process_index_default_axes
168-
func.func @process_index_default_axes() -> (index, index, index) {
169-
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
170-
%0:3 = mesh.process_index on @mesh0 : index, index, index
167+
// CHECK-LABEL: func @process_multi_index_default_axes
168+
func.func @process_multi_index_default_axes() -> (index, index, index) {
169+
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
170+
%0:3 = mesh.process_multi_index on @mesh0 : index, index, index
171171
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
172172
return %0#0, %0#1, %0#2 : index, index, index
173173
}
174174

175-
// CHECK-LABEL: func @process_index_empty_axes
176-
func.func @process_index_empty_axes() -> (index, index, index) {
177-
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
178-
%0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
175+
// CHECK-LABEL: func @process_multi_index_empty_axes
176+
func.func @process_multi_index_empty_axes() -> (index, index, index) {
177+
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
178+
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
179179
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
180180
return %0#0, %0#1, %0#2 : index, index, index
181181
}
182182

183+
// CHECK-LABEL: func @process_linear_index
184+
func.func @process_linear_index() -> index {
185+
// CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
186+
%0 = mesh.process_linear_index on @mesh0 : index
187+
// CHECK: return %[[RES]] : index
188+
return %0 : index
189+
}
183190

184191
// CHECK-LABEL: func @all_reduce
185192
func.func @all_reduce(

0 commit comments

Comments
 (0)