@@ -338,6 +338,105 @@ struct LegalizeTransferWriteOpsByDecomposition
338
338
}
339
339
};
340
340
341
+ // Shuffles arith extend ops after vector.extract op.
342
+ //
343
+ // This transforms IR like:
344
+ // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
345
+ // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
346
+ // Into:
347
+ // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
348
+ // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
349
+ //
350
+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
351
+ // pass when the result is the input to an outer product.
352
+ struct SwapVectorExtractOfArithExtend
353
+ : public OpRewritePattern<vector::ExtractOp> {
354
+ using OpRewritePattern::OpRewritePattern;
355
+
356
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
357
+ PatternRewriter &rewriter) const override {
358
+ VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType ());
359
+ if (!resultType)
360
+ return rewriter.notifyMatchFailure (extractOp,
361
+ " extracted type is not a vector type" );
362
+
363
+ auto numScalableDims = llvm::count (resultType.getScalableDims (), true );
364
+ if (numScalableDims != 1 )
365
+ return rewriter.notifyMatchFailure (
366
+ extractOp, " extracted type is not a 1-D scalable vector type" );
367
+
368
+ auto *extendOp = extractOp.getVector ().getDefiningOp ();
369
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
370
+ extendOp))
371
+ return rewriter.notifyMatchFailure (extractOp,
372
+ " extract not from extend op" );
373
+
374
+ auto loc = extractOp.getLoc ();
375
+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
376
+ Value extendSource = extendOp->getOperand (0 );
377
+
378
+ // Create new extract from source of extend.
379
+ Value newExtract = rewriter.create <vector::ExtractOp>(
380
+ loc, extendSource, extractOp.getMixedPosition ());
381
+
382
+ // Extend new extract to original result type.
383
+ Operation *newExtend =
384
+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
385
+
386
+ rewriter.replaceOp (extractOp, newExtend->getResult (0 ));
387
+
388
+ return success ();
389
+ }
390
+ };
391
+
392
+ // Shuffles arith extend ops after vector.scalable.extract op.
393
+ //
394
+ // This transforms IR like:
395
+ // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
396
+ // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
397
+ // Into:
398
+ // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
399
+ // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
400
+ //
401
+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
402
+ // pass when the result is the input to an outer product.
403
+ struct SwapVectorScalableExtractOfArithExtend
404
+ : public OpRewritePattern<vector::ScalableExtractOp> {
405
+ using OpRewritePattern::OpRewritePattern;
406
+
407
+ LogicalResult matchAndRewrite (vector::ScalableExtractOp extractOp,
408
+ PatternRewriter &rewriter) const override {
409
+ auto *extendOp = extractOp.getSource ().getDefiningOp ();
410
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
411
+ extendOp))
412
+ return rewriter.notifyMatchFailure (extractOp,
413
+ " extract not from extend op" );
414
+
415
+ auto loc = extractOp.getLoc ();
416
+ VectorType resultType = extractOp.getResultVectorType ();
417
+
418
+ Value extendSource = extendOp->getOperand (0 );
419
+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
420
+ VectorType extendSourceVectorType =
421
+ cast<VectorType>(extendSource.getType ());
422
+
423
+ // Create new extract from source of extend.
424
+ VectorType extractResultVectorType =
425
+ VectorType::Builder (resultType)
426
+ .setElementType (extendSourceVectorType.getElementType ());
427
+ Value newExtract = rewriter.create <vector::ScalableExtractOp>(
428
+ loc, extractResultVectorType, extendSource, extractOp.getPos ());
429
+
430
+ // Extend new extract to original result type.
431
+ Operation *newExtend =
432
+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
433
+
434
+ rewriter.replaceOp (extractOp, newExtend->getResult (0 ));
435
+
436
+ return success ();
437
+ }
438
+ };
439
+
341
440
struct VectorLegalizationPass
342
441
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
343
442
void runOnOperation () override {
@@ -358,6 +457,8 @@ struct VectorLegalizationPass
358
457
return success ();
359
458
});
360
459
460
+ patterns.add <SwapVectorExtractOfArithExtend,
461
+ SwapVectorScalableExtractOfArithExtend>(context);
361
462
// Note: High benefit to ensure masked outer products are lowered first.
362
463
patterns.add <LegalizeMaskedVectorOuterProductOpsByDecomposition>(
363
464
converter, context, 1024 );
0 commit comments