Skip to content

Commit 17db9ef

Browse files
authored
[OpenMP][MLIR] Add omp.distribute op to the OMP dialect (#67720)
This patch adds the omp.distribute operation to the OMP dialect. The purpose is to be able to represent the distribute construct in OpenMP with the associated clauses. The effect of the operation is to distributes the loop iterations of the loop(s) contained inside the region across multiple teams.
1 parent ca8605a commit 17db9ef

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,60 @@ def YieldOp : OpenMP_Op<"yield",
638638
let assemblyFormat = [{ ( `(` $results^ `:` type($results) `)` )? attr-dict}];
639639
}
640640

641+
//===----------------------------------------------------------------------===//
642+
// Distribute construct [2.9.4.1]
643+
//===----------------------------------------------------------------------===//
644+
def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
645+
MemoryEffects<[MemWrite]>]> {
646+
let summary = "distribute construct";
647+
let description = [{
648+
The distribute construct specifies that the iterations of one or more loops
649+
(optionally specified using collapse clause) will be executed by the
650+
initial teams in the context of their implicit tasks. The loops that the
651+
distribute op is associated with starts with the outermost loop enclosed by
652+
the distribute op region and going down the loop nest toward the innermost
653+
loop. The iterations are distributed across the initial threads of all
654+
initial teams that execute the teams region to which the distribute region
655+
binds.
656+
657+
The distribute loop construct specifies that the iterations of the loop(s)
658+
will be executed in parallel by threads in the current context. These
659+
iterations are spread across threads that already exist in the enclosing
660+
region. The lower and upper bounds specify a half-open range: the
661+
range includes the lower bound but does not include the upper bound. If the
662+
`inclusive` attribute is specified then the upper bound is also included.
663+
664+
The `dist_schedule_static` attribute specifies the schedule for this
665+
loop, determining how the loop is distributed across the parallel threads.
666+
The optional `schedule_chunk` associated with this determines further
667+
controls this distribution.
668+
669+
// TODO: private_var, firstprivate_var, lastprivate_var, collapse
670+
}];
671+
let arguments = (ins
672+
UnitAttr:$dist_schedule_static,
673+
Optional<IntLikeType>:$chunk_size,
674+
Variadic<AnyType>:$allocate_vars,
675+
Variadic<AnyType>:$allocators_vars,
676+
OptionalAttr<OrderKindAttr>:$order_val);
677+
678+
let regions = (region AnyRegion:$region);
679+
680+
let assemblyFormat = [{
681+
oilist(`dist_schedule_static` $dist_schedule_static
682+
|`chunk_size` `(` $chunk_size `:` type($chunk_size) `)`
683+
|`order` `(` custom<ClauseAttr>($order_val) `)`
684+
|`allocate` `(`
685+
custom<AllocateAndAllocator>(
686+
$allocate_vars, type($allocate_vars),
687+
$allocators_vars, type($allocators_vars)
688+
) `)`
689+
) $region attr-dict
690+
}];
691+
692+
let hasVerifier = 1;
693+
}
694+
641695
//===----------------------------------------------------------------------===//
642696
// 2.10.1 task Construct
643697
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,22 @@ LogicalResult SimdLoopOp::verify() {
11531153
return success();
11541154
}
11551155

1156+
//===----------------------------------------------------------------------===//
1157+
// Verifier for Distribute construct [2.9.4.1]
1158+
//===----------------------------------------------------------------------===//
1159+
1160+
LogicalResult DistributeOp::verify() {
1161+
if (this->getChunkSize() && !this->getDistScheduleStatic())
1162+
return emitOpError() << "chunk size set without "
1163+
"dist_schedule_static being present";
1164+
1165+
if (getAllocateVars().size() != getAllocatorsVars().size())
1166+
return emitError(
1167+
"expected equal sizes for allocate and allocator variables");
1168+
1169+
return success();
1170+
}
1171+
11561172
//===----------------------------------------------------------------------===//
11571173
// ReductionOp
11581174
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,3 +1729,12 @@ func.func @omp_target_update_invalid_motion_modifier_5(%map1 : memref<?xi32>) {
17291729
return
17301730
}
17311731
llvm.mlir.global internal @_QFsubEx() : i32
1732+
1733+
// -----
1734+
1735+
func.func @omp_distribute(%data_var : memref<i32>) -> () {
1736+
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
1737+
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
1738+
"omp.terminator"() : () -> ()
1739+
}) : (memref<i32>) -> ()
1740+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,36 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
479479
return
480480
}
481481

482+
// CHECK-LABEL: omp_distribute
483+
func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>) -> () {
484+
// CHECK: omp.distribute
485+
"omp.distribute" () ({
486+
omp.terminator
487+
}) {} : () -> ()
488+
// CHECK: omp.distribute
489+
omp.distribute {
490+
omp.terminator
491+
}
492+
// CHECK: omp.distribute dist_schedule_static
493+
omp.distribute dist_schedule_static {
494+
omp.terminator
495+
}
496+
// CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
497+
omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
498+
omp.terminator
499+
}
500+
// CHECK: omp.distribute order(concurrent)
501+
omp.distribute order(concurrent) {
502+
omp.terminator
503+
}
504+
// CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
505+
omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
506+
omp.terminator
507+
}
508+
return
509+
}
510+
511+
482512
// CHECK-LABEL: omp_target
483513
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
484514

0 commit comments

Comments
 (0)