29
29
#include " flang/Lower/PFTBuilder.h"
30
30
#include " flang/Lower/StatementContext.h"
31
31
#include " flang/Lower/Support/Utils.h"
32
+ #include " flang/Optimizer/Builder/Complex.h"
32
33
#include " flang/Optimizer/Builder/DirectivesCommon.h"
33
34
#include " flang/Optimizer/Builder/HLFIRTools.h"
34
35
#include " flang/Optimizer/Dialect/FIRType.h"
@@ -103,6 +104,61 @@ static void processOmpAtomicTODO(mlir::Type elementType,
103
104
}
104
105
}
105
106
107
+ // / Emits an implicit cast for atomic statements
108
+ static void emitImplicitCast (Fortran::lower::AbstractConverter &converter,
109
+ mlir::Location loc, mlir::Value &fromAddress,
110
+ mlir::Value &toAddress, mlir::Type &elementType) {
111
+ if (fromAddress.getType () == toAddress.getType ())
112
+ return ;
113
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
114
+ mlir::Value alloca = builder.create <fir::AllocaOp>(
115
+ loc, fir::unwrapRefType (toAddress.getType ()));
116
+ mlir::Value loadedVal = builder.create <fir::LoadOp>(loc, fromAddress);
117
+ mlir::Type toType = fir::unwrapRefType (toAddress.getType ());
118
+ mlir::Type fromType = fir::unwrapRefType (fromAddress.getType ());
119
+ if (!fir::isa_complex (toType) && !fir::isa_complex (fromType)) {
120
+ loadedVal = builder.create <fir::ConvertOp>(
121
+ loc, fir::unwrapRefType (toAddress.getType ()), loadedVal);
122
+ builder.create <fir::StoreOp>(loc, loadedVal, alloca);
123
+ } else if (!fir::isa_complex (toType) && fir::isa_complex (fromType)) {
124
+ loadedVal = builder.create <fir::ExtractValueOp>(
125
+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (),
126
+ loadedVal,
127
+ builder.getArrayAttr (
128
+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
129
+ loadedVal = builder.create <fir::ConvertOp>(loc, toType, loadedVal);
130
+ builder.create <fir::StoreOp>(loc, loadedVal, alloca);
131
+ } else if (fir::isa_complex (toType) && fir::isa_complex (fromType)) {
132
+ mlir::Value firstComp = builder.create <fir::ExtractValueOp>(
133
+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (),
134
+ loadedVal,
135
+ builder.getArrayAttr (
136
+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
137
+ mlir::Value secondComp = builder.create <fir::ExtractValueOp>(
138
+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (),
139
+ loadedVal,
140
+ builder.getArrayAttr (
141
+ builder.getIntegerAttr (builder.getIndexType (), 1 )));
142
+ firstComp = builder.create <fir::ConvertOp>(
143
+ loc, mlir::cast<mlir::ComplexType>(toType).getElementType (), firstComp);
144
+ secondComp = builder.create <fir::ConvertOp>(
145
+ loc, mlir::cast<mlir::ComplexType>(toType).getElementType (),
146
+ secondComp);
147
+ auto undef = builder.create <fir::UndefOp>(loc, toType);
148
+ mlir::Value pair1 = builder.create <fir::InsertValueOp>(
149
+ loc, toType, undef, firstComp,
150
+ builder.getArrayAttr (
151
+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
152
+ mlir::Value pair = builder.create <fir::InsertValueOp>(
153
+ loc, toType, pair1, secondComp,
154
+ builder.getArrayAttr (
155
+ builder.getIntegerAttr (builder.getIndexType (), 1 )));
156
+ builder.create <fir::StoreOp>(loc, pair, alloca);
157
+ }
158
+ fromAddress = alloca;
159
+ elementType = fir::unwrapRefType (toAddress.getType ());
160
+ }
161
+
106
162
// / Used to generate atomic.read operation which is created in existing
107
163
// / location set by builder.
108
164
template <typename AtomicListT>
@@ -386,6 +442,7 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
386
442
fir::getBase (converter.genExprAddr (fromExpr, stmtCtx));
387
443
mlir::Value toAddress = fir::getBase (converter.genExprAddr (
388
444
*Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx));
445
+ emitImplicitCast (converter, loc, fromAddress, toAddress, elementType);
389
446
genOmpAccAtomicCaptureStatement (converter, fromAddress, toAddress,
390
447
leftHandClauseList, rightHandClauseList,
391
448
elementType, loc);
@@ -481,6 +538,30 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
481
538
mlir::Type stmt2VarType =
482
539
fir::getBase (converter.genExprValue (assign2.lhs , stmtCtx)).getType ();
483
540
541
+ // Checks helpful in constructing the `atomic.capture` region
542
+ bool hasSingleVariable =
543
+ Fortran::semantics::checkForSingleVariableOnRHS (stmt1);
544
+ bool hasSymMatch = Fortran::semantics::checkForSymbolMatch (stmt2);
545
+
546
+ // Implicit casts
547
+ mlir::Type captureStmtElemTy;
548
+ if (hasSingleVariable) {
549
+ if (hasSymMatch) {
550
+ // Atomic capture construct is of the form [capture-stmt, update-stmt]
551
+ // FIXME: Emit an implicit cast if there is a type mismatch
552
+ } else {
553
+ // Atomic capture construct is of the form [capture-stmt, write-stmt]
554
+ const Fortran::semantics::SomeExpr &fromExpr =
555
+ *Fortran::semantics::GetExpr (stmt1Expr);
556
+ captureStmtElemTy = converter.genType (fromExpr);
557
+ emitImplicitCast (converter, loc, stmt2LHSArg, stmt1LHSArg,
558
+ captureStmtElemTy);
559
+ }
560
+ } else {
561
+ // Atomic capture construct is of the form [update-stmt, capture-stmt]
562
+ // FIXME: Emit an implicit cast if there is a type mismatch
563
+ }
564
+
484
565
mlir::Operation *atomicCaptureOp = nullptr ;
485
566
if constexpr (std::is_same<AtomicListT,
486
567
Fortran::parser::OmpAtomicClauseList>()) {
@@ -501,8 +582,8 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
501
582
firOpBuilder.createBlock (&(atomicCaptureOp->getRegion (0 )));
502
583
mlir::Block &block = atomicCaptureOp->getRegion (0 ).back ();
503
584
firOpBuilder.setInsertionPointToStart (&block);
504
- if (Fortran::semantics::checkForSingleVariableOnRHS (stmt1) ) {
505
- if (Fortran::semantics::checkForSymbolMatch (stmt2) ) {
585
+ if (hasSingleVariable ) {
586
+ if (hasSymMatch ) {
506
587
// Atomic capture construct is of the form [capture-stmt, update-stmt]
507
588
const Fortran::semantics::SomeExpr &fromExpr =
508
589
*Fortran::semantics::GetExpr (stmt1Expr);
@@ -521,13 +602,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
521
602
mlir::Value stmt2RHSArg =
522
603
fir::getBase (converter.genExprValue (assign2.rhs , stmtCtx));
523
604
firOpBuilder.setInsertionPointToStart (&block);
524
- const Fortran::semantics::SomeExpr &fromExpr =
525
- *Fortran::semantics::GetExpr (stmt1Expr);
526
- mlir::Type elementType = converter.genType (fromExpr);
527
605
genOmpAccAtomicCaptureStatement<AtomicListT>(
528
606
converter, stmt2LHSArg, stmt1LHSArg,
529
607
/* leftHandClauseList=*/ nullptr ,
530
- /* rightHandClauseList=*/ nullptr , elementType , loc);
608
+ /* rightHandClauseList=*/ nullptr , captureStmtElemTy , loc);
531
609
genOmpAccAtomicWriteStatement<AtomicListT>(
532
610
converter, stmt2LHSArg, stmt2RHSArg,
533
611
/* leftHandClauseList=*/ nullptr ,
0 commit comments