@@ -349,205 +349,6 @@ class YieldOpConversion : public ConvertToLLVMPattern {
349
349
};
350
350
} // namespace
351
351
352
- template <typename LinalgOp>
353
- static SmallVector<Type, 4 > ExtractOperandTypes (Operation *op) {
354
- return SmallVector<Type, 4 >{op->getOperandTypes ()};
355
- }
356
-
357
- template <>
358
- SmallVector<Type, 4 > ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
359
- auto ctx = op->getContext ();
360
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
361
- auto numLoops = indexedGenericOp.getNumLoops ();
362
-
363
- SmallVector<Type, 4 > result;
364
- result.reserve (numLoops + op->getNumOperands ());
365
- for (unsigned i = 0 ; i < numLoops; ++i) {
366
- result.push_back (IndexType::get (ctx));
367
- }
368
- for (auto type : op->getOperandTypes ()) {
369
- result.push_back (type);
370
- }
371
- return result;
372
- }
373
-
374
- // Get a SymbolRefAttr containing the library function name for the LinalgOp.
375
- // If the library function does not exist, insert a declaration.
376
- template <typename LinalgOp>
377
- static FlatSymbolRefAttr getLibraryCallSymbolRef (Operation *op,
378
- PatternRewriter &rewriter) {
379
- auto linalgOp = cast<LinalgOp>(op);
380
- auto fnName = linalgOp.getLibraryCallName ();
381
- if (fnName.empty ()) {
382
- op->emitWarning (" No library call defined for: " ) << *op;
383
- return {};
384
- }
385
-
386
- // fnName is a dynamic std::String, unique it via a SymbolRefAttr.
387
- FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr (fnName);
388
- auto module = op->getParentOfType <ModuleOp>();
389
- if (module .lookupSymbol (fnName)) {
390
- return fnNameAttr;
391
- }
392
-
393
- SmallVector<Type, 4 > inputTypes (ExtractOperandTypes<LinalgOp>(op));
394
- assert (op->getNumResults () == 0 &&
395
- " Library call for linalg operation can be generated only for ops that "
396
- " have void return types" );
397
- auto libFnType = FunctionType::get (inputTypes, {}, rewriter.getContext ());
398
-
399
- OpBuilder::InsertionGuard guard (rewriter);
400
- // Insert before module terminator.
401
- rewriter.setInsertionPoint (module .getBody (),
402
- std::prev (module .getBody ()->end ()));
403
- FuncOp funcOp =
404
- rewriter.create <FuncOp>(op->getLoc (), fnNameAttr.getValue (), libFnType,
405
- ArrayRef<NamedAttribute>{});
406
- // Insert a function attribute that will trigger the emission of the
407
- // corresponding `_mlir_ciface_xxx` interface so that external libraries see
408
- // a normalized ABI. This interface is added during std to llvm conversion.
409
- funcOp.setAttr (" llvm.emit_c_interface" , UnitAttr::get (op->getContext ()));
410
- return fnNameAttr;
411
- }
412
-
413
- namespace {
414
-
415
- // LinalgOpConversion<LinalgOp> creates a new call to the
416
- // `LinalgOp::getLibraryCallName()` function.
417
- // The implementation of the function can be either in the same module or in an
418
- // externally linked library.
419
- template <typename LinalgOp>
420
- class LinalgOpConversion : public OpRewritePattern <LinalgOp> {
421
- public:
422
- using OpRewritePattern<LinalgOp>::OpRewritePattern;
423
-
424
- LogicalResult matchAndRewrite (LinalgOp op,
425
- PatternRewriter &rewriter) const override {
426
- auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
427
- if (!libraryCallName)
428
- return failure ();
429
-
430
- rewriter.replaceOpWithNewOp <mlir::CallOp>(
431
- op, libraryCallName.getValue (), ArrayRef<Type>{}, op.getOperands ());
432
- return success ();
433
- }
434
- };
435
-
436
- // / Conversion pattern specialization for CopyOp. This kicks in when both input
437
- // / and output permutations are left unspecified or are the identity.
438
- template <> class LinalgOpConversion <CopyOp> : public OpRewritePattern<CopyOp> {
439
- public:
440
- using OpRewritePattern<CopyOp>::OpRewritePattern;
441
-
442
- LogicalResult matchAndRewrite (CopyOp op,
443
- PatternRewriter &rewriter) const override {
444
- auto inputPerm = op.inputPermutation ();
445
- if (inputPerm.hasValue () && !inputPerm->isIdentity ())
446
- return failure ();
447
- auto outputPerm = op.outputPermutation ();
448
- if (outputPerm.hasValue () && !outputPerm->isIdentity ())
449
- return failure ();
450
-
451
- auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
452
- if (!libraryCallName)
453
- return failure ();
454
-
455
- rewriter.replaceOpWithNewOp <mlir::CallOp>(
456
- op, libraryCallName.getValue (), ArrayRef<Type>{}, op.getOperands ());
457
- return success ();
458
- }
459
- };
460
-
461
- // / Conversion pattern specialization for IndexedGenericOp.
462
- template <>
463
- class LinalgOpConversion <IndexedGenericOp>
464
- : public OpRewritePattern<IndexedGenericOp> {
465
- public:
466
- using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
467
-
468
- LogicalResult matchAndRewrite (IndexedGenericOp op,
469
- PatternRewriter &rewriter) const override {
470
- auto libraryCallName =
471
- getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
472
- if (!libraryCallName)
473
- return failure ();
474
-
475
- // TODO(pifon, ntv): Use induction variables values instead of zeros, when
476
- // IndexedGenericOp is tiled.
477
- auto zero = rewriter.create <mlir::ConstantOp>(
478
- op.getLoc (), rewriter.getIntegerAttr (rewriter.getIndexType (), 0 ));
479
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
480
- auto numLoops = indexedGenericOp.getNumLoops ();
481
- SmallVector<Value, 4 > operands;
482
- operands.reserve (numLoops + op.getNumOperands ());
483
- for (unsigned i = 0 ; i < numLoops; ++i) {
484
- operands.push_back (zero);
485
- }
486
- for (auto operand : op.getOperands ()) {
487
- operands.push_back (operand);
488
- }
489
- rewriter.replaceOpWithNewOp <mlir::CallOp>(op, libraryCallName.getValue (),
490
- ArrayRef<Type>{}, operands);
491
- return success ();
492
- }
493
- };
494
-
495
- // / A non-conversion rewrite pattern kicks in to convert CopyOp with
496
- // / permutations into a sequence of TransposeOp and permutation-free CopyOp.
497
- // / This interplays together with TransposeOpConversion and
498
- // / LinalgConversion<CopyOp> to create a path to the LLVM dialect.
499
- class CopyTransposeConversion : public OpRewritePattern <CopyOp> {
500
- public:
501
- using OpRewritePattern<CopyOp>::OpRewritePattern;
502
-
503
- LogicalResult matchAndRewrite (CopyOp op,
504
- PatternRewriter &rewriter) const override {
505
- Value in = op.input (), out = op.output ();
506
-
507
- // If either inputPerm or outputPerm are non-identities, insert transposes.
508
- auto inputPerm = op.inputPermutation ();
509
- if (inputPerm.hasValue () && !inputPerm->isIdentity ())
510
- in = rewriter.create <linalg::TransposeOp>(op.getLoc (), in,
511
- AffineMapAttr::get (*inputPerm));
512
- auto outputPerm = op.outputPermutation ();
513
- if (outputPerm.hasValue () && !outputPerm->isIdentity ())
514
- out = rewriter.create <linalg::TransposeOp>(
515
- op.getLoc (), out, AffineMapAttr::get (*outputPerm));
516
-
517
- // If nothing was transposed, fail and let the conversion kick in.
518
- if (in == op.input () && out == op.output ())
519
- return failure ();
520
-
521
- rewriter.replaceOpWithNewOp <CopyOp>(op, in, out);
522
- return success ();
523
- }
524
- };
525
-
526
- // / Populate the given list with patterns that convert from Linalg to Standard.
527
- static void
528
- populateLinalgToStandardConversionPatterns (OwningRewritePatternList &patterns,
529
- MLIRContext *ctx) {
530
- // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
531
- // attribute values such as kernel striding and dilation.
532
- // clang-format off
533
- patterns.insert <
534
- CopyTransposeConversion,
535
- LinalgOpConversion<ConvOp>,
536
- LinalgOpConversion<PoolingMaxOp>,
537
- LinalgOpConversion<PoolingMinOp>,
538
- LinalgOpConversion<PoolingSumOp>,
539
- LinalgOpConversion<CopyOp>,
540
- LinalgOpConversion<DotOp>,
541
- LinalgOpConversion<FillOp>,
542
- LinalgOpConversion<GenericOp>,
543
- LinalgOpConversion<IndexedGenericOp>,
544
- LinalgOpConversion<MatmulOp>,
545
- LinalgOpConversion<MatvecOp>>(ctx);
546
- // clang-format on
547
- }
548
-
549
- } // namespace
550
-
551
352
// / Populate the given list with patterns that convert from Linalg to LLVM.
552
353
void mlir::populateLinalgToLLVMConversionPatterns (
553
354
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
@@ -579,7 +380,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
579
380
populateVectorToLoopsConversionPatterns (patterns, &getContext ());
580
381
populateVectorToLLVMMatrixConversionPatterns (converter, patterns);
581
382
populateVectorToLLVMConversionPatterns (converter, patterns);
582
- populateLinalgToStandardConversionPatterns (patterns, &getContext ());
583
383
populateLinalgToLLVMConversionPatterns (converter, patterns, &getContext ());
584
384
585
385
LLVMConversionTarget target (getContext ());
0 commit comments