Skip to content

Commit dc3258c

Browse files
authored
[mlir][mesh] Add all-slice operation (#81218)
This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic. Add lowering for the op that slices from the tensor based on the in-group process index. Make resharding generate an all-slice instead of inserting the slicing logic directly.
1 parent e3f88a9 commit dc3258c

File tree

26 files changed

+701
-214
lines changed

26 files changed

+701
-214
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1717
#include "mlir/Dialect/Affine/IR/AffineOps.h"
18+
#include "mlir/IR/OpDefinition.h"
1819
#include <optional>
1920

2021
namespace mlir {
2122
class DominanceInfo;
2223
class Operation;
2324
class PostDominanceInfo;
25+
class ImplicitLocOpBuilder;
2426

2527
namespace func {
2628
class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
309311
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
310312
Value linearIndex,
311313
ArrayRef<Value> basis);
314+
// Generate IR that extracts the linear index from a multi-index according to
315+
// a basis/shape.
316+
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
317+
ArrayRef<OpFoldResult> basis,
318+
ImplicitLocOpBuilder &builder);
312319

313320
/// Ensure that all operations that could be executed after `start`
314321
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/Matchers.h"
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/IR/Value.h"
23+
#include "llvm/ADT/ArrayRef.h"
2324

2425
namespace mlir {
2526

@@ -81,6 +82,22 @@ struct ArithBuilder {
8182
OpBuilder &b;
8283
Location loc;
8384
};
85+
86+
namespace arith {
87+
88+
// Build the product of a sequence.
89+
// If values = (v0, v1, ..., vn) than the returned
90+
// value is v0 * v1 * ... * vn.
91+
// All values must have the same type.
92+
//
93+
// The version without `resultType` must contain at least one element in values.
94+
// Then the result will have the same type as the elements in `values`.
95+
// If `values` is empty in the version with `resultType` returns 1 with type
96+
// `resultType`.
97+
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
98+
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
99+
Type resultType);
100+
} // namespace arith
84101
} // namespace mlir
85102

86103
#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
9292
return res;
9393
}
9494

95+
template <typename MeshAxesRange>
96+
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
97+
return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
98+
mesh.getShape());
99+
}
100+
95101
// Get the size of a sharded dimension.
96102
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
97103
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
9696

9797
let builders = [
9898
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
99+
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
99100
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
100101
];
101102
}
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
341342
let hasCanonicalizer = 1;
342343
}
343344

345+
def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
346+
Pure,
347+
SameOperandsAndResultElementType,
348+
SameOperandsAndResultRank
349+
]> {
350+
let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
351+
let description = [{
352+
Slice along the `slice_axis` tensor axis.
353+
This operation can be thought of as the inverse of all-gather.
354+
Technically, it is not required that all processes have the same input tensor.
355+
Each process will slice a piece of its local tensor based on its in-group device index.
356+
The operation does not communicate data between devices.
357+
358+
Example:
359+
```mlir
360+
mesh.mesh @mesh0(shape = 2x2)
361+
...
362+
%1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
363+
: tensor<2x4xi8> -> tensor<2x2xi8>
364+
```
365+
Input:
366+
```
367+
+-------------+
368+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
369+
| 3 4 7 8 |
370+
+-------------+
371+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
372+
| 11 12 15 16 |
373+
+-------------+
374+
```
375+
Result:
376+
```
377+
gather tensor
378+
axis 1
379+
------------>
380+
+-------+-------+
381+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
382+
| 3 4 | 7 8 |
383+
+-------+-------+
384+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
385+
| 11 12 | 15 16 |
386+
+-------+-------+
387+
```
388+
}];
389+
let arguments = !con(commonArgs, (ins
390+
AnyNon0RankedTensor:$input,
391+
IndexAttr:$slice_axis
392+
));
393+
let results = (outs
394+
AnyNon0RankedTensor:$result
395+
);
396+
let assemblyFormat = [{
397+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
398+
attr-dict `:` type($input) `->` type($result)
399+
}];
400+
let hasCanonicalizer = 1;
401+
let builders = [
402+
OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
403+
OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
404+
];
405+
}
406+
344407
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
345408
Pure,
346409
SameOperandsAndResultElementType,

mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,33 @@
99
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
1010
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
1111

12+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/Value.h"
15+
#include "mlir/Support/LLVM.h"
16+
1217
namespace mlir {
1318
class RewritePatternSet;
1419
class SymbolTableCollection;
1520
class DialectRegistry;
21+
class ImplicitLocOpBuilder;
1622
namespace mesh {
1723

18-
void processMultiIndexOpLoweringPopulatePatterns(
24+
void populateProcessMultiIndexOpLoweringPatterns(
25+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
26+
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry);
27+
28+
void populateAllSliceOpLoweringPatterns(
29+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
30+
void registerAllSliceOpLoweringDialects(DialectRegistry &registry);
31+
32+
void populateAllOpLoweringPatterns(
1933
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
34+
void registerAllOpLoweringDialects(DialectRegistry &registry);
2035

21-
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
36+
TypedValue<IndexType>
37+
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
38+
ImplicitLocOpBuilder &builder);
2239

2340
} // namespace mesh
2441
} // namespace mlir

mlir/include/mlir/IR/Builders.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ class Builder {
118118
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
119119
// ranked tensor of them. Returns null attribute otherwise.
120120
TypedAttr getZeroAttr(Type type);
121+
// Returns a 1-valued attribute of the given `type`.
122+
// Type constraints are the same as `getZeroAttr`.
123+
TypedAttr getOneAttr(Type type);
121124

122125
// Convenience methods for fixed types.
123126
FloatAttr getF16FloatAttr(float value);

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
#include "mlir/Dialect/Arith/Utils/Utils.h"
2121
#include "mlir/Dialect/Func/IR/FuncOps.h"
2222
#include "mlir/Dialect/MemRef/IR/MemRef.h"
23+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/IR/AffineExprVisitor.h"
2425
#include "mlir/IR/Dominance.h"
2526
#include "mlir/IR/IRMapping.h"
27+
#include "mlir/IR/ImplicitLocOpBuilder.h"
2628
#include "mlir/IR/IntegerSet.h"
2729
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2830
#include <optional>
@@ -1869,3 +1871,27 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
18691871
results.push_back(residual);
18701872
return results;
18711873
}
1874+
1875+
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
1876+
ArrayRef<OpFoldResult> basis,
1877+
ImplicitLocOpBuilder &builder) {
1878+
assert(multiIndex.size() == basis.size());
1879+
SmallVector<AffineExpr> basisAffine;
1880+
for (size_t i = 0; i < basis.size(); ++i) {
1881+
basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
1882+
}
1883+
1884+
SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine);
1885+
SmallVector<OpFoldResult> strides;
1886+
strides.reserve(stridesAffine.size());
1887+
llvm::transform(stridesAffine, std::back_inserter(strides),
1888+
[&builder, &basis](AffineExpr strideExpr) {
1889+
return affine::makeComposedFoldedAffineApply(
1890+
builder, builder.getLoc(), strideExpr, basis);
1891+
});
1892+
1893+
auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
1894+
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
1895+
return affine::makeComposedFoldedAffineApply(
1896+
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
1897+
}

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Complex/IR/Complex.h"
1616
#include "mlir/IR/ImplicitLocOpBuilder.h"
1717
#include "llvm/ADT/SmallBitVector.h"
18+
#include <numeric>
1819

1920
using namespace mlir;
2021

@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
262263
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
263264
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
264265
}
266+
267+
namespace mlir::arith {
268+
269+
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
270+
return createProduct(builder, loc, values, values.front().getType());
271+
}
272+
273+
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
274+
Type resultType) {
275+
Value one = builder.create<ConstantOp>(loc, resultType,
276+
builder.getOneAttr(resultType));
277+
ArithBuilder arithBuilder(builder, loc);
278+
return std::accumulate(
279+
values.begin(), values.end(), one,
280+
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
281+
}
282+
283+
} // namespace mlir::arith

0 commit comments

Comments
 (0)