@@ -261,6 +261,104 @@ class OuterProductFusion2Way
261
261
}
262
262
};
263
263
264
+ // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
265
+ //
266
+ // This transforms IR like:
267
+ // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
268
+ // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
269
+ // Into:
270
+ // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
271
+ // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
272
+ //
273
+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
274
+ // pass when the result is the input to an outer product.
275
+ struct SwapVectorExtractOfArithExtend
276
+ : public OpRewritePattern<vector::ExtractOp> {
277
+ using OpRewritePattern::OpRewritePattern;
278
+
279
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
280
+ PatternRewriter &rewriter) const override {
281
+ VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType ());
282
+ if (!resultType)
283
+ return rewriter.notifyMatchFailure (extractOp,
284
+ " extracted type is not a vector type" );
285
+
286
+ auto numScalableDims = llvm::count (resultType.getScalableDims (), true );
287
+ if (numScalableDims != 1 )
288
+ return rewriter.notifyMatchFailure (
289
+ extractOp, " extracted type is not a 1-D scalable vector type" );
290
+
291
+ auto *extendOp = extractOp.getVector ().getDefiningOp ();
292
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
293
+ extendOp))
294
+ return rewriter.notifyMatchFailure (extractOp,
295
+ " extract not from extend op" );
296
+
297
+ auto loc = extractOp.getLoc ();
298
+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
299
+ Value extendSource = extendOp->getOperand (0 );
300
+
301
+ // Create new extract from source of extend.
302
+ Value newExtract = rewriter.create <vector::ExtractOp>(
303
+ loc, extendSource, extractOp.getMixedPosition ());
304
+
305
+ // Extend new extract to original result type.
306
+ Operation *newExtend =
307
+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
308
+
309
+ rewriter.replaceOp (extractOp, newExtend);
310
+
311
+ return success ();
312
+ }
313
+ };
314
+
315
+ // Same as above, but for vector.scalable.extract.
316
+ //
317
+ // This transforms IR like:
318
+ // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
319
+ // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
320
+ // Into:
321
+ // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
322
+ // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
323
+ //
324
+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
325
+ // pass when the result is the input to an outer product.
326
+ struct SwapVectorScalableExtractOfArithExtend
327
+ : public OpRewritePattern<vector::ScalableExtractOp> {
328
+ using OpRewritePattern::OpRewritePattern;
329
+
330
+ LogicalResult matchAndRewrite (vector::ScalableExtractOp extractOp,
331
+ PatternRewriter &rewriter) const override {
332
+ auto *extendOp = extractOp.getSource ().getDefiningOp ();
333
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
334
+ extendOp))
335
+ return rewriter.notifyMatchFailure (extractOp,
336
+ " extract not from extend op" );
337
+
338
+ auto loc = extractOp.getLoc ();
339
+ VectorType resultType = extractOp.getResultVectorType ();
340
+
341
+ Value extendSource = extendOp->getOperand (0 );
342
+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
343
+ VectorType extendSourceVectorType =
344
+ cast<VectorType>(extendSource.getType ());
345
+
346
+ // Create new extract from source of extend.
347
+ VectorType extractResultVectorType =
348
+ resultType.clone (extendSourceVectorType.getElementType ());
349
+ Value newExtract = rewriter.create <vector::ScalableExtractOp>(
350
+ loc, extractResultVectorType, extendSource, extractOp.getPos ());
351
+
352
+ // Extend new extract to original result type.
353
+ Operation *newExtend =
354
+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
355
+
356
+ rewriter.replaceOp (extractOp, newExtend);
357
+
358
+ return success ();
359
+ }
360
+ };
361
+
264
362
struct OuterProductFusionPass
265
363
: public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
266
364
@@ -278,7 +376,11 @@ struct OuterProductFusionPass
278
376
279
377
void mlir::arm_sme::populateOuterProductFusionPatterns (
280
378
RewritePatternSet &patterns) {
281
- patterns.add <OuterProductFusion2Way>(patterns.getContext ());
379
+ MLIRContext *context = patterns.getContext ();
380
+ // Note: High benefit to ensure extract(extend) are swapped first.
381
+ patterns.add <SwapVectorExtractOfArithExtend,
382
+ SwapVectorScalableExtractOfArithExtend>(context, 1024 );
383
+ patterns.add <OuterProductFusion2Way>(context);
282
384
}
283
385
284
386
std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass () {
0 commit comments