@@ -100,14 +100,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
100
100
// Convert ops in target-specific patterns.
101
101
mod.walk ([&](mlir::Operation *op) {
102
102
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
103
- if (!hasPortableSignature (call.getFunctionType ()))
103
+ if (!hasPortableSignature (call.getFunctionType (), op ))
104
104
convertCallOp (call);
105
105
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
106
- if (!hasPortableSignature (dispatch.getFunctionType ()))
106
+ if (!hasPortableSignature (dispatch.getFunctionType (), op ))
107
107
convertCallOp (dispatch);
108
108
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
109
109
if (addr.getType ().isa <mlir::FunctionType>() &&
110
- !hasPortableSignature (addr.getType ()))
110
+ !hasPortableSignature (addr.getType (), op ))
111
111
convertAddrOp (addr);
112
112
}
113
113
});
@@ -443,19 +443,23 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
443
443
// / then it is considered portable for any target, and this function will
444
444
// / return `true`. Otherwise, the signature is not portable and `false` is
445
445
// / returned.
446
- bool hasPortableSignature (mlir::Type signature) {
446
+ bool hasPortableSignature (mlir::Type signature, mlir::Operation *op ) {
447
447
assert (signature.isa <mlir::FunctionType>());
448
448
auto func = signature.dyn_cast <mlir::FunctionType>();
449
+ bool hasFirRuntime = op->hasAttrOfType <mlir::UnitAttr>(
450
+ fir::FIROpsDialect::getFirRuntimeAttrName ());
449
451
for (auto ty : func.getResults ())
450
452
if ((ty.isa <fir::BoxCharType>() && !noCharacterConversion) ||
451
- (fir::isa_complex (ty) && !noComplexConversion)) {
453
+ (fir::isa_complex (ty) && !noComplexConversion) ||
454
+ (ty.isa <mlir::IntegerType>() && hasFirRuntime)) {
452
455
LLVM_DEBUG (llvm::dbgs () << " rewrite " << signature << " for target\n " );
453
456
return false ;
454
457
}
455
458
for (auto ty : func.getInputs ())
456
459
if (((ty.isa <fir::BoxCharType>() || fir::isCharacterProcedureTuple (ty)) &&
457
460
!noCharacterConversion) ||
458
- (fir::isa_complex (ty) && !noComplexConversion)) {
461
+ (fir::isa_complex (ty) && !noComplexConversion) ||
462
+ (ty.isa <mlir::IntegerType>() && hasFirRuntime)) {
459
463
LLVM_DEBUG (llvm::dbgs () << " rewrite " << signature << " for target\n " );
460
464
return false ;
461
465
}
@@ -476,13 +480,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
476
480
// / the immediately subsequent target code gen.
477
481
void convertSignature (mlir::func::FuncOp func) {
478
482
auto funcTy = func.getFunctionType ().cast <mlir::FunctionType>();
479
- if (hasPortableSignature (funcTy) && !hasHostAssociations (func))
483
+ if (hasPortableSignature (funcTy, func ) && !hasHostAssociations (func))
480
484
return ;
481
485
llvm::SmallVector<mlir::Type> newResTys;
482
486
llvm::SmallVector<mlir::Type> newInTys;
483
487
llvm::SmallVector<std::pair<unsigned , mlir::NamedAttribute>> savedAttrs;
484
488
llvm::SmallVector<std::pair<unsigned , mlir::NamedAttribute>> extraAttrs;
485
489
llvm::SmallVector<FixupTy> fixups;
490
+ llvm::SmallVector<std::pair<unsigned , mlir::NamedAttrList>, 1 > resultAttrs;
486
491
487
492
// Save argument attributes in case there is a shift so we can replace them
488
493
// correctly.
@@ -509,6 +514,22 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
509
514
else
510
515
doComplexReturn (func, cmplx, newResTys, newInTys, fixups);
511
516
})
517
+ .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
518
+ auto m = specifics->integerArgumentType (func.getLoc (), intTy);
519
+ assert (m.size () == 1 );
520
+ auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
521
+ auto retTy = std::get<mlir::Type>(m[0 ]);
522
+ std::size_t resId = newResTys.size ();
523
+ llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName ();
524
+ if (!extensionAttrName.empty () &&
525
+ // TODO: we have to do the same for BIND(C) routines.
526
+ func->hasAttrOfType <mlir::UnitAttr>(
527
+ fir::FIROpsDialect::getFirRuntimeAttrName ()))
528
+ resultAttrs.emplace_back (
529
+ resId, rewriter->getNamedAttr (extensionAttrName,
530
+ rewriter->getUnitAttr ()));
531
+ newResTys.push_back (retTy);
532
+ })
512
533
.Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
513
534
514
535
// Saved potential shift in argument. Handling of result can add arguments
@@ -572,6 +593,26 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
572
593
newInTys.push_back (ty);
573
594
}
574
595
})
596
+ .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
597
+ auto m = specifics->integerArgumentType (func.getLoc (), intTy);
598
+ assert (m.size () == 1 );
599
+ auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
600
+ auto argTy = std::get<mlir::Type>(m[0 ]);
601
+ auto argNo = newInTys.size ();
602
+ llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName ();
603
+ if (!extensionAttrName.empty () &&
604
+ // TODO: we have to do the same for BIND(C) routines.
605
+ func->hasAttrOfType <mlir::UnitAttr>(
606
+ fir::FIROpsDialect::getFirRuntimeAttrName ())) {
607
+ fixups.emplace_back (FixupTy::Codes::ArgumentType, argNo,
608
+ [=](mlir::func::FuncOp func) {
609
+ func.setArgAttr (
610
+ argNo, extensionAttrName,
611
+ mlir::UnitAttr::get (func.getContext ()));
612
+ });
613
+ }
614
+ newInTys.push_back (argTy);
615
+ })
575
616
.Default ([&](mlir::Type ty) { newInTys.push_back (ty); });
576
617
577
618
if (func.getArgAttrOfType <mlir::UnitAttr>(index,
@@ -608,14 +649,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
608
649
case FixupTy::Codes::ArgumentType: {
609
650
// Argument is pass-by-value, but its type has likely been modified to
610
651
// suit the target ABI convention.
652
+ auto oldArgTy =
653
+ fir::ReferenceType::get (oldArgTys[fixup.index - offset]);
654
+ // If type did not change, keep the original argument.
655
+ if (newInTys[fixup.index ] == oldArgTy)
656
+ break ;
657
+
611
658
auto newArg = func.front ().insertArgument (fixup.index ,
612
659
newInTys[fixup.index ], loc);
613
660
rewriter->setInsertionPointToStart (&func.front ());
614
661
auto mem =
615
662
rewriter->create <fir::AllocaOp>(loc, newInTys[fixup.index ]);
616
663
rewriter->create <fir::StoreOp>(loc, newArg, mem);
617
- auto oldArgTy =
618
- fir::ReferenceType::get (oldArgTys[fixup.index - offset]);
619
664
auto cast = rewriter->create <fir::ConvertOp>(loc, oldArgTy, mem);
620
665
mlir::Value load = rewriter->create <fir::LoadOp>(loc, cast);
621
666
func.getArgument (fixup.index + 1 ).replaceAllUsesWith (load);
@@ -744,6 +789,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
744
789
func.setArgAttr (extraAttr.first , extraAttr.second .getName (),
745
790
extraAttr.second .getValue ());
746
791
792
+ for (auto [resId, resAttrList] : resultAttrs)
793
+ for (mlir::NamedAttribute resAttr : resAttrList)
794
+ func.setResultAttr (resId, resAttr.getName (), resAttr.getValue ());
795
+
747
796
// Replace attributes to the correct argument if there was an argument shift
748
797
// to the right.
749
798
if (argumentShift > 0 ) {
0 commit comments