14
14
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
17
+ #include " mlir/Dialect/DLTI/DLTI.h"
18
+ #include " mlir/Dialect/Func/IR/FuncOps.h"
19
+ #include " mlir/Dialect/Func/Transforms/FuncConversions.h"
20
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
17
21
#include " mlir/Dialect/MPI/IR/MPI.h"
18
22
#include " mlir/Dialect/MemRef/IR/MemRef.h"
23
+ #include " mlir/Dialect/Mesh/IR/MeshDialect.h"
19
24
#include " mlir/Dialect/Mesh/IR/MeshOps.h"
20
25
#include " mlir/Dialect/SCF/IR/SCF.h"
21
26
#include " mlir/Dialect/Tensor/IR/Tensor.h"
25
30
#include " mlir/IR/BuiltinTypes.h"
26
31
#include " mlir/IR/PatternMatch.h"
27
32
#include " mlir/IR/SymbolTable.h"
33
+ #include " mlir/Transforms/DialectConversion.h"
28
34
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
29
35
30
36
#define DEBUG_TYPE " mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
36
42
} // namespace mlir
37
43
38
44
using namespace mlir ;
39
- using namespace mlir :: mesh;
45
+ using namespace mesh ;
40
46
41
47
namespace {
42
- // Create operations converting a linear index to a multi-dimensional index
48
+ // / Convert vec of OpFoldResults (ints) into vector of Values.
49
+ static SmallVector<Value> getMixedAsValues (OpBuilder b, const Location &loc,
50
+ llvm::ArrayRef<int64_t > statics,
51
+ ValueRange dynamics,
52
+ Type type = Type()) {
53
+ SmallVector<Value> values;
54
+ auto dyn = dynamics.begin ();
55
+ Type i64 = b.getI64Type ();
56
+ if (!type)
57
+ type = i64 ;
58
+ assert (i64 == type || b.getIndexType () == type);
59
+ for (auto s : statics) {
60
+ values.emplace_back (
61
+ ShapedType::isDynamic (s)
62
+ ? *(dyn++)
63
+ : b.create <arith::ConstantOp>(loc, type,
64
+ i64 == type ? b.getI64IntegerAttr (s)
65
+ : b.getIndexAttr (s)));
66
+ }
67
+ return values;
68
+ };
69
+
70
+ // / Create operations converting a linear index to a multi-dimensional index.
43
71
static SmallVector<Value> linearToMultiIndex (Location loc, OpBuilder b,
44
72
Value linearIndex,
45
73
ValueRange dimensions) {
@@ -72,6 +100,152 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
72
100
return linearIndex;
73
101
}
74
102
103
+ // / Replace GetShardingOp with related/dependent ShardingOp.
104
+ struct ConvertGetShardingOp : public OpConversionPattern <GetShardingOp> {
105
+ using OpConversionPattern::OpConversionPattern;
106
+
107
+ LogicalResult
108
+ matchAndRewrite (GetShardingOp op, OpAdaptor adaptor,
109
+ ConversionPatternRewriter &rewriter) const override {
110
+ auto shardOp = adaptor.getSource ().getDefiningOp <ShardOp>();
111
+ if (!shardOp)
112
+ return failure ();
113
+ auto shardingOp = shardOp.getSharding ().getDefiningOp <ShardingOp>();
114
+ if (!shardingOp)
115
+ return failure ();
116
+
117
+ rewriter.replaceOp (op, shardingOp.getResult ());
118
+ return success ();
119
+ }
120
+ };
121
+
122
+ // / Convert a sharding op to a tuple of tensors of its components
123
+ // / (SplitAxes, HaloSizes, ShardedDimsOffsets)
124
+ // / as defined by type converter.
125
+ struct ConvertShardingOp : public OpConversionPattern <ShardingOp> {
126
+ using OpConversionPattern::OpConversionPattern;
127
+
128
+ LogicalResult
129
+ matchAndRewrite (ShardingOp op, OpAdaptor adaptor,
130
+ ConversionPatternRewriter &rewriter) const override {
131
+ auto splitAxes = op.getSplitAxes ().getAxes ();
132
+ int64_t maxNAxes = 0 ;
133
+ for (auto axes : splitAxes) {
134
+ maxNAxes = std::max<int64_t >(maxNAxes, axes.size ());
135
+ }
136
+
137
+ // To hold the split axes, create empty 2d tensor with shape
138
+ // {splitAxes.size(), max-size-of-split-groups}.
139
+ // Set trailing elements for smaller split-groups to -1.
140
+ Location loc = op.getLoc ();
141
+ auto i16 = rewriter.getI16Type ();
142
+ auto i64 = rewriter.getI64Type ();
143
+ int64_t shape[] = {static_cast <int64_t >(splitAxes.size ()), maxNAxes};
144
+ Value resSplitAxes = rewriter.create <tensor::EmptyOp>(loc, shape, i16 );
145
+ auto attr = IntegerAttr::get (i16 , 0xffff );
146
+ Value fillValue = rewriter.create <arith::ConstantOp>(loc, i16 , attr);
147
+ resSplitAxes = rewriter.create <linalg::FillOp>(loc, fillValue, resSplitAxes)
148
+ .getResult (0 );
149
+
150
+ // explicitly write values into tensor row by row
151
+ int64_t strides[] = {1 , 1 };
152
+ int64_t nSplits = 0 ;
153
+ ValueRange empty = {};
154
+ for (auto [i, axes] : llvm::enumerate (splitAxes)) {
155
+ int64_t size = axes.size ();
156
+ if (size > 0 )
157
+ ++nSplits;
158
+ int64_t offs[] = {(int64_t )i, 0 };
159
+ int64_t sizes[] = {1 , size};
160
+ auto tensorType = RankedTensorType::get ({size}, i16 );
161
+ auto attrs = DenseIntElementsAttr::get (tensorType, axes.asArrayRef ());
162
+ auto vals = rewriter.create <arith::ConstantOp>(loc, tensorType, attrs);
163
+ resSplitAxes = rewriter.create <tensor::InsertSliceOp>(
164
+ loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
165
+ }
166
+
167
+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
168
+ // Store the halo sizes in the tensor.
169
+ auto haloSizes =
170
+ getMixedAsValues (rewriter, loc, adaptor.getStaticHaloSizes (),
171
+ adaptor.getDynamicHaloSizes ());
172
+ auto type = RankedTensorType::get ({nSplits, 2 }, i64 );
173
+ Value resHaloSizes =
174
+ haloSizes.empty ()
175
+ ? rewriter
176
+ .create <tensor::EmptyOp>(loc, std::array<int64_t , 2 >{0 , 0 },
177
+ i64 )
178
+ .getResult ()
179
+ : rewriter.create <tensor::FromElementsOp>(loc, type, haloSizes)
180
+ .getResult ();
181
+
182
+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
183
+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
184
+ // elements for smaller split-groups to -1. Computing the max size of the
185
+ // split groups needs using collectiveProcessGroupSize (which needs the
186
+ // MeshOp)
187
+ Value resOffsets;
188
+ if (adaptor.getStaticShardedDimsOffsets ().empty ()) {
189
+ resOffsets = rewriter.create <tensor::EmptyOp>(
190
+ loc, std::array<int64_t , 2 >{0 , 0 }, i64 );
191
+ } else {
192
+ SymbolTableCollection symbolTableCollection;
193
+ auto meshOp = getMesh (op, symbolTableCollection);
194
+ auto maxSplitSize = 0 ;
195
+ for (auto axes : splitAxes) {
196
+ int64_t splitSize =
197
+ collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
198
+ assert (splitSize != ShapedType::kDynamic );
199
+ maxSplitSize = std::max<int64_t >(maxSplitSize, splitSize);
200
+ }
201
+ assert (maxSplitSize);
202
+ ++maxSplitSize; // add one for the total size
203
+
204
+ resOffsets = rewriter.create <tensor::EmptyOp>(
205
+ loc, std::array<int64_t , 2 >{nSplits, maxSplitSize}, i64 );
206
+ Value zero = rewriter.create <arith::ConstantOp>(
207
+ loc, i64 , rewriter.getI64IntegerAttr (ShapedType::kDynamic ));
208
+ resOffsets =
209
+ rewriter.create <linalg::FillOp>(loc, zero, resOffsets).getResult (0 );
210
+ auto offsets =
211
+ getMixedAsValues (rewriter, loc, adaptor.getStaticShardedDimsOffsets (),
212
+ adaptor.getDynamicShardedDimsOffsets ());
213
+ int64_t curr = 0 ;
214
+ for (auto [i, axes] : llvm::enumerate (splitAxes)) {
215
+ int64_t splitSize =
216
+ collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
217
+ assert (splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
218
+ ++splitSize; // add one for the total size
219
+ ArrayRef<Value> values (&offsets[curr], splitSize);
220
+ Value vals = rewriter.create <tensor::FromElementsOp>(loc, values);
221
+ int64_t offs[] = {(int64_t )i, 0 };
222
+ int64_t sizes[] = {1 , splitSize};
223
+ resOffsets = rewriter.create <tensor::InsertSliceOp>(
224
+ loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
225
+ curr += splitSize;
226
+ }
227
+ }
228
+
229
+ // return a tuple of tensors as defined by type converter
230
+ SmallVector<Type> resTypes;
231
+ if (failed (getTypeConverter ()->convertType (op.getResult ().getType (),
232
+ resTypes)))
233
+ return failure ();
234
+
235
+ resSplitAxes =
236
+ rewriter.create <tensor::CastOp>(loc, resTypes[0 ], resSplitAxes);
237
+ resHaloSizes =
238
+ rewriter.create <tensor::CastOp>(loc, resTypes[1 ], resHaloSizes);
239
+ resOffsets = rewriter.create <tensor::CastOp>(loc, resTypes[2 ], resOffsets);
240
+
241
+ rewriter.replaceOpWithNewOp <UnrealizedConversionCastOp>(
242
+ op, TupleType::get (op.getContext (), resTypes),
243
+ ValueRange{resSplitAxes, resHaloSizes, resOffsets});
244
+
245
+ return success ();
246
+ }
247
+ };
248
+
75
249
struct ConvertProcessMultiIndexOp
76
250
: public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
77
251
using OpRewritePattern::OpRewritePattern;
@@ -419,14 +593,95 @@ struct ConvertMeshToMPIPass
419
593
420
594
// / Run the dialect converter on the module.
421
595
void runOnOperation () override {
422
- auto *ctx = &getContext ();
423
- mlir::RewritePatternSet patterns (ctx);
596
+ uint64_t worldRank = -1 ;
597
+ // Try to get DLTI attribute for MPI:comm_world_rank
598
+ // If found, set worldRank to the value of the attribute.
599
+ {
600
+ auto dltiAttr =
601
+ dlti::query (getOperation (), {" MPI:comm_world_rank" }, false );
602
+ if (succeeded (dltiAttr)) {
603
+ if (!isa<IntegerAttr>(dltiAttr.value ())) {
604
+ getOperation ()->emitError ()
605
+ << " Expected an integer attribute for MPI:comm_world_rank" ;
606
+ return signalPassFailure ();
607
+ }
608
+ worldRank = cast<IntegerAttr>(dltiAttr.value ()).getInt ();
609
+ }
610
+ }
424
611
425
- patterns.insert <ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
426
- ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
427
- ctx);
612
+ auto *ctxt = &getContext ();
613
+ RewritePatternSet patterns (ctxt);
614
+ ConversionTarget target (getContext ());
615
+
616
+ // Define a type converter to convert mesh::ShardingType,
617
+ // mostly for use in return operations.
618
+ TypeConverter typeConverter;
619
+ typeConverter.addConversion ([](Type type) { return type; });
620
+
621
+ // convert mesh::ShardingType to a tuple of RankedTensorTypes
622
+ typeConverter.addConversion (
623
+ [](ShardingType type,
624
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
625
+ auto i16 = IntegerType::get (type.getContext (), 16 );
626
+ auto i64 = IntegerType::get (type.getContext (), 64 );
627
+ std::array<int64_t , 2 > shp{ShapedType::kDynamic ,
628
+ ShapedType::kDynamic };
629
+ results.emplace_back (RankedTensorType::get (shp, i16 ));
630
+ results.emplace_back (RankedTensorType::get (shp, i64 )); // actually ?x2
631
+ results.emplace_back (RankedTensorType::get (shp, i64 ));
632
+ return success ();
633
+ });
634
+
635
+ // To 'extract' components, a UnrealizedConversionCastOp is expected
636
+ // to define the input
637
+ typeConverter.addTargetMaterialization (
638
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
639
+ Location loc) {
640
+ // Expecting a single input.
641
+ if (inputs.size () != 1 || !isa<TupleType>(inputs[0 ].getType ()))
642
+ return SmallVector<Value>();
643
+ auto castOp = inputs[0 ].getDefiningOp <UnrealizedConversionCastOp>();
644
+ // Expecting an UnrealizedConversionCastOp.
645
+ if (!castOp)
646
+ return SmallVector<Value>();
647
+ // Fill a vector with elements of the tuple/castOp.
648
+ SmallVector<Value> results;
649
+ for (auto oprnd : castOp.getInputs ()) {
650
+ if (!isa<RankedTensorType>(oprnd.getType ()))
651
+ return SmallVector<Value>();
652
+ results.emplace_back (oprnd);
653
+ }
654
+ return results;
655
+ });
428
656
429
- (void )mlir::applyPatternsGreedily (getOperation (), std::move (patterns));
657
+ // No mesh dialect should left after conversion...
658
+ target.addIllegalDialect <mesh::MeshDialect>();
659
+ // ...except the global MeshOp
660
+ target.addLegalOp <mesh::MeshOp>();
661
+ // Allow all the stuff that our patterns will convert to
662
+ target.addLegalDialect <BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
663
+ arith::ArithDialect, tensor::TensorDialect,
664
+ bufferization::BufferizationDialect,
665
+ linalg::LinalgDialect, memref::MemRefDialect>();
666
+ // Make sure the function signature, calls etc. are legal
667
+ target.addDynamicallyLegalOp <func::FuncOp>([&](func::FuncOp op) {
668
+ return typeConverter.isSignatureLegal (op.getFunctionType ());
669
+ });
670
+ target.addDynamicallyLegalOp <func::CallOp, func::ReturnOp>(
671
+ [&](Operation *op) { return typeConverter.isLegal (op); });
672
+
673
+ patterns.add <ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
674
+ ConvertProcessMultiIndexOp, ConvertGetShardingOp,
675
+ ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
676
+ // ConvertProcessLinearIndexOp accepts an optional worldRank
677
+ patterns.add <ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
678
+
679
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
680
+ patterns, typeConverter);
681
+ populateCallOpTypeConversionPattern (patterns, typeConverter);
682
+ populateReturnOpTypeConversionPattern (patterns, typeConverter);
683
+
684
+ (void )applyPartialConversion (getOperation (), target, std::move (patterns));
430
685
}
431
686
};
432
687
0 commit comments