Skip to content

Commit d56b74e

Browse files
committed
[mlir][openacc] Add acc.kernels operation
The acc.kernels operation models the OpenACC kernels construct. The kernels construct defines a region of a program that is compiled into a sequence of kernels to be executed on the current device. The operation is modelled on the acc.parallel operation and will receive similar updates when the data operands operations will be implemented. Reviewed By: PeteSteinfeld Differential Revision: https://reviews.llvm.org/D148277
1 parent 3865e08 commit d56b74e

File tree

3 files changed

+198
-0
lines changed

3 files changed

+198
-0
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,87 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", [AttrSizedOperandSegments]> {
238238
}];
239239
}
240240

241+
//===----------------------------------------------------------------------===//
242+
// 2.5.1 kernels Construct
243+
//===----------------------------------------------------------------------===//
244+
245+
def OpenACC_KernelsOp : OpenACC_Op<"kernels", [AttrSizedOperandSegments]> {
246+
let summary = "kernels construct";
247+
let description = [{
248+
The "acc.kernels" operation represents a kernels construct block. It has
249+
one region to be compiled into a sequence of kernels for execution on the
250+
current device.
251+
252+
Example:
253+
254+
```mlir
255+
acc.kernels num_gangs(%c10) num_workers(%c10)
256+
private(%c : memref<10xf32>) {
257+
// kernels region
258+
}
259+
```
260+
}];
261+
262+
let arguments = (ins Optional<IntOrIndex>:$async,
263+
UnitAttr:$asyncAttr,
264+
Variadic<IntOrIndex>:$waitOperands,
265+
UnitAttr:$waitAttr,
266+
Optional<IntOrIndex>:$numGangs,
267+
Optional<IntOrIndex>:$numWorkers,
268+
Optional<IntOrIndex>:$vectorLength,
269+
Optional<I1>:$ifCond,
270+
Optional<I1>:$selfCond,
271+
UnitAttr:$selfAttr,
272+
Variadic<AnyType>:$copyOperands,
273+
Variadic<AnyType>:$copyinOperands,
274+
Variadic<AnyType>:$copyinReadonlyOperands,
275+
Variadic<AnyType>:$copyoutOperands,
276+
Variadic<AnyType>:$copyoutZeroOperands,
277+
Variadic<AnyType>:$createOperands,
278+
Variadic<AnyType>:$createZeroOperands,
279+
Variadic<AnyType>:$noCreateOperands,
280+
Variadic<AnyType>:$presentOperands,
281+
Variadic<AnyType>:$devicePtrOperands,
282+
Variadic<AnyType>:$attachOperands,
283+
OptionalAttr<DefaultValueAttr>:$defaultAttr);
284+
285+
let regions = (region AnyRegion:$region);
286+
287+
let extraClassDeclaration = [{
288+
/// The number of data operands.
289+
unsigned getNumDataOperands();
290+
291+
/// The i-th data operand passed.
292+
Value getDataOperand(unsigned i);
293+
}];
294+
295+
let assemblyFormat = [{
296+
oilist(
297+
`attach` `(` $attachOperands `:` type($attachOperands) `)`
298+
| `async` `(` $async `:` type($async) `)`
299+
| `copy` `(` $copyOperands `:` type($copyOperands) `)`
300+
| `copyin` `(` $copyinOperands `:` type($copyinOperands) `)`
301+
| `copyin_readonly` `(` $copyinReadonlyOperands `:`
302+
type($copyinReadonlyOperands) `)`
303+
| `copyout` `(` $copyoutOperands `:` type($copyoutOperands) `)`
304+
| `copyout_zero` `(` $copyoutZeroOperands `:`
305+
type($copyoutZeroOperands) `)`
306+
| `create` `(` $createOperands `:` type($createOperands) `)`
307+
| `create_zero` `(` $createZeroOperands `:` type($createZeroOperands) `)`
308+
| `deviceptr` `(` $devicePtrOperands `:` type($devicePtrOperands) `)`
309+
| `no_create` `(` $noCreateOperands `:` type($noCreateOperands) `)`
310+
| `num_gangs` `(` $numGangs `:` type($numGangs) `)`
311+
| `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
312+
| `present` `(` $presentOperands `:` type($presentOperands) `)`
313+
| `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
314+
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
315+
| `self` `(` $selfCond `)`
316+
| `if` `(` $ifCond `)`
317+
)
318+
$region attr-dict-with-keyword
319+
}];
320+
}
321+
241322
//===----------------------------------------------------------------------===//
242323
// 2.6.5 data Construct
243324
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,26 @@ Value SerialOp::getDataOperand(unsigned i) {
208208
return getOperand(getWaitOperands().size() + numOptional + i);
209209
}
210210

211+
//===----------------------------------------------------------------------===//
212+
// KernelsOp
213+
//===----------------------------------------------------------------------===//
214+
215+
unsigned KernelsOp::getNumDataOperands() {
216+
return getCopyOperands().size() + getCopyinOperands().size() +
217+
getCopyinReadonlyOperands().size() + getCopyoutOperands().size() +
218+
getCopyoutZeroOperands().size() + getCreateOperands().size() +
219+
getCreateZeroOperands().size() + getNoCreateOperands().size() +
220+
getPresentOperands().size() + getDevicePtrOperands().size() +
221+
getAttachOperands().size();
222+
}
223+
224+
Value KernelsOp::getDataOperand(unsigned i) {
225+
unsigned numOptional = getAsync() ? 1 : 0;
226+
numOptional += getIfCond() ? 1 : 0;
227+
numOptional += getSelfCond() ? 1 : 0;
228+
return getOperand(getWaitOperands().size() + numOptional + i);
229+
}
230+
211231
//===----------------------------------------------------------------------===//
212232
// LoopOp
213233
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,103 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
586586

587587
// -----
588588

589+
590+
func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
591+
%i64value = arith.constant 1 : i64
592+
%i32value = arith.constant 1 : i32
593+
%idxValue = arith.constant 1 : index
594+
acc.kernels async(%i64value: i64) {
595+
}
596+
acc.kernels async(%i32value: i32) {
597+
}
598+
acc.kernels async(%idxValue: index) {
599+
}
600+
acc.kernels wait(%i64value: i64) {
601+
}
602+
acc.kernels wait(%i32value: i32) {
603+
}
604+
acc.kernels wait(%idxValue: index) {
605+
}
606+
acc.kernels wait(%i64value, %i32value, %idxValue : i64, i32, index) {
607+
}
608+
acc.kernels copyin(%a, %b : memref<10xf32>, memref<10xf32>) {
609+
}
610+
acc.kernels copyin_readonly(%a, %b : memref<10xf32>, memref<10xf32>) {
611+
}
612+
acc.kernels copyin(%a: memref<10xf32>) copyout_zero(%b, %c : memref<10xf32>, memref<10x10xf32>) {
613+
}
614+
acc.kernels copyout(%b, %c : memref<10xf32>, memref<10x10xf32>) create(%a: memref<10xf32>) {
615+
}
616+
acc.kernels copyout_zero(%b, %c : memref<10xf32>, memref<10x10xf32>) create_zero(%a: memref<10xf32>) {
617+
}
618+
acc.kernels no_create(%a: memref<10xf32>) present(%b, %c : memref<10xf32>, memref<10x10xf32>) {
619+
}
620+
acc.kernels deviceptr(%a: memref<10xf32>) attach(%b, %c : memref<10xf32>, memref<10x10xf32>) {
621+
}
622+
acc.kernels {
623+
} attributes {defaultAttr = #acc<defaultvalue none>}
624+
acc.kernels {
625+
} attributes {defaultAttr = #acc<defaultvalue present>}
626+
acc.kernels {
627+
} attributes {asyncAttr}
628+
acc.kernels {
629+
} attributes {waitAttr}
630+
acc.kernels {
631+
} attributes {selfAttr}
632+
acc.kernels {
633+
acc.terminator
634+
} attributes {selfAttr}
635+
return
636+
}
637+
638+
// CHECK: func @testserialop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) {
639+
// CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64
640+
// CHECK: [[I32VALUE:%.*]] = arith.constant 1 : i32
641+
// CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
642+
// CHECK: acc.kernels async([[I64VALUE]] : i64) {
643+
// CHECK-NEXT: }
644+
// CHECK: acc.kernels async([[I32VALUE]] : i32) {
645+
// CHECK-NEXT: }
646+
// CHECK: acc.kernels async([[IDXVALUE]] : index) {
647+
// CHECK-NEXT: }
648+
// CHECK: acc.kernels wait([[I64VALUE]] : i64) {
649+
// CHECK-NEXT: }
650+
// CHECK: acc.kernels wait([[I32VALUE]] : i32) {
651+
// CHECK-NEXT: }
652+
// CHECK: acc.kernels wait([[IDXVALUE]] : index) {
653+
// CHECK-NEXT: }
654+
// CHECK: acc.kernels wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
655+
// CHECK-NEXT: }
656+
// CHECK: acc.kernels copyin([[ARGA]], [[ARGB]] : memref<10xf32>, memref<10xf32>) {
657+
// CHECK-NEXT: }
658+
// CHECK: acc.kernels copyin_readonly([[ARGA]], [[ARGB]] : memref<10xf32>, memref<10xf32>) {
659+
// CHECK-NEXT: }
660+
// CHECK: acc.kernels copyin([[ARGA]] : memref<10xf32>) copyout_zero([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
661+
// CHECK-NEXT: }
662+
// CHECK: acc.kernels copyout([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) create([[ARGA]] : memref<10xf32>) {
663+
// CHECK-NEXT: }
664+
// CHECK: acc.kernels copyout_zero([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) create_zero([[ARGA]] : memref<10xf32>) {
665+
// CHECK-NEXT: }
666+
// CHECK: acc.kernels no_create([[ARGA]] : memref<10xf32>) present([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
667+
// CHECK-NEXT: }
668+
// CHECK: acc.kernels attach([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) deviceptr([[ARGA]] : memref<10xf32>) {
669+
// CHECK-NEXT: }
670+
// CHECK: acc.kernels {
671+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
672+
// CHECK: acc.kernels {
673+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue present>}
674+
// CHECK: acc.kernels {
675+
// CHECK-NEXT: } attributes {asyncAttr}
676+
// CHECK: acc.kernels {
677+
// CHECK-NEXT: } attributes {waitAttr}
678+
// CHECK: acc.kernels {
679+
// CHECK-NEXT: } attributes {selfAttr}
680+
// CHECK: acc.kernels {
681+
// CHECK: acc.terminator
682+
// CHECK-NEXT: } attributes {selfAttr}
683+
684+
// -----
685+
589686
func.func @testdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
590687
%ifCond = arith.constant true
591688
acc.data if(%ifCond) present(%a : memref<10xf32>) {

0 commit comments

Comments
 (0)