@@ -1172,6 +1172,9 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
1172
1172
using GenericTarget::GenericTarget;
1173
1173
1174
1174
static constexpr int defaultWidth = 64 ;
1175
+ static constexpr int GRLen = defaultWidth; /* eight bytes */
1176
+ static constexpr int GRLenInChar = GRLen / 8 ;
1177
+ static constexpr int FRLen = defaultWidth; /* eight bytes */
1175
1178
1176
1179
CodeGenSpecifics::Marshalling
1177
1180
complexArgumentType (mlir::Location loc, mlir::Type eleTy) const override {
@@ -1242,6 +1245,313 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
1242
1245
1243
1246
return GenericTarget::integerArgumentType (loc, argTy);
1244
1247
}
1248
+
1249
+ // / Flatten non-basic types, resulting in an array of types containing only
1250
+ // / `IntegerType` and `FloatType`.
1251
+ llvm::SmallVector<mlir::Type> flattenTypeList (mlir::Location loc,
1252
+ const mlir::Type type) const {
1253
+ llvm::SmallVector<mlir::Type> flatTypes;
1254
+
1255
+ llvm::TypeSwitch<mlir::Type>(type)
1256
+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
1257
+ if (intTy.getWidth () != 0 )
1258
+ flatTypes.push_back (intTy);
1259
+ })
1260
+ .template Case <mlir::FloatType>([&](mlir::FloatType floatTy) {
1261
+ if (floatTy.getWidth () != 0 )
1262
+ flatTypes.push_back (floatTy);
1263
+ })
1264
+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
1265
+ const auto *sem = &floatToSemantics (kindMap, cmplx.getElementType ());
1266
+ if (sem == &llvm::APFloat::IEEEsingle () ||
1267
+ sem == &llvm::APFloat::IEEEdouble () ||
1268
+ sem == &llvm::APFloat::IEEEquad ())
1269
+ std::fill_n (std::back_inserter (flatTypes), 2 ,
1270
+ cmplx.getElementType ());
1271
+ else
1272
+ TODO (loc, " unsupported complex type(not IEEEsingle, IEEEdouble, "
1273
+ " IEEEquad) as a structure component for BIND(C), "
1274
+ " VALUE derived type argument and type return" );
1275
+ })
1276
+ .template Case <fir::LogicalType>([&](fir::LogicalType logicalTy) {
1277
+ const unsigned width =
1278
+ kindMap.getLogicalBitsize (logicalTy.getFKind ());
1279
+ if (width != 0 )
1280
+ flatTypes.push_back (
1281
+ mlir::IntegerType::get (type.getContext (), width));
1282
+ })
1283
+ .template Case <fir::CharacterType>([&](fir::CharacterType charTy) {
1284
+ assert (kindMap.getCharacterBitsize (charTy.getFKind ()) <= 8 &&
1285
+ " the bit size of characterType as an interoperable type must "
1286
+ " not exceed 8" );
1287
+ for (unsigned i = 0 ; i < charTy.getLen (); ++i)
1288
+ flatTypes.push_back (mlir::IntegerType::get (type.getContext (), 8 ));
1289
+ })
1290
+ .template Case <fir::SequenceType>([&](fir::SequenceType seqTy) {
1291
+ if (!seqTy.hasDynamicExtents ()) {
1292
+ const std::uint64_t numOfEle = seqTy.getConstantArraySize ();
1293
+ mlir::Type eleTy = seqTy.getEleTy ();
1294
+ if (!mlir::isa<mlir::IntegerType, mlir::FloatType>(eleTy)) {
1295
+ llvm::SmallVector<mlir::Type> subTypeList =
1296
+ flattenTypeList (loc, eleTy);
1297
+ if (subTypeList.size () != 0 )
1298
+ for (std::uint64_t i = 0 ; i < numOfEle; ++i)
1299
+ llvm::copy (subTypeList, std::back_inserter (flatTypes));
1300
+ } else {
1301
+ std::fill_n (std::back_inserter (flatTypes), numOfEle, eleTy);
1302
+ }
1303
+ } else
1304
+ TODO (loc, " unsupported dynamic extent sequence type as a structure "
1305
+ " component for BIND(C), "
1306
+ " VALUE derived type argument and type return" );
1307
+ })
1308
+ .template Case <fir::RecordType>([&](fir::RecordType recTy) {
1309
+ for (auto &component : recTy.getTypeList ()) {
1310
+ mlir::Type eleTy = component.second ;
1311
+ llvm::SmallVector<mlir::Type> subTypeList =
1312
+ flattenTypeList (loc, eleTy);
1313
+ if (subTypeList.size () != 0 )
1314
+ llvm::copy (subTypeList, std::back_inserter (flatTypes));
1315
+ }
1316
+ })
1317
+ .template Case <fir::VectorType>([&](fir::VectorType vecTy) {
1318
+ auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash (
1319
+ loc, vecTy, getDataLayout (), kindMap);
1320
+ if (sizeAndAlign.first == 2 * GRLenInChar)
1321
+ flatTypes.push_back (
1322
+ mlir::IntegerType::get (type.getContext (), 2 * GRLen));
1323
+ else
1324
+ TODO (loc, " unsupported vector width(must be 128 bits)" );
1325
+ })
1326
+ .Default ([&](mlir::Type ty) {
1327
+ if (fir::conformsWithPassByRef (ty))
1328
+ flatTypes.push_back (
1329
+ mlir::IntegerType::get (type.getContext (), GRLen));
1330
+ else
1331
+ TODO (loc, " unsupported component type for BIND(C), VALUE derived "
1332
+ " type argument and type return" );
1333
+ });
1334
+
1335
+ return flatTypes;
1336
+ }
1337
+
1338
+ // / Determine if a struct is eligible to be passed in FARs (and GARs) (i.e.,
1339
+ // / when flattened it contains a single fp value, fp+fp, or int+fp of
1340
+ // / appropriate size).
1341
+ bool detectFARsEligibleStruct (mlir::Location loc, fir::RecordType recTy,
1342
+ mlir::Type &field1Ty,
1343
+ mlir::Type &field2Ty) const {
1344
+ field1Ty = field2Ty = nullptr ;
1345
+ llvm::SmallVector<mlir::Type> flatTypes = flattenTypeList (loc, recTy);
1346
+ size_t flatSize = flatTypes.size ();
1347
+
1348
+ // Cannot be eligible if the number of flattened types is equal to 0 or
1349
+ // greater than 2.
1350
+ if (flatSize == 0 || flatSize > 2 )
1351
+ return false ;
1352
+
1353
+ bool isFirstAvaliableFloat = false ;
1354
+
1355
+ assert ((mlir::isa<mlir::IntegerType, mlir::FloatType>(flatTypes[0 ])) &&
1356
+ " Type must be integerType or floatType after flattening" );
1357
+ if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(flatTypes[0 ])) {
1358
+ const unsigned Size = floatTy.getWidth ();
1359
+ // Can't be eligible if larger than the FP registers. Half precision isn't
1360
+ // currently supported on LoongArch and the ABI hasn't been confirmed, so
1361
+ // default to the integer ABI in that case.
1362
+ if (Size > FRLen || Size < 32 )
1363
+ return false ;
1364
+ isFirstAvaliableFloat = true ;
1365
+ field1Ty = floatTy;
1366
+ } else if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(flatTypes[0 ])) {
1367
+ if (intTy.getWidth () > GRLen)
1368
+ return false ;
1369
+ field1Ty = intTy;
1370
+ }
1371
+
1372
+ // flatTypes has two elements
1373
+ if (flatSize == 2 ) {
1374
+ assert ((mlir::isa<mlir::IntegerType, mlir::FloatType>(flatTypes[1 ])) &&
1375
+ " Type must be integerType or floatType after flattening" );
1376
+ if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(flatTypes[1 ])) {
1377
+ const unsigned Size = floatTy.getWidth ();
1378
+ if (Size > FRLen || Size < 32 )
1379
+ return false ;
1380
+ field2Ty = floatTy;
1381
+ return true ;
1382
+ } else if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(flatTypes[1 ])) {
1383
+ // Can't be eligible if an integer type was already found (int+int pairs
1384
+ // are not eligible).
1385
+ if (!isFirstAvaliableFloat)
1386
+ return false ;
1387
+ if (intTy.getWidth () > GRLen)
1388
+ return false ;
1389
+ field2Ty = intTy;
1390
+ return true ;
1391
+ }
1392
+ }
1393
+
1394
+ // return isFirstAvaliableFloat if flatTypes only has one element
1395
+ return isFirstAvaliableFloat;
1396
+ }
1397
+
1398
+ bool checkTypeHasEnoughRegs (mlir::Location loc, int &GARsLeft, int &FARsLeft,
1399
+ const mlir::Type type) const {
1400
+ if (!type)
1401
+ return true ;
1402
+
1403
+ llvm::TypeSwitch<mlir::Type>(type)
1404
+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
1405
+ const unsigned width = intTy.getWidth ();
1406
+ if (width > 128 )
1407
+ TODO (loc,
1408
+ " integerType with width exceeding 128 bits is unsupported" );
1409
+ if (width == 0 )
1410
+ return ;
1411
+ if (width <= GRLen)
1412
+ --GARsLeft;
1413
+ else if (width <= 2 * GRLen)
1414
+ GARsLeft = GARsLeft - 2 ;
1415
+ })
1416
+ .template Case <mlir::FloatType>([&](mlir::FloatType floatTy) {
1417
+ const unsigned width = floatTy.getWidth ();
1418
+ if (width > 128 )
1419
+ TODO (loc, " floatType with width exceeding 128 bits is unsupported" );
1420
+ if (width == 0 )
1421
+ return ;
1422
+ if (width == 32 || width == 64 )
1423
+ --FARsLeft;
1424
+ else if (width <= GRLen)
1425
+ --GARsLeft;
1426
+ else if (width <= 2 * GRLen)
1427
+ GARsLeft = GARsLeft - 2 ;
1428
+ })
1429
+ .Default ([&](mlir::Type ty) {
1430
+ if (fir::conformsWithPassByRef (ty))
1431
+ --GARsLeft; // Pointers.
1432
+ else
1433
+ TODO (loc, " unsupported component type for BIND(C), VALUE derived "
1434
+ " type argument and type return" );
1435
+ });
1436
+
1437
+ return GARsLeft >= 0 && FARsLeft >= 0 ;
1438
+ }
1439
+
1440
+ bool hasEnoughRegisters (mlir::Location loc, int GARsLeft, int FARsLeft,
1441
+ const Marshalling &previousArguments,
1442
+ const mlir::Type &field1Ty,
1443
+ const mlir::Type &field2Ty) const {
1444
+ for (auto &typeAndAttr : previousArguments) {
1445
+ const auto &attr = std::get<Attributes>(typeAndAttr);
1446
+ if (attr.isByVal ()) {
1447
+ // Previous argument passed on the stack, and its address is passed in
1448
+ // GAR.
1449
+ --GARsLeft;
1450
+ continue ;
1451
+ }
1452
+
1453
+ // Previous aggregate arguments were marshalled into simpler arguments.
1454
+ const auto &type = std::get<mlir::Type>(typeAndAttr);
1455
+ llvm::SmallVector<mlir::Type> flatTypes = flattenTypeList (loc, type);
1456
+
1457
+ for (auto &flatTy : flatTypes) {
1458
+ if (!checkTypeHasEnoughRegs (loc, GARsLeft, FARsLeft, flatTy))
1459
+ return false ;
1460
+ }
1461
+ }
1462
+
1463
+ if (!checkTypeHasEnoughRegs (loc, GARsLeft, FARsLeft, field1Ty))
1464
+ return false ;
1465
+ if (!checkTypeHasEnoughRegs (loc, GARsLeft, FARsLeft, field2Ty))
1466
+ return false ;
1467
+ return true ;
1468
+ }
1469
+
1470
+ // / LoongArch64 subroutine calling sequence ABI in:
1471
+ // / https://github.com/loongson/la-abi-specs/blob/release/lapcs.adoc#subroutine-calling-sequence
1472
+ CodeGenSpecifics::Marshalling
1473
+ classifyStruct (mlir::Location loc, fir::RecordType recTy, int GARsLeft,
1474
+ int FARsLeft, bool isResult,
1475
+ const Marshalling &previousArguments) const {
1476
+ CodeGenSpecifics::Marshalling marshal;
1477
+
1478
+ auto [recSize, recAlign] = fir::getTypeSizeAndAlignmentOrCrash (
1479
+ loc, recTy, getDataLayout (), kindMap);
1480
+ mlir::MLIRContext *context = recTy.getContext ();
1481
+
1482
+ if (recSize == 0 ) {
1483
+ TODO (loc, " unsupported empty struct type for BIND(C), "
1484
+ " VALUE derived type argument and type return" );
1485
+ }
1486
+
1487
+ if (recSize > 2 * GRLenInChar) {
1488
+ marshal.emplace_back (
1489
+ fir::ReferenceType::get (recTy),
1490
+ AT{recAlign, /* byval=*/ !isResult, /* sret=*/ isResult});
1491
+ return marshal;
1492
+ }
1493
+
1494
+ // Pass by FARs(and GARs)
1495
+ mlir::Type field1Ty = nullptr , field2Ty = nullptr ;
1496
+ if (detectFARsEligibleStruct (loc, recTy, field1Ty, field2Ty) &&
1497
+ hasEnoughRegisters (loc, GARsLeft, FARsLeft, previousArguments, field1Ty,
1498
+ field2Ty)) {
1499
+ if (!isResult) {
1500
+ if (field1Ty)
1501
+ marshal.emplace_back (field1Ty, AT{});
1502
+ if (field2Ty)
1503
+ marshal.emplace_back (field2Ty, AT{});
1504
+ } else {
1505
+ // field1Ty is always preferred over field2Ty for assignment, so there
1506
+ // will never be a case where field1Ty == nullptr and field2Ty !=
1507
+ // nullptr.
1508
+ if (field1Ty && !field2Ty)
1509
+ marshal.emplace_back (field1Ty, AT{});
1510
+ else if (field1Ty && field2Ty)
1511
+ marshal.emplace_back (
1512
+ mlir::TupleType::get (context,
1513
+ mlir::TypeRange{field1Ty, field2Ty}),
1514
+ AT{/* alignment=*/ 0 , /* byval=*/ true });
1515
+ }
1516
+ return marshal;
1517
+ }
1518
+
1519
+ if (recSize <= GRLenInChar) {
1520
+ marshal.emplace_back (mlir::IntegerType::get (context, GRLen), AT{});
1521
+ return marshal;
1522
+ }
1523
+
1524
+ if (recAlign == 2 * GRLenInChar) {
1525
+ marshal.emplace_back (mlir::IntegerType::get (context, 2 * GRLen), AT{});
1526
+ return marshal;
1527
+ }
1528
+
1529
+ // recSize > GRLenInChar && recSize <= 2 * GRLenInChar
1530
+ marshal.emplace_back (
1531
+ fir::SequenceType::get ({2 }, mlir::IntegerType::get (context, GRLen)),
1532
+ AT{});
1533
+ return marshal;
1534
+ }
1535
+
1536
+ // / Marshal a derived type passed by value like a C struct.
1537
+ CodeGenSpecifics::Marshalling
1538
+ structArgumentType (mlir::Location loc, fir::RecordType recTy,
1539
+ const Marshalling &previousArguments) const override {
1540
+ int GARsLeft = 8 ;
1541
+ int FARsLeft = FRLen ? 8 : 0 ;
1542
+
1543
+ return classifyStruct (loc, recTy, GARsLeft, FARsLeft, /* isResult=*/ false ,
1544
+ previousArguments);
1545
+ }
1546
+
1547
+ CodeGenSpecifics::Marshalling
1548
+ structReturnType (mlir::Location loc, fir::RecordType recTy) const override {
1549
+ // The rules for return and argument types are the same.
1550
+ int GARsLeft = 2 ;
1551
+ int FARsLeft = FRLen ? 2 : 0 ;
1552
+ return classifyStruct (loc, recTy, GARsLeft, FARsLeft, /* isResult=*/ true ,
1553
+ {});
1554
+ }
1245
1555
};
1246
1556
} // namespace
1247
1557
0 commit comments