Skip to content

Commit aa5f843

Browse files
committed
[mlir][OpenMP] map argument to reduction initialization region
The argument to the initialization region of reduction declarations was never mapped. This meant that if this argument was accessed inside the initialization region, that mlir operation would be translated to an llvm operation with a null argument (failing verification). Adding the mapping ensures that the right LLVM value can be found when inlining and converting the initialization region. We have to separately establish and clean up these mappings for each use of the reduction declaration because repeated usage of the same declaration will inline it using a different concrete value for the block argument.
1 parent 97e02bc commit aa5f843

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,25 @@ static void allocByValReductionVars(
825825
}
826826
}
827827

828+
/// Map input argument to all reduction initialization regions
829+
template <typename T>
830+
static void
831+
mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation,
832+
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
833+
unsigned i) {
834+
// map input argument to the initialization region
835+
mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
836+
Region &initializerRegion = reduction.getInitializerRegion();
837+
Block &entry = initializerRegion.front();
838+
assert(entry.getNumArguments() == 1 &&
839+
"the initialization region has one argument");
840+
841+
mlir::Value mlirSource = loop.getReductionVars()[i];
842+
llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
843+
assert(llvmSource && "lookup reduction var");
844+
moduleTranslation.mapValue(entry.getArgument(0), llvmSource);
845+
}
846+
828847
/// Collect reduction info
829848
template <typename T>
830849
static void collectReductionInfo(
@@ -902,6 +921,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
902921
loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
903922
for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) {
904923
SmallVector<llvm::Value *> phis;
924+
925+
// map block argument to initializer region
926+
mapInitializationArg(loop, moduleTranslation, reductionDecls, i);
927+
905928
if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
906929
"omp.reduction.neutral", builder,
907930
moduleTranslation, &phis)))
@@ -925,6 +948,11 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
925948
builder.CreateStore(phis[0], privateReductionVariables[i]);
926949
// the rest was handled in allocByValReductionVars
927950
}
951+
952+
// forget the mapping for the initializer region because we might need a
953+
// different mapping if this reduction declaration is re-used for a
954+
// different variable
955+
moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
928956
}
929957

930958
// Store the mapping between reduction variables and their private copies on
@@ -1118,6 +1146,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11181146
opInst.getNumReductionVars());
11191147
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
11201148
SmallVector<llvm::Value *> phis;
1149+
1150+
// map the block argument
1151+
mapInitializationArg(opInst, moduleTranslation, reductionDecls, i);
11211152
if (failed(inlineConvertOmpRegions(
11221153
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
11231154
builder, moduleTranslation, &phis)))
@@ -1144,6 +1175,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11441175
builder.CreateStore(phis[0], privateReductionVariables[i]);
11451176
// the rest is done in allocByValReductionVars
11461177
}
1178+
1179+
// clear block argument mapping in case it needs to be re-created with a
1180+
// different source for another use of the same reduction decl
1181+
moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
11471182
}
11481183

11491184
// Store the mapping between reduction variables and their private copies on
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// Test that the block argument to the initialization region of
4+
// omp.declare_reduction gets mapped properly when translating to LLVMIR.
5+
6+
module {
7+
omp.declare_reduction @add_reduction_byref_box_Uxf64 : !llvm.ptr init {
8+
^bb0(%arg0: !llvm.ptr):
9+
// test usage of %arg0:
10+
%11 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
11+
omp.yield(%arg0 : !llvm.ptr)
12+
} combiner {
13+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
14+
omp.yield(%arg0 : !llvm.ptr)
15+
}
16+
17+
llvm.func internal @_QFPreduce(%arg0: !llvm.ptr {fir.bindc_name = "r"}, %arg1: !llvm.ptr {fir.bindc_name = "r2"}) attributes {sym_visibility = "private"} {
18+
%8 = llvm.mlir.constant(1 : i32) : i32
19+
%9 = llvm.mlir.constant(10 : i32) : i32
20+
%10 = llvm.mlir.constant(0 : i32) : i32
21+
omp.parallel {
22+
%83 = llvm.mlir.constant(1 : i64) : i64
23+
%84 = llvm.alloca %83 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
24+
%86 = llvm.mlir.constant(1 : i64) : i64
25+
%87 = llvm.alloca %86 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
26+
// test multiple reduction variables to ensure they don't intefere with eachother
27+
// when inlining the reduction init region multiple times
28+
omp.wsloop byref reduction(@add_reduction_byref_box_Uxf64 %84 -> %arg3 : !llvm.ptr, @add_reduction_byref_box_Uxf64 %87 -> %arg4 : !llvm.ptr) for (%arg2) : i32 = (%10) to (%9) inclusive step (%8) {
29+
omp.yield
30+
}
31+
omp.terminator
32+
}
33+
llvm.return
34+
}
35+
}
36+
37+
// CHECK-LABEL: define internal void @_QFPreduce(ptr %{{.*}}, ptr %{{.*}})
38+
// CHECK: br label %entry
39+
// CHECK: entry: ; preds = %[[VAL_1:.*]]
40+
// CHECK: %[[VAL_2:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
41+
// CHECK: br label %[[VAL_3:.*]]
42+
// CHECK: omp_parallel: ; preds = %entry
43+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 0, ptr @_QFPreduce..omp_par)
44+
// CHECK: br label %[[VAL_4:.*]]
45+
// CHECK: omp.par.outlined.exit: ; preds = %[[VAL_3]]
46+
// CHECK: br label %[[VAL_5:.*]]
47+
// CHECK: omp.par.exit.split: ; preds = %[[VAL_4]]
48+
// CHECK: ret void
49+
// CHECK: omp.par.entry:
50+
// CHECK: %[[VAL_6:.*]] = alloca i32, align 4
51+
// CHECK: %[[VAL_7:.*]] = load i32, ptr %[[VAL_8:.*]], align 4
52+
// CHECK: store i32 %[[VAL_7]], ptr %[[VAL_6]], align 4
53+
// CHECK: %[[VAL_9:.*]] = load i32, ptr %[[VAL_6]], align 4
54+
// CHECK: %[[VAL_10:.*]] = alloca i32, align 4
55+
// CHECK: %[[VAL_11:.*]] = alloca i32, align 4
56+
// CHECK: %[[VAL_12:.*]] = alloca i32, align 4
57+
// CHECK: %[[VAL_13:.*]] = alloca i32, align 4
58+
// CHECK: %[[VAL_14:.*]] = alloca [2 x ptr], align 8
59+
// CHECK: br label %[[VAL_15:.*]]
60+
// CHECK: omp.par.region: ; preds = %[[VAL_16:.*]]
61+
// CHECK: br label %[[VAL_17:.*]]
62+
// CHECK: omp.par.region1: ; preds = %[[VAL_15]]
63+
// CHECK: %[[VAL_18:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
64+
// CHECK: %[[VAL_19:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
65+
// CHECK: %[[VAL_20:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_18]], align 8
66+
// CHECK: %[[VAL_21:.*]] = alloca ptr, align 8
67+
// CHECK: store ptr %[[VAL_18]], ptr %[[VAL_21]], align 8
68+
// CHECK: %[[VAL_22:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_19]], align 8
69+
// CHECK: %[[VAL_23:.*]] = alloca ptr, align 8
70+
// CHECK: store ptr %[[VAL_19]], ptr %[[VAL_23]], align 8
71+
// CHECK: br label %[[VAL_24:.*]]
72+
// CHECK: omp_loop.preheader: ; preds = %[[VAL_17]]
73+
// CHECK: store i32 0, ptr %[[VAL_11]], align 4
74+
// CHECK: store i32 10, ptr %[[VAL_12]], align 4
75+
// CHECK: store i32 1, ptr %[[VAL_13]], align 4
76+
// CHECK: %[[VAL_25:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
77+
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_25]], i32 34, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr %[[VAL_12]], ptr %[[VAL_13]], i32 1, i32 0)
78+
// CHECK: %[[VAL_26:.*]] = load i32, ptr %[[VAL_11]], align 4
79+
// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_12]], align 4
80+
// CHECK: %[[VAL_28:.*]] = sub i32 %[[VAL_27]], %[[VAL_26]]
81+
// CHECK: %[[VAL_29:.*]] = add i32 %[[VAL_28]], 1
82+
// CHECK: br label %[[VAL_30:.*]]
83+
// CHECK: omp_loop.header: ; preds = %[[VAL_31:.*]], %[[VAL_24]]
84+
// CHECK: %[[VAL_32:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_33:.*]], %[[VAL_31]] ]
85+
// CHECK: br label %[[VAL_34:.*]]
86+
// CHECK: omp_loop.cond: ; preds = %[[VAL_30]]
87+
// CHECK: %[[VAL_35:.*]] = icmp ult i32 %[[VAL_32]], %[[VAL_29]]
88+
// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]]
89+
// CHECK: omp_loop.exit: ; preds = %[[VAL_34]]
90+
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_25]])
91+
// CHECK: %[[VAL_38:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
92+
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_38]])
93+
// CHECK: br label %[[VAL_39:.*]]
94+
// CHECK: omp_loop.after: ; preds = %[[VAL_37]]
95+
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_14]], i64 0, i64 0
96+
// CHECK: store ptr %[[VAL_21]], ptr %[[VAL_40]], align 8
97+
// CHECK: %[[VAL_41:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_14]], i64 0, i64 1
98+
// CHECK: store ptr %[[VAL_23]], ptr %[[VAL_41]], align 8
99+
// CHECK: %[[VAL_42:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
100+
// CHECK: %[[VAL_43:.*]] = call i32 @__kmpc_reduce(ptr @1, i32 %[[VAL_42]], i32 2, i64 16, ptr %[[VAL_14]], ptr @.omp.reduction.func, ptr @.gomp_critical_user_.reduction.var)
101+
// CHECK: switch i32 %[[VAL_43]], label %[[VAL_44:.*]] [
102+
// CHECK: i32 1, label %[[VAL_45:.*]]
103+
// CHECK: i32 2, label %[[VAL_46:.*]]
104+
// CHECK: ]
105+
// CHECK: reduce.switch.atomic: ; preds = %[[VAL_39]]
106+
// CHECK: unreachable
107+
// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_39]]
108+
// CHECK: %[[VAL_47:.*]] = load ptr, ptr %[[VAL_21]], align 8
109+
// CHECK: %[[VAL_48:.*]] = load ptr, ptr %[[VAL_23]], align 8
110+
// CHECK: call void @__kmpc_end_reduce(ptr @1, i32 %[[VAL_42]], ptr @.gomp_critical_user_.reduction.var)
111+
// CHECK: br label %[[VAL_44]]
112+
// CHECK: reduce.finalize: ; preds = %[[VAL_45]], %[[VAL_39]]
113+
// CHECK: %[[VAL_49:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
114+
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_49]])
115+
// CHECK: br label %[[VAL_50:.*]]
116+
// CHECK: omp.region.cont: ; preds = %[[VAL_44]]
117+
// CHECK: br label %[[VAL_51:.*]]
118+
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_50]]
119+
// CHECK: br label %[[VAL_52:.*]]
120+
// CHECK: omp_loop.body: ; preds = %[[VAL_34]]
121+
// CHECK: %[[VAL_53:.*]] = add i32 %[[VAL_32]], %[[VAL_26]]
122+
// CHECK: %[[VAL_54:.*]] = mul i32 %[[VAL_53]], 1
123+
// CHECK: %[[VAL_55:.*]] = add i32 %[[VAL_54]], 0
124+
// CHECK: br label %[[VAL_56:.*]]
125+
// CHECK: omp.wsloop.region: ; preds = %[[VAL_36]]
126+
// CHECK: br label %[[VAL_57:.*]]
127+
// CHECK: omp.region.cont2: ; preds = %[[VAL_56]]
128+
// CHECK: br label %[[VAL_31]]
129+
// CHECK: omp_loop.inc: ; preds = %[[VAL_57]]
130+
// CHECK: %[[VAL_33]] = add nuw i32 %[[VAL_32]], 1
131+
// CHECK: br label %[[VAL_30]]
132+
// CHECK: omp.par.outlined.exit.exitStub: ; preds = %[[VAL_51]]
133+
// CHECK: ret void
134+
// CHECK: %[[VAL_58:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_59:.*]], i64 0, i64 0
135+
// CHECK: %[[VAL_60:.*]] = load ptr, ptr %[[VAL_58]], align 8
136+
// CHECK: %[[VAL_61:.*]] = load ptr, ptr %[[VAL_60]], align 8
137+
// CHECK: %[[VAL_62:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_63:.*]], i64 0, i64 0
138+
// CHECK: %[[VAL_64:.*]] = load ptr, ptr %[[VAL_62]], align 8
139+
// CHECK: %[[VAL_65:.*]] = load ptr, ptr %[[VAL_64]], align 8
140+
// CHECK: %[[VAL_66:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_59]], i64 0, i64 1
141+
// CHECK: %[[VAL_67:.*]] = load ptr, ptr %[[VAL_66]], align 8
142+
// CHECK: %[[VAL_68:.*]] = load ptr, ptr %[[VAL_67]], align 8
143+
// CHECK: %[[VAL_69:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_63]], i64 0, i64 1
144+
// CHECK: %[[VAL_70:.*]] = load ptr, ptr %[[VAL_69]], align 8
145+
// CHECK: %[[VAL_71:.*]] = load ptr, ptr %[[VAL_70]], align 8
146+
// CHECK: ret void
147+

0 commit comments

Comments
 (0)