@@ -76,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
76
76
77
77
for (int i = n - 1 ; i >= 0 ; --i) {
78
78
multiIndex[i] = b.create <arith::RemSIOp>(loc, linearIndex, dimensions[i]);
79
- if (i > 0 ) {
79
+ if (i > 0 )
80
80
linearIndex = b.create <arith::DivSIOp>(loc, linearIndex, dimensions[i]);
81
- }
82
81
}
83
82
84
83
return multiIndex;
85
84
}
86
85
87
- // Create operations converting a multi-dimensional index to a linear index
86
+ // / Create operations converting a multi-dimensional index to a linear index.
88
87
Value multiToLinearIndex (Location loc, OpBuilder b, ValueRange multiIndex,
89
88
ValueRange dimensions) {
90
89
91
- auto linearIndex = b.create <arith::ConstantIndexOp>(loc, 0 ). getResult ( );
92
- auto stride = b.create <arith::ConstantIndexOp>(loc, 1 ). getResult ( );
90
+ Value linearIndex = b.create <arith::ConstantIndexOp>(loc, 0 );
91
+ Value stride = b.create <arith::ConstantIndexOp>(loc, 1 );
93
92
94
93
for (int i = multiIndex.size () - 1 ; i >= 0 ; --i) {
95
- auto off = b.create <arith::MulIOp>(loc, multiIndex[i], stride);
94
+ Value off = b.create <arith::MulIOp>(loc, multiIndex[i], stride);
96
95
linearIndex = b.create <arith::AddIOp>(loc, linearIndex, off);
97
96
stride = b.create <arith::MulIOp>(loc, stride, dimensions[i]);
98
97
}
@@ -247,34 +246,32 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
247
246
};
248
247
249
248
struct ConvertProcessMultiIndexOp
250
- : public mlir::OpRewritePattern<mlir::mesh:: ProcessMultiIndexOp> {
251
- using OpRewritePattern::OpRewritePattern ;
249
+ : public OpConversionPattern< ProcessMultiIndexOp> {
250
+ using OpConversionPattern::OpConversionPattern ;
252
251
253
- mlir:: LogicalResult
254
- matchAndRewrite (mlir::mesh:: ProcessMultiIndexOp op,
255
- mlir::PatternRewriter &rewriter) const override {
252
+ LogicalResult
253
+ matchAndRewrite (ProcessMultiIndexOp op, OpAdaptor adaptor ,
254
+ ConversionPatternRewriter &rewriter) const override {
256
255
257
256
// Currently converts its linear index to a multi-dimensional index.
258
257
259
258
SymbolTableCollection symbolTableCollection;
260
- auto loc = op.getLoc ();
259
+ Location loc = op.getLoc ();
261
260
auto meshOp = getMesh (op, symbolTableCollection);
262
261
// For now we only support static mesh shapes
263
- if (ShapedType::isDynamicShape (meshOp.getShape ())) {
264
- return mlir::failure ();
265
- }
262
+ if (ShapedType::isDynamicShape (meshOp.getShape ()))
263
+ return failure ();
266
264
267
265
SmallVector<Value> dims;
268
266
llvm::transform (
269
267
meshOp.getShape (), std::back_inserter (dims), [&](int64_t i) {
270
268
return rewriter.create <arith::ConstantIndexOp>(loc, i).getResult ();
271
269
});
272
- auto rank =
273
- rewriter.create <ProcessLinearIndexOp>(op.getLoc (), meshOp).getResult ();
270
+ Value rank = rewriter.create <ProcessLinearIndexOp>(op.getLoc (), meshOp);
274
271
auto mIdx = linearToMultiIndex (loc, rewriter, rank, dims);
275
272
276
273
// optionally extract subset of mesh axes
277
- auto axes = op .getAxes ();
274
+ auto axes = adaptor .getAxes ();
278
275
if (!axes.empty ()) {
279
276
SmallVector<Value> subIndex;
280
277
for (auto axis : axes) {
@@ -319,44 +316,43 @@ class ConvertProcessLinearIndexOp
319
316
.getRank ();
320
317
rewriter.replaceOpWithNewOp <arith::IndexCastOp>(op, rewriter.getIndexType (),
321
318
rank);
322
- return mlir:: success ();
319
+ return success ();
323
320
}
324
321
};
325
322
326
323
struct ConvertNeighborsLinearIndicesOp
327
- : public mlir::OpRewritePattern<mlir::mesh:: NeighborsLinearIndicesOp> {
328
- using OpRewritePattern::OpRewritePattern ;
324
+ : public OpConversionPattern< NeighborsLinearIndicesOp> {
325
+ using OpConversionPattern::OpConversionPattern ;
329
326
330
- mlir:: LogicalResult
331
- matchAndRewrite (mlir::mesh:: NeighborsLinearIndicesOp op,
332
- mlir::PatternRewriter &rewriter) const override {
327
+ LogicalResult
328
+ matchAndRewrite (NeighborsLinearIndicesOp op, OpAdaptor adaptor ,
329
+ ConversionPatternRewriter &rewriter) const override {
333
330
334
331
// Computes the neighbors indices along a split axis by simply
335
332
// adding/subtracting 1 to the current index in that dimension.
336
333
// Assigns -1 if neighbor is out of bounds.
337
334
338
- auto axes = op .getSplitAxes ();
335
+ auto axes = adaptor .getSplitAxes ();
339
336
// For now only single axis sharding is supported
340
- if (axes.size () != 1 ) {
341
- return mlir::failure ();
342
- }
337
+ if (axes.size () != 1 )
338
+ return failure ();
343
339
344
- auto loc = op.getLoc ();
340
+ Location loc = op.getLoc ();
345
341
SymbolTableCollection symbolTableCollection;
346
342
auto meshOp = getMesh (op, symbolTableCollection);
347
- auto mIdx = op .getDevice ();
343
+ auto mIdx = adaptor .getDevice ();
348
344
auto orgIdx = mIdx [axes[0 ]];
349
345
SmallVector<Value> dims;
350
346
llvm::transform (
351
347
meshOp.getShape (), std::back_inserter (dims), [&](int64_t i) {
352
348
return rewriter.create <arith::ConstantIndexOp>(loc, i).getResult ();
353
349
});
354
- auto dimSz = dims[axes[0 ]];
355
- auto one = rewriter.create <arith::ConstantIndexOp>(loc, 1 ). getResult ( );
356
- auto minus1 = rewriter.create <arith::ConstantIndexOp>(loc, -1 ). getResult ( );
357
- auto atBorder = rewriter.create <arith::CmpIOp>(
350
+ Value dimSz = dims[axes[0 ]];
351
+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
352
+ Value minus1 = rewriter.create <arith::ConstantIndexOp>(loc, -1 );
353
+ Value atBorder = rewriter.create <arith::CmpIOp>(
358
354
loc, arith::CmpIPredicate::sle, orgIdx,
359
- rewriter.create <arith::ConstantIndexOp>(loc, 0 ). getResult () );
355
+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
360
356
auto down = rewriter.create <scf::IfOp>(
361
357
loc, atBorder,
362
358
[&](OpBuilder &builder, Location loc) {
@@ -598,23 +594,20 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
598
594
// we need the actual shape to compute offsets and sizes
599
595
for (auto i = 0 ; i < rank; ++i) {
600
596
auto s = dstShape[i];
601
- if (ShapedType::isDynamic (s)) {
597
+ if (ShapedType::isDynamic (s))
602
598
shape[i] = rewriter.create <memref::DimOp>(loc, array, s).getResult ();
603
- } else {
599
+ else
604
600
shape[i] = rewriter.getIndexAttr (s);
605
- }
606
601
607
602
if ((size_t )i < opSplitAxes.size () && !opSplitAxes[i].empty ()) {
608
603
++currHaloDim;
609
604
// the offsets for lower dim sstarts after their down halo
610
605
offsets[i] = haloSizes[currHaloDim * 2 ];
611
606
612
607
// prepare shape and offsets of highest dim's halo exchange
613
- auto _haloSz =
614
- rewriter
615
- .create <arith::AddIOp>(loc, toValue (haloSizes[currHaloDim * 2 ]),
616
- toValue (haloSizes[currHaloDim * 2 + 1 ]))
617
- .getResult ();
608
+ Value _haloSz = rewriter.create <arith::AddIOp>(
609
+ loc, toValue (haloSizes[currHaloDim * 2 ]),
610
+ toValue (haloSizes[currHaloDim * 2 + 1 ]));
618
611
// the halo shape of lower dims exlude the halos
619
612
dimSizes[i] =
620
613
rewriter.create <arith::SubIOp>(loc, toValue (shape[i]), _haloSz)
@@ -625,9 +618,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
625
618
}
626
619
627
620
auto tagAttr = rewriter.getI32IntegerAttr (91 ); // we just pick something
628
- auto tag = rewriter.create <::mlir:: arith::ConstantOp>(loc, tagAttr);
621
+ auto tag = rewriter.create <arith::ConstantOp>(loc, tagAttr);
629
622
auto zeroAttr = rewriter.getI32IntegerAttr (0 ); // for detecting v<0
630
- auto zero = rewriter.create <::mlir:: arith::ConstantOp>(loc, zeroAttr);
623
+ auto zero = rewriter.create <arith::ConstantOp>(loc, zeroAttr);
631
624
632
625
SmallVector<Type> indexResultTypes (meshOp.getShape ().size (),
633
626
rewriter.getIndexType ());
@@ -637,9 +630,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
637
630
// traverse all split axes from high to low dim
638
631
for (ssize_t dim = opSplitAxes.size () - 1 ; dim >= 0 ; --dim) {
639
632
auto splitAxes = opSplitAxes[dim];
640
- if (splitAxes.empty ()) {
633
+ if (splitAxes.empty ())
641
634
continue ;
642
- }
643
635
assert (currHaloDim >= 0 && (size_t )currHaloDim < haloSizes.size () / 2 );
644
636
// Get the linearized ids of the neighbors (down and up) for the
645
637
// given split
0 commit comments