@@ -69,10 +69,6 @@ struct SingleRegion {
69
69
Block::iterator begin, end;
70
70
};
71
71
72
- static bool isSupportedByFirAlloca (Type ty) {
73
- return !isa<fir::ReferenceType>(ty);
74
- }
75
-
76
72
static bool mustParallelizeOp (Operation *op) {
77
73
// TODO as in shouldUseWorkshareLowering we be careful not to pick up
78
74
// workshare_loop_wrapper in nested omp.parallel ops
@@ -98,14 +94,10 @@ static bool isSafeToParallelize(Operation *op) {
98
94
static mlir::func::FuncOp createCopyFunc (mlir::Location loc, mlir::Type varType,
99
95
fir::FirOpBuilder builder) {
100
96
mlir::ModuleOp module = builder.getModule ();
101
- std::string copyFuncName;
102
- if (auto rt = dyn_cast<fir::ReferenceType>(varType)) {
103
- mlir::Type eleTy = rt.getEleTy ();
104
- copyFuncName =
105
- fir::getTypeAsString (eleTy, builder.getKindMap (), " _workshare_copy" );
106
- } else {
107
- copyFuncName = " _workshare_copy_llvm_ptr" ;
108
- }
97
+ auto rt = cast<fir::ReferenceType>(varType);
98
+ mlir::Type eleTy = rt.getEleTy ();
99
+ std::string copyFuncName =
100
+ fir::getTypeAsString (eleTy, builder.getKindMap (), " _workshare_copy" );
109
101
110
102
if (auto decl = module .lookupSymbol <mlir::func::FuncOp>(copyFuncName))
111
103
return decl;
@@ -120,6 +112,10 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
120
112
builder.createBlock (&funcOp.getRegion (), funcOp.getRegion ().end (), argsTy,
121
113
{loc, loc});
122
114
builder.setInsertionPointToStart (&funcOp.getRegion ().back ());
115
+
116
+ Value loaded = builder.create <fir::LoadOp>(loc, funcOp.getArgument (0 ));
117
+ builder.create <fir::StoreOp>(loc, loaded, funcOp.getArgument (1 ));
118
+
123
119
builder.create <mlir::func::ReturnOp>(loc);
124
120
return funcOp;
125
121
}
@@ -168,28 +164,10 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
168
164
OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
169
165
if (auto reloaded = rootMapping.lookupOrNull (v))
170
166
return nullptr ;
171
- Type llvmPtrTy = LLVM::LLVMPointerType::get (allocaBuilder.getContext ());
172
167
Type ty = v.getType ();
173
- Value alloc, reloaded;
174
- if (isSupportedByFirAlloca (ty)) {
175
- alloc = allocaBuilder.create <fir::AllocaOp>(loc, ty);
176
- singleBuilder.create <fir::StoreOp>(loc, singleMapping.lookup (v), alloc);
177
- reloaded = parallelBuilder.create <fir::LoadOp>(loc, ty, alloc);
178
- } else {
179
- auto one = allocaBuilder.create <LLVM::ConstantOp>(
180
- loc, allocaBuilder.getI32Type (), 1 );
181
- alloc =
182
- allocaBuilder.create <LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
183
- Value toStore = singleBuilder
184
- .create <UnrealizedConversionCastOp>(
185
- loc, llvmPtrTy, singleMapping.lookup (v))
186
- .getResult (0 );
187
- singleBuilder.create <LLVM::StoreOp>(loc, toStore, alloc);
188
- reloaded = parallelBuilder.create <LLVM::LoadOp>(loc, llvmPtrTy, alloc);
189
- reloaded =
190
- parallelBuilder.create <UnrealizedConversionCastOp>(loc, ty, reloaded)
191
- .getResult (0 );
192
- }
168
+ Value alloc = allocaBuilder.create <fir::AllocaOp>(loc, ty);
169
+ singleBuilder.create <fir::StoreOp>(loc, singleMapping.lookup (v), alloc);
170
+ Value reloaded = parallelBuilder.create <fir::LoadOp>(loc, ty, alloc);
193
171
rootMapping.map (v, reloaded);
194
172
return alloc;
195
173
};
0 commit comments