@@ -142,20 +142,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
142
142
143
143
mlir::ModuleOp getModule () { return getOperation (); }
144
144
145
- template <typename A , typename B, typename C >
145
+ template <typename Ty , typename Callback >
146
146
std::optional<std::function<mlir::Value(mlir::Operation *)>>
147
- rewriteCallComplexResultType (
148
- mlir::Location loc, A ty, B &newResTys,
149
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
150
- mlir::Value &savedStackPtr) {
151
- if (noComplexConversion) {
152
- newResTys.push_back (ty);
153
- return std::nullopt;
154
- }
155
- auto m = specifics->complexReturnType (loc, ty.getElementType ());
156
- // Currently targets mandate COMPLEX is a single aggregate or packed
157
- // scalar, including the sret case.
158
- assert (m.size () == 1 && " target of complex return not supported" );
147
+ rewriteCallResultType (mlir::Location loc, mlir::Type originalResTy,
148
+ Ty &newResTys,
149
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
150
+ Callback &newOpers, mlir::Value &savedStackPtr,
151
+ fir::CodeGenSpecifics::Marshalling &m) {
152
+ // Currently, targets mandate COMPLEX or STRUCT is a single aggregate or
153
+ // packed scalar, including the sret case.
154
+ assert (m.size () == 1 && " return type not supported on this target" );
159
155
auto resTy = std::get<mlir::Type>(m[0 ]);
160
156
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
161
157
if (attr.isSRet ()) {
@@ -170,7 +166,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
170
166
newInTyAndAttrs.push_back (m[0 ]);
171
167
newOpers.push_back (stack);
172
168
return [=](mlir::Operation *) -> mlir::Value {
173
- auto memTy = fir::ReferenceType::get (ty );
169
+ auto memTy = fir::ReferenceType::get (originalResTy );
174
170
auto cast = rewriter->create <fir::ConvertOp>(loc, memTy, stack);
175
171
return rewriter->create <fir::LoadOp>(loc, cast);
176
172
};
@@ -180,11 +176,41 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
180
176
// We are going to generate an alloca, so save the stack pointer.
181
177
if (!savedStackPtr)
182
178
savedStackPtr = genStackSave (loc);
183
- return this ->convertValueInMemory (loc, call->getResult (0 ), ty ,
179
+ return this ->convertValueInMemory (loc, call->getResult (0 ), originalResTy ,
184
180
/* inputMayBeBigger=*/ true );
185
181
};
186
182
}
187
183
184
+ template <typename Ty, typename Callback>
185
+ std::optional<std::function<mlir::Value(mlir::Operation *)>>
186
+ rewriteCallComplexResultType (
187
+ mlir::Location loc, mlir::ComplexType ty, Ty &newResTys,
188
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
189
+ mlir::Value &savedStackPtr) {
190
+ if (noComplexConversion) {
191
+ newResTys.push_back (ty);
192
+ return std::nullopt;
193
+ }
194
+ auto m = specifics->complexReturnType (loc, ty.getElementType ());
195
+ return rewriteCallResultType (loc, ty, newResTys, newInTyAndAttrs, newOpers,
196
+ savedStackPtr, m);
197
+ }
198
+
199
+ template <typename Ty, typename Callback>
200
+ std::optional<std::function<mlir::Value(mlir::Operation *)>>
201
+ rewriteCallStructResultType (
202
+ mlir::Location loc, fir::RecordType recTy, Ty &newResTys,
203
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
204
+ mlir::Value &savedStackPtr) {
205
+ if (noStructConversion) {
206
+ newResTys.push_back (recTy);
207
+ return std::nullopt;
208
+ }
209
+ auto m = specifics->structReturnType (loc, recTy);
210
+ return rewriteCallResultType (loc, recTy, newResTys, newInTyAndAttrs,
211
+ newOpers, savedStackPtr, m);
212
+ }
213
+
188
214
void passArgumentOnStackOrWithNewType (
189
215
mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
190
216
mlir::Type oldType, mlir::Value oper,
@@ -356,6 +382,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
356
382
newInTyAndAttrs, newOpers,
357
383
savedStackPtr);
358
384
})
385
+ .template Case <fir::RecordType>([&](fir::RecordType recTy) {
386
+ wrap = rewriteCallStructResultType (loc, recTy, newResTys,
387
+ newInTyAndAttrs, newOpers,
388
+ savedStackPtr);
389
+ })
359
390
.Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
360
391
} else if (fnTy.getResults ().size () > 1 ) {
361
392
TODO (loc, " multiple results not supported yet" );
@@ -562,6 +593,24 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
562
593
}
563
594
}
564
595
596
+ template <typename Ty>
597
+ void
598
+ lowerStructSignatureRes (mlir::Location loc, fir::RecordType recTy,
599
+ Ty &newResTys,
600
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
601
+ if (noComplexConversion) {
602
+ newResTys.push_back (recTy);
603
+ return ;
604
+ } else {
605
+ for (auto &tup : specifics->structReturnType (loc, recTy)) {
606
+ if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet ())
607
+ newInTyAndAttrs.push_back (tup);
608
+ else
609
+ newResTys.push_back (std::get<mlir::Type>(tup));
610
+ }
611
+ }
612
+ }
613
+
565
614
void
566
615
lowerStructSignatureArg (mlir::Location loc, fir::RecordType recTy,
567
616
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
@@ -595,6 +644,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
595
644
.Case <mlir::ComplexType>([&](mlir::ComplexType ty) {
596
645
lowerComplexSignatureRes (loc, ty, newResTys, newInTyAndAttrs);
597
646
})
647
+ .Case <fir::RecordType>([&](fir::RecordType ty) {
648
+ lowerStructSignatureRes (loc, ty, newResTys, newInTyAndAttrs);
649
+ })
598
650
.Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
599
651
}
600
652
llvm::SmallVector<mlir::Type> trailingInTys;
@@ -696,7 +748,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
696
748
for (auto ty : func.getResults ())
697
749
if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) ||
698
750
(fir::isa_complex (ty) && !noComplexConversion) ||
699
- (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv)) {
751
+ (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
752
+ (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
700
753
LLVM_DEBUG (llvm::dbgs () << " rewrite " << signature << " for target\n " );
701
754
return false ;
702
755
}
@@ -770,6 +823,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770
823
rewriter->getUnitAttr ()));
771
824
newResTys.push_back (retTy);
772
825
})
826
+ .Case <fir::RecordType>([&](fir::RecordType recTy) {
827
+ doStructReturn (func, recTy, newResTys, newInTyAndAttrs, fixups);
828
+ })
773
829
.Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
774
830
775
831
// Saved potential shift in argument. Handling of result can add arguments
@@ -1062,21 +1118,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1062
1118
return false ;
1063
1119
}
1064
1120
1065
- // / Convert a complex return value. This can involve converting the return
1066
- // / value to a "hidden" first argument or packing the complex into a wide
1067
- // / GPR.
1068
1121
template <typename Ty, typename FIXUPS>
1069
- void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1070
- Ty &newResTys,
1071
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1072
- FIXUPS &fixups) {
1073
- if (noComplexConversion) {
1074
- newResTys.push_back (cmplx);
1075
- return ;
1076
- }
1077
- auto m =
1078
- specifics->complexReturnType (func.getLoc (), cmplx.getElementType ());
1079
- assert (m.size () == 1 );
1122
+ void doReturn (mlir::func::FuncOp func, Ty &newResTys,
1123
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1124
+ FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
1125
+ assert (m.size () == 1 &&
1126
+ " expect result to be turned into single argument or result so far" );
1080
1127
auto &tup = m[0 ];
1081
1128
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
1082
1129
auto argTy = std::get<mlir::Type>(tup);
@@ -1117,6 +1164,36 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
1117
1164
newResTys.push_back (argTy);
1118
1165
}
1119
1166
1167
+ // / Convert a complex return value. This can involve converting the return
1168
+ // / value to a "hidden" first argument or packing the complex into a wide
1169
+ // / GPR.
1170
+ template <typename Ty, typename FIXUPS>
1171
+ void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1172
+ Ty &newResTys,
1173
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1174
+ FIXUPS &fixups) {
1175
+ if (noComplexConversion) {
1176
+ newResTys.push_back (cmplx);
1177
+ return ;
1178
+ }
1179
+ auto m =
1180
+ specifics->complexReturnType (func.getLoc (), cmplx.getElementType ());
1181
+ doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
1182
+ }
1183
+
1184
+ template <typename Ty, typename FIXUPS>
1185
+ void doStructReturn (mlir::func::FuncOp func, fir::RecordType recTy,
1186
+ Ty &newResTys,
1187
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1188
+ FIXUPS &fixups) {
1189
+ if (noStructConversion) {
1190
+ newResTys.push_back (recTy);
1191
+ return ;
1192
+ }
1193
+ auto m = specifics->structReturnType (func.getLoc (), recTy);
1194
+ doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
1195
+ }
1196
+
1120
1197
template <typename FIXUPS>
1121
1198
void
1122
1199
createFuncOpArgFixups (mlir::func::FuncOp func,
0 commit comments