@@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
3469
3469
llvm_unreachable (" unsupported declarative directive" );
3470
3470
}
3471
3471
3472
+ static bool hasDeviceType (llvm::SmallVector<mlir::Attribute> &arrayAttr,
3473
+ mlir::acc::DeviceType deviceType) {
3474
+ for (auto attr : arrayAttr) {
3475
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3476
+ if (deviceTypeAttr.getValue () == deviceType)
3477
+ return true ;
3478
+ }
3479
+ return false ;
3480
+ }
3481
+
3482
+ template <typename RetTy, typename AttrTy>
3483
+ static std::optional<RetTy>
3484
+ getAttributeValueByDeviceType (llvm::SmallVector<mlir::Attribute> &attributes,
3485
+ llvm::SmallVector<mlir::Attribute> &deviceTypes,
3486
+ mlir::acc::DeviceType deviceType) {
3487
+ assert (attributes.size () == deviceTypes.size () &&
3488
+ " expect same number of attributes" );
3489
+ for (auto it : llvm::enumerate (deviceTypes)) {
3490
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value ());
3491
+ if (deviceTypeAttr.getValue () == deviceType) {
3492
+ if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
3493
+ auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index ()]);
3494
+ return strAttr.getValue ();
3495
+ } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
3496
+ auto intAttr =
3497
+ mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index ()]);
3498
+ return intAttr.getInt ();
3499
+ }
3500
+ }
3501
+ }
3502
+ return std::nullopt;
3503
+ }
3504
+
3505
+ static bool compareDeviceTypeInfo (
3506
+ mlir::acc::RoutineOp op,
3507
+ llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
3508
+ llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
3509
+ llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
3510
+ llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
3511
+ llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
3512
+ llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
3513
+ llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
3514
+ llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
3515
+ for (uint32_t dtypeInt = 0 ;
3516
+ dtypeInt != mlir::acc::getMaxEnumValForDeviceType (); ++dtypeInt) {
3517
+ auto dtype = static_cast <mlir::acc::DeviceType>(dtypeInt);
3518
+ if (op.getBindNameValue (dtype) !=
3519
+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
3520
+ bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
3521
+ return false ;
3522
+ if (op.hasGang (dtype) != hasDeviceType (gangArrayAttr, dtype))
3523
+ return false ;
3524
+ if (op.getGangDimValue (dtype) !=
3525
+ getAttributeValueByDeviceType<int64_t , mlir::IntegerAttr>(
3526
+ gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
3527
+ return false ;
3528
+ if (op.hasSeq (dtype) != hasDeviceType (seqArrayAttr, dtype))
3529
+ return false ;
3530
+ if (op.hasWorker (dtype) != hasDeviceType (workerArrayAttr, dtype))
3531
+ return false ;
3532
+ if (op.hasVector (dtype) != hasDeviceType (vectorArrayAttr, dtype))
3533
+ return false ;
3534
+ }
3535
+ return true ;
3536
+ }
3537
+
3472
3538
static void attachRoutineInfo (mlir::func::FuncOp func,
3473
3539
mlir::SymbolRefAttr routineAttr) {
3474
3540
llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
3518
3584
funcName = funcOp.getName ();
3519
3585
}
3520
3586
}
3521
- bool hasSeq = false , hasGang = false , hasWorker = false , hasVector = false ,
3522
- hasNohost = false ;
3523
- std::optional<std::string> bindName = std::nullopt;
3524
- std::optional<int64_t > gangDim = std::nullopt;
3587
+ bool hasNohost = false ;
3588
+
3589
+ llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
3590
+ workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
3591
+ gangDimDeviceTypes, gangDimValues;
3592
+
3593
+ // device_type attribute is set to `none` until a device_type clause is
3594
+ // encountered.
3595
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
3596
+ builder.getContext (), mlir::acc::DeviceType::None);
3525
3597
3526
3598
for (const Fortran::parser::AccClause &clause : clauses.v ) {
3527
3599
if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u )) {
3528
- hasSeq = true ;
3600
+ seqDeviceTypes. push_back (crtDeviceTypeAttr) ;
3529
3601
} else if (const auto *gangClause =
3530
3602
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u )) {
3531
- hasGang = true ;
3603
+
3532
3604
if (gangClause->v ) {
3533
3605
const Fortran::parser::AccGangArgList &x = *gangClause->v ;
3534
3606
for (const Fortran::parser::AccGangArg &gangArg : x.v ) {
@@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
3539
3611
if (!dimValue)
3540
3612
mlir::emitError (loc,
3541
3613
" dim value must be a constant positive integer" );
3542
- gangDim = *dimValue;
3614
+ gangDimValues.push_back (
3615
+ builder.getIntegerAttr (builder.getI64Type (), *dimValue));
3616
+ gangDimDeviceTypes.push_back (crtDeviceTypeAttr);
3543
3617
}
3544
3618
}
3619
+ } else {
3620
+ gangDeviceTypes.push_back (crtDeviceTypeAttr);
3545
3621
}
3546
3622
} else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u )) {
3547
- hasVector = true ;
3623
+ vectorDeviceTypes. push_back (crtDeviceTypeAttr) ;
3548
3624
} else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u )) {
3549
- hasWorker = true ;
3625
+ workerDeviceTypes. push_back (crtDeviceTypeAttr) ;
3550
3626
} else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u )) {
3551
3627
hasNohost = true ;
3552
3628
} else if (const auto *bindClause =
3553
3629
std::get_if<Fortran::parser::AccClause::Bind>(&clause.u )) {
3554
3630
if (const auto *name =
3555
3631
std::get_if<Fortran::parser::Name>(&bindClause->v .u )) {
3556
- bindName = converter.mangleName (*name->symbol );
3632
+ bindNames.push_back (
3633
+ builder.getStringAttr (converter.mangleName (*name->symbol )));
3634
+ bindNameDeviceTypes.push_back (crtDeviceTypeAttr);
3557
3635
} else if (const auto charExpr =
3558
3636
std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
3559
3637
&bindClause->v .u )) {
@@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
3562
3640
*charExpr);
3563
3641
if (!name)
3564
3642
mlir::emitError (loc, " Could not retrieve the bind name" );
3565
- bindName = *name;
3643
+ bindNames.push_back (builder.getStringAttr (*name));
3644
+ bindNameDeviceTypes.push_back (crtDeviceTypeAttr);
3566
3645
}
3646
+ } else if (const auto *deviceTypeClause =
3647
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
3648
+ &clause.u )) {
3649
+ const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
3650
+ deviceTypeClause->v ;
3651
+ assert (deviceTypeExprList.v .size () == 1 &&
3652
+ " expect only one device_type expr" );
3653
+ crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
3654
+ builder.getContext (), getDeviceType (deviceTypeExprList.v .front ().v ));
3567
3655
}
3568
3656
}
3569
3657
@@ -3575,23 +3663,31 @@ void Fortran::lower::genOpenACCRoutineConstruct(
3575
3663
if (routineOp.getFuncName ().str ().compare (funcName) == 0 ) {
3576
3664
// If the routine is already specified with the same clauses, just skip
3577
3665
// the operation creation.
3578
- if (routineOp.getBindName () == bindName &&
3579
- routineOp.getGang () == hasGang &&
3580
- routineOp.getWorker () == hasWorker &&
3581
- routineOp.getVector () == hasVector && routineOp.getSeq () == hasSeq &&
3582
- routineOp.getNohost () == hasNohost &&
3583
- routineOp.getGangDim () == gangDim)
3666
+ if (compareDeviceTypeInfo (routineOp, bindNames, bindNameDeviceTypes,
3667
+ gangDeviceTypes, gangDimValues,
3668
+ gangDimDeviceTypes, seqDeviceTypes,
3669
+ workerDeviceTypes, vectorDeviceTypes) &&
3670
+ routineOp.getNohost () == hasNohost)
3584
3671
return ;
3585
3672
mlir::emitError (loc, " Routine already specified with different clauses" );
3586
3673
}
3587
3674
}
3588
3675
3589
3676
modBuilder.create <mlir::acc::RoutineOp>(
3590
3677
loc, routineOpName.str (), funcName,
3591
- bindName ? builder.getStringAttr (*bindName) : mlir::StringAttr{}, hasGang,
3592
- hasWorker, hasVector, hasSeq, hasNohost, /* implicit=*/ false ,
3593
- gangDim ? builder.getIntegerAttr (builder.getIntegerType (32 ), *gangDim)
3594
- : mlir::IntegerAttr{});
3678
+ bindNames.empty () ? nullptr : builder.getArrayAttr (bindNames),
3679
+ bindNameDeviceTypes.empty () ? nullptr
3680
+ : builder.getArrayAttr (bindNameDeviceTypes),
3681
+ workerDeviceTypes.empty () ? nullptr
3682
+ : builder.getArrayAttr (workerDeviceTypes),
3683
+ vectorDeviceTypes.empty () ? nullptr
3684
+ : builder.getArrayAttr (vectorDeviceTypes),
3685
+ seqDeviceTypes.empty () ? nullptr : builder.getArrayAttr (seqDeviceTypes),
3686
+ hasNohost, /* implicit=*/ false ,
3687
+ gangDeviceTypes.empty () ? nullptr : builder.getArrayAttr (gangDeviceTypes),
3688
+ gangDimValues.empty () ? nullptr : builder.getArrayAttr (gangDimValues),
3689
+ gangDimDeviceTypes.empty () ? nullptr
3690
+ : builder.getArrayAttr (gangDimDeviceTypes));
3595
3691
3596
3692
if (funcOp)
3597
3693
attachRoutineInfo (funcOp, builder.getSymbolRefAttr (routineOpName.str ()));
0 commit comments