8
8
9
9
#include " flang/Common/Fortran.h"
10
10
#include " flang/Optimizer/Builder/Runtime/RTBuilder.h"
11
+ #include " flang/Optimizer/CodeGen/TypeConverter.h"
11
12
#include " flang/Optimizer/Dialect/CUF/CUFOps.h"
12
13
#include " flang/Optimizer/Dialect/FIRDialect.h"
13
14
#include " flang/Optimizer/Dialect/FIROps.h"
14
15
#include " flang/Optimizer/HLFIR/HLFIROps.h"
16
+ #include " flang/Optimizer/Support/DataLayout.h"
17
+ #include " flang/Runtime/CUDA/descriptor.h"
15
18
#include " flang/Runtime/allocatable.h"
16
19
#include " mlir/Pass/Pass.h"
17
20
#include " mlir/Transforms/DialectConversion.h"
@@ -25,6 +28,7 @@ namespace fir {
25
28
using namespace fir ;
26
29
using namespace mlir ;
27
30
using namespace Fortran ::runtime;
31
+ using namespace Fortran ::runtime::cuf;
28
32
29
33
namespace {
30
34
@@ -75,11 +79,11 @@ static mlir::LogicalResult convertOpToCall(OpTy op,
75
79
}
76
80
77
81
struct CufAllocateOpConversion
78
- : public mlir::OpRewritePattern<cuf::AllocateOp> {
82
+ : public mlir::OpRewritePattern<:: cuf::AllocateOp> {
79
83
using OpRewritePattern::OpRewritePattern;
80
84
81
85
mlir::LogicalResult
82
- matchAndRewrite (cuf::AllocateOp op,
86
+ matchAndRewrite (:: cuf::AllocateOp op,
83
87
mlir::PatternRewriter &rewriter) const override {
84
88
// TODO: Allocation with source will need a new entry point in the runtime.
85
89
if (op.getSource ())
@@ -108,16 +112,16 @@ struct CufAllocateOpConversion
108
112
mlir::func::FuncOp func =
109
113
fir::runtime::getRuntimeFunc<mkRTKey (AllocatableAllocate)>(loc,
110
114
builder);
111
- return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
115
+ return convertOpToCall<:: cuf::AllocateOp>(op, rewriter, func);
112
116
}
113
117
};
114
118
115
119
struct CufDeallocateOpConversion
116
- : public mlir::OpRewritePattern<cuf::DeallocateOp> {
120
+ : public mlir::OpRewritePattern<:: cuf::DeallocateOp> {
117
121
using OpRewritePattern::OpRewritePattern;
118
122
119
123
mlir::LogicalResult
120
- matchAndRewrite (cuf::DeallocateOp op,
124
+ matchAndRewrite (:: cuf::DeallocateOp op,
121
125
mlir::PatternRewriter &rewriter) const override {
122
126
// TODO: Allocation of module variable will need more work as the descriptor
123
127
// will be duplicated and needs to be synced after allocation.
@@ -133,7 +137,84 @@ struct CufDeallocateOpConversion
133
137
mlir::func::FuncOp func =
134
138
fir::runtime::getRuntimeFunc<mkRTKey (AllocatableDeallocate)>(loc,
135
139
builder);
136
- return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
140
+ return convertOpToCall<::cuf::DeallocateOp>(op, rewriter, func);
141
+ }
142
+ };
143
+
144
+ struct CufAllocOpConversion : public mlir ::OpRewritePattern<::cuf::AllocOp> {
145
+ using OpRewritePattern::OpRewritePattern;
146
+
147
+ CufAllocOpConversion (mlir::MLIRContext *context, mlir::DataLayout *dl,
148
+ fir::LLVMTypeConverter *typeConverter)
149
+ : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
150
+
151
+ mlir::LogicalResult
152
+ matchAndRewrite (::cuf::AllocOp op,
153
+ mlir::PatternRewriter &rewriter) const override {
154
+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ());
155
+
156
+ // Only convert cuf.alloc that allocates a descriptor.
157
+ if (!boxTy)
158
+ return failure ();
159
+
160
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
161
+ fir::FirOpBuilder builder (rewriter, mod);
162
+ mlir::Location loc = op.getLoc ();
163
+ mlir::func::FuncOp func =
164
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFAllocDesciptor)>(loc, builder);
165
+
166
+ auto fTy = func.getFunctionType ();
167
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
168
+ mlir::Value sourceLine =
169
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
170
+
171
+ mlir::Type structTy = typeConverter->convertBoxTypeAsStruct (boxTy);
172
+ std::size_t boxSize = dl->getTypeSizeInBits (structTy) / 8 ;
173
+ mlir::Value sizeInBytes =
174
+ builder.createIntegerConstant (loc, builder.getIndexType (), boxSize);
175
+
176
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
177
+ builder, loc, fTy , sizeInBytes, sourceFile, sourceLine)};
178
+ auto callOp = builder.create <fir::CallOp>(loc, func, args);
179
+ auto convOp = builder.createConvert (loc, op.getResult ().getType (),
180
+ callOp.getResult (0 ));
181
+ rewriter.replaceOp (op, convOp);
182
+ return mlir::success ();
183
+ }
184
+
185
+ private:
186
+ mlir::DataLayout *dl;
187
+ fir::LLVMTypeConverter *typeConverter;
188
+ };
189
+
190
+ struct CufFreeOpConversion : public mlir ::OpRewritePattern<::cuf::FreeOp> {
191
+ using OpRewritePattern::OpRewritePattern;
192
+
193
+ mlir::LogicalResult
194
+ matchAndRewrite (::cuf::FreeOp op,
195
+ mlir::PatternRewriter &rewriter) const override {
196
+ // Only convert cuf.free on descriptor.
197
+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr ().getType ()))
198
+ return failure ();
199
+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr ().getType ());
200
+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ()))
201
+ return failure ();
202
+
203
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
204
+ fir::FirOpBuilder builder (rewriter, mod);
205
+ mlir::Location loc = op.getLoc ();
206
+ mlir::func::FuncOp func =
207
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFFreeDesciptor)>(loc, builder);
208
+
209
+ auto fTy = func.getFunctionType ();
210
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
211
+ mlir::Value sourceLine =
212
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
213
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
214
+ builder, loc, fTy , op.getDevptr (), sourceFile, sourceLine)};
215
+ builder.create <fir::CallOp>(loc, func, args);
216
+ rewriter.eraseOp (op);
217
+ return mlir::success ();
137
218
}
138
219
};
139
220
@@ -143,8 +224,22 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
143
224
auto *ctx = &getContext ();
144
225
mlir::RewritePatternSet patterns (ctx);
145
226
mlir::ConversionTarget target (*ctx);
146
- target.addIllegalOp <cuf::AllocateOp, cuf::DeallocateOp>();
147
- patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion>(ctx);
227
+
228
+ mlir::Operation *op = getOperation ();
229
+ mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
230
+ if (!module )
231
+ return signalPassFailure ();
232
+
233
+ std::optional<mlir::DataLayout> dl =
234
+ fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
235
+ fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
236
+ /* forceUnifiedTBAATree=*/ false , *dl);
237
+
238
+ target.addIllegalOp <::cuf::AllocOp, ::cuf::AllocateOp, ::cuf::DeallocateOp,
239
+ ::cuf::FreeOp>();
240
+ patterns.insert <CufAllocOpConversion>(ctx, &*dl, &typeConverter);
241
+ patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
242
+ CufFreeOpConversion>(ctx);
148
243
if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
149
244
std::move (patterns)))) {
150
245
mlir::emitError (mlir::UnknownLoc::get (ctx),
0 commit comments