Skip to content

Commit f8a9973

Browse files
authored
[flang][cuda] Add verifier for cuda_alloc/cuda_free (#90983)
Adding a verifier to check the associated cuda attribute.
1 parent a4d1026 commit f8a9973

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3364,6 +3364,8 @@ def fir_CUDAAllocOp : fir_Op<"cuda_alloc", [AttrSizedOperandSegments,
33643364
CArg<"mlir::ValueRange", "{}">:$typeparams,
33653365
CArg<"mlir::ValueRange", "{}">:$shape,
33663366
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];
3367+
3368+
let hasVerifier = 1;
33673369
}
33683370

33693371
def fir_CUDAFreeOp : fir_Op<"cuda_free", [MemoryEffects<[MemFree]>]> {
@@ -3381,6 +3383,8 @@ def fir_CUDAFreeOp : fir_Op<"cuda_free", [MemoryEffects<[MemFree]>]> {
33813383
);
33823384

33833385
let assemblyFormat = "$devptr `:` qualified(type($devptr)) attr-dict";
3386+
3387+
let hasVerifier = 1;
33843388
}
33853389

33863390
#endif

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4048,6 +4048,19 @@ void fir::CUDAAllocOp::build(
40484048
result.addAttributes(attributes);
40494049
}
40504050

4051+
template <typename Op>
4052+
static mlir::LogicalResult checkCudaAttr(Op op) {
4053+
if (op.getCudaAttr() == fir::CUDADataAttribute::Device ||
4054+
op.getCudaAttr() == fir::CUDADataAttribute::Managed ||
4055+
op.getCudaAttr() == fir::CUDADataAttribute::Unified)
4056+
return mlir::success();
4057+
return op.emitOpError("expect device, managed or unified cuda attribute");
4058+
}
4059+
4060+
mlir::LogicalResult fir::CUDAAllocOp::verify() { return checkCudaAttr(*this); }
4061+
4062+
mlir::LogicalResult fir::CUDAFreeOp::verify() { return checkCudaAttr(*this); }
4063+
40514064
//===----------------------------------------------------------------------===//
40524065
// FIROpsDialect
40534066
//===----------------------------------------------------------------------===//

flang/test/Fir/cuf-invalid.fir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,21 @@ func.func @_QPsub1() {
8585
%13 = fir.cuda_deallocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {cuda_attr = #fir.cuda<device>} -> i32
8686
return
8787
}
88+
89+
// -----
90+
91+
func.func @_QPsub1() {
92+
// expected-error@+1{{'fir.cuda_alloc' op expect device, managed or unified cuda attribute}}
93+
%0 = fir.cuda_alloc f32 {bindc_name = "r", cuda_attr = #fir.cuda<pinned>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
94+
fir.cuda_free %0 : !fir.ref<f32> {cuda_attr = #fir.cuda<constant>}
95+
return
96+
}
97+
98+
// -----
99+
100+
func.func @_QPsub1() {
101+
%0 = fir.cuda_alloc f32 {bindc_name = "r", cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
102+
// expected-error@+1{{'fir.cuda_free' op expect device, managed or unified cuda attribute}}
103+
fir.cuda_free %0 : !fir.ref<f32> {cuda_attr = #fir.cuda<constant>}
104+
return
105+
}

0 commit comments

Comments
 (0)