@@ -180,11 +180,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
180
180
// We are going to generate an alloca, so save the stack pointer.
181
181
if (!savedStackPtr)
182
182
savedStackPtr = genStackSave (loc);
183
- auto mem = rewriter->create <fir::AllocaOp>(loc, resTy);
184
- rewriter->create <fir::StoreOp>(loc, call->getResult (0 ), mem);
185
- auto memTy = fir::ReferenceType::get (ty);
186
- auto cast = rewriter->create <fir::ConvertOp>(loc, memTy, mem);
187
- return rewriter->create <fir::LoadOp>(loc, cast);
183
+ return this ->convertValueInMemory (loc, call->getResult (0 ), ty,
184
+ /* inputMayBeBigger=*/ true );
188
185
};
189
186
}
190
187
@@ -195,7 +192,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
195
192
mlir::Value &savedStackPtr) {
196
193
auto resTy = std::get<mlir::Type>(newTypeAndAttr);
197
194
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
198
- auto oldRefTy = fir::ReferenceType::get (oldType);
199
195
// We are going to generate an alloca, so save the stack pointer.
200
196
if (!savedStackPtr)
201
197
savedStackPtr = genStackSave (loc);
@@ -206,11 +202,83 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
206
202
mem = rewriter->create <fir::ConvertOp>(loc, resTy, mem);
207
203
newOpers.push_back (mem);
208
204
} else {
209
- auto mem = rewriter->create <fir::AllocaOp>(loc, resTy);
205
+ mlir::Value bitcast =
206
+ convertValueInMemory (loc, oper, resTy, /* inputMayBeBigger=*/ false );
207
+ newOpers.push_back (bitcast);
208
+ }
209
+ }
210
+
211
+ // Do a bitcast (convert a value via its memory representation).
212
+ // The input and output types may have different storage sizes,
213
+ // "inputMayBeBigger" should be set to indicate which of the input or
214
+ // output type may be bigger in order for the load/store to be safe.
215
+ // The mismatch comes from the fact that the LLVM register used for passing
216
+ // may be bigger than the value being passed (e.g., passing
217
+ // a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register).
218
+ mlir::Value convertValueInMemory (mlir::Location loc, mlir::Value value,
219
+ mlir::Type newType, bool inputMayBeBigger) {
220
+ if (inputMayBeBigger) {
221
+ auto newRefTy = fir::ReferenceType::get (newType);
222
+ auto mem = rewriter->create <fir::AllocaOp>(loc, value.getType ());
223
+ rewriter->create <fir::StoreOp>(loc, value, mem);
224
+ auto cast = rewriter->create <fir::ConvertOp>(loc, newRefTy, mem);
225
+ return rewriter->create <fir::LoadOp>(loc, cast);
226
+ } else {
227
+ auto oldRefTy = fir::ReferenceType::get (value.getType ());
228
+ auto mem = rewriter->create <fir::AllocaOp>(loc, newType);
210
229
auto cast = rewriter->create <fir::ConvertOp>(loc, oldRefTy, mem);
211
- rewriter->create <fir::StoreOp>(loc, oper, cast);
212
- newOpers.push_back (rewriter->create <fir::LoadOp>(loc, mem));
230
+ rewriter->create <fir::StoreOp>(loc, value, cast);
231
+ return rewriter->create <fir::LoadOp>(loc, mem);
232
+ }
233
+ }
234
+
235
+ void passSplitArgument (mlir::Location loc,
236
+ fir::CodeGenSpecifics::Marshalling splitArgs,
237
+ mlir::Type oldType, mlir::Value oper,
238
+ llvm::SmallVectorImpl<mlir::Value> &newOpers,
239
+ mlir::Value &savedStackPtr) {
240
+ // COMPLEX or struct argument split into separate arguments
241
+ if (!fir::isa_complex (oldType)) {
242
+ // Cast original operand to a tuple of the new arguments
243
+ // via memory.
244
+ llvm::SmallVector<mlir::Type> partTypes;
245
+ for (auto argPart : splitArgs)
246
+ partTypes.push_back (std::get<mlir::Type>(argPart));
247
+ mlir::Type tupleType =
248
+ mlir::TupleType::get (oldType.getContext (), partTypes);
249
+ if (!savedStackPtr)
250
+ savedStackPtr = genStackSave (loc);
251
+ oper = convertValueInMemory (loc, oper, tupleType,
252
+ /* inputMayBeBigger=*/ false );
253
+ }
254
+ auto iTy = rewriter->getIntegerType (32 );
255
+ for (auto e : llvm::enumerate (splitArgs)) {
256
+ auto &tup = e.value ();
257
+ auto ty = std::get<mlir::Type>(tup);
258
+ auto index = e.index ();
259
+ auto idx = rewriter->getIntegerAttr (iTy, index);
260
+ auto val = rewriter->create <fir::ExtractValueOp>(
261
+ loc, ty, oper, rewriter->getArrayAttr (idx));
262
+ newOpers.push_back (val);
263
+ }
264
+ }
265
+
266
+ void rewriteCallOperands (
267
+ mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs,
268
+ mlir::Type originalArgTy, mlir::Value oper,
269
+ llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr,
270
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
271
+ if (passArgAs.size () == 1 ) {
272
+ // COMPLEX or derived type is passed as a single argument.
273
+ passArgumentOnStackOrWithNewType (loc, passArgAs[0 ], originalArgTy, oper,
274
+ newOpers, savedStackPtr);
275
+ } else {
276
+ // COMPLEX or derived type is split into separate arguments
277
+ passSplitArgument (loc, passArgAs, originalArgTy, oper, newOpers,
278
+ savedStackPtr);
213
279
}
280
+ newInTyAndAttrs.insert (newInTyAndAttrs.end (), passArgAs.begin (),
281
+ passArgAs.end ());
214
282
}
215
283
216
284
template <typename CPLX>
@@ -224,28 +292,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
224
292
newOpers.push_back (oper);
225
293
return ;
226
294
}
227
-
228
295
auto m = specifics->complexArgumentType (loc, ty.getElementType ());
229
- if (m.size () == 1 ) {
230
- // COMPLEX is a single aggregate
231
- passArgumentOnStackOrWithNewType (loc, m[0 ], ty, oper, newOpers,
232
- savedStackPtr);
233
- newInTyAndAttrs.push_back (m[0 ]);
234
- } else {
235
- assert (m.size () == 2 );
236
- // COMPLEX is split into 2 separate arguments
237
- auto iTy = rewriter->getIntegerType (32 );
238
- for (auto e : llvm::enumerate (m)) {
239
- auto &tup = e.value ();
240
- auto ty = std::get<mlir::Type>(tup);
241
- auto index = e.index ();
242
- auto idx = rewriter->getIntegerAttr (iTy, index);
243
- auto val = rewriter->create <fir::ExtractValueOp>(
244
- loc, ty, oper, rewriter->getArrayAttr (idx));
245
- newInTyAndAttrs.push_back (tup);
246
- newOpers.push_back (val);
247
- }
248
- }
296
+ rewriteCallOperands (loc, m, ty, oper, newOpers, savedStackPtr,
297
+ newInTyAndAttrs);
249
298
}
250
299
251
300
void rewriteCallStructInputType (
@@ -260,11 +309,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
260
309
}
261
310
auto structArgs =
262
311
specifics->structArgumentType (loc, recTy, newInTyAndAttrs);
263
- if (structArgs.size () != 1 )
264
- TODO (loc, " splitting BIND(C), VALUE derived type into several arguments" );
265
- passArgumentOnStackOrWithNewType (loc, structArgs[0 ], recTy, oper, newOpers,
266
- savedStackPtr);
267
- structArgs.push_back (structArgs[0 ]);
312
+ rewriteCallOperands (loc, structArgs, recTy, oper, newOpers, savedStackPtr,
313
+ newInTyAndAttrs);
268
314
}
269
315
270
316
static bool hasByValOrSRetArgs (
@@ -849,20 +895,17 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
849
895
case FixupTy::Codes::ArgumentType: {
850
896
// Argument is pass-by-value, but its type has likely been modified to
851
897
// suit the target ABI convention.
852
- auto oldArgTy =
853
- fir::ReferenceType::get (oldArgTys[fixup.index - offset]);
898
+ auto oldArgTy = oldArgTys[fixup.index - offset];
854
899
// If type did not change, keep the original argument.
855
900
if (fixupType == oldArgTy)
856
901
break ;
857
902
858
903
auto newArg =
859
904
func.front ().insertArgument (fixup.index , fixupType, loc);
860
905
rewriter->setInsertionPointToStart (&func.front ());
861
- auto mem = rewriter->create <fir::AllocaOp>(loc, fixupType);
862
- rewriter->create <fir::StoreOp>(loc, newArg, mem);
863
- auto cast = rewriter->create <fir::ConvertOp>(loc, oldArgTy, mem);
864
- mlir::Value load = rewriter->create <fir::LoadOp>(loc, cast);
865
- func.getArgument (fixup.index + 1 ).replaceAllUsesWith (load);
906
+ mlir::Value bitcast = convertValueInMemory (loc, newArg, oldArgTy,
907
+ /* inputMayBeBigger=*/ true );
908
+ func.getArgument (fixup.index + 1 ).replaceAllUsesWith (bitcast);
866
909
func.front ().eraseArgument (fixup.index + 1 );
867
910
LLVM_DEBUG (llvm::dbgs ()
868
911
<< " old argument: " << oldArgTy.getEleTy ()
@@ -907,34 +950,43 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
907
950
func.walk ([&](mlir::func::ReturnOp ret) {
908
951
rewriter->setInsertionPoint (ret);
909
952
auto oldOper = ret.getOperand (0 );
910
- auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
911
- auto mem =
912
- rewriter->create <fir::AllocaOp>(loc, newResTys[fixup.index ]);
913
- auto cast = rewriter->create <fir::ConvertOp>(loc, oldOperTy, mem);
914
- rewriter->create <fir::StoreOp>(loc, oldOper, cast);
915
- mlir::Value load = rewriter->create <fir::LoadOp>(loc, mem);
916
- rewriter->create <mlir::func::ReturnOp>(loc, load);
953
+ mlir::Value bitcast =
954
+ convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
955
+ /* inputMayBeBigger=*/ false );
956
+ rewriter->create <mlir::func::ReturnOp>(loc, bitcast);
917
957
ret.erase ();
918
958
});
919
959
} break ;
920
960
case FixupTy::Codes::Split: {
921
961
// The FIR argument has been split into a pair of distinct arguments
922
- // that are in juxtaposition to each other. (For COMPLEX value.)
962
+ // that are in juxtaposition to each other. (For COMPLEX value or
963
+ // derived type passed with VALUE in BIND(C) context).
923
964
auto newArg =
924
965
func.front ().insertArgument (fixup.index , fixupType, loc);
925
966
if (fixup.second == 1 ) {
926
967
rewriter->setInsertionPointToStart (&func.front ());
927
- auto cplxTy = oldArgTys[fixup.index - offset - fixup.second ];
928
- auto undef = rewriter->create <fir::UndefOp>(loc, cplxTy);
968
+ mlir::Value firstArg = func.front ().getArgument (fixup.index - 1 );
969
+ mlir::Type originalTy =
970
+ oldArgTys[fixup.index - offset - fixup.second ];
971
+ mlir::Type pairTy = originalTy;
972
+ if (!fir::isa_complex (originalTy)) {
973
+ pairTy = mlir::TupleType::get (
974
+ originalTy.getContext (),
975
+ mlir::TypeRange{firstArg.getType (), newArg.getType ()});
976
+ }
977
+ auto undef = rewriter->create <fir::UndefOp>(loc, pairTy);
929
978
auto iTy = rewriter->getIntegerType (32 );
930
979
auto zero = rewriter->getIntegerAttr (iTy, 0 );
931
980
auto one = rewriter->getIntegerAttr (iTy, 1 );
932
- auto cplx1 = rewriter->create <fir::InsertValueOp>(
933
- loc, cplxTy, undef, func.front ().getArgument (fixup.index - 1 ),
934
- rewriter->getArrayAttr (zero));
935
- auto cplx = rewriter->create <fir::InsertValueOp>(
936
- loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr (one));
937
- func.getArgument (fixup.index + 1 ).replaceAllUsesWith (cplx);
981
+ mlir::Value pair1 = rewriter->create <fir::InsertValueOp>(
982
+ loc, pairTy, undef, firstArg, rewriter->getArrayAttr (zero));
983
+ mlir::Value pair = rewriter->create <fir::InsertValueOp>(
984
+ loc, pairTy, pair1, newArg, rewriter->getArrayAttr (one));
985
+ // Cast local argument tuple to original type via memory if needed.
986
+ if (pairTy != originalTy)
987
+ pair = convertValueInMemory (loc, pair, originalTy,
988
+ /* inputMayBeBigger=*/ true );
989
+ func.getArgument (fixup.index + 1 ).replaceAllUsesWith (pair);
938
990
func.front ().eraseArgument (fixup.index + 1 );
939
991
offset++;
940
992
}
0 commit comments