Skip to content

Commit 3ffaa8b

Browse files
committed
[MLIR][OpenMP][Flang] Normalize clause arguments names
Currently, there are some inconsistencies to how clause arguments are named in the OpenMP dialect. Additionally, the clause operand structures associated to them also diverge in certain cases. The purpose of this patch is to normalize argument names across all `OpenMP_Clause` tablegen definitions and clause operand structures. This has the benefit of providing more consistent representations for clauses in the dialect, but the main short-term advantage is that it enables the development of an OpenMP-specific tablegen backend to automatically generate the clause operand structures without breaking dependent code. The main re-naming decisions made in this patch are the following: - Variadic arguments (i.e. multiple values) have the "_vars" suffix. This and other similar suffixes are removed from array attribute arguments. - Individual required or optional value arguments do not have any suffix added to them (e.g. "val", "var", "expr", ...), except for `if` which would otherwise result in an invalid C++ variable name. - The associated clause's name is prepended to argument names that don't already contain it as part of its name. This avoids future collisions between arguments named the same way on different clauses and adding both clauses to the same operation. - Privatization and reduction related arguments that contain lists of symbols pointing to privatizer/reducer operations use the "_syms" suffix. This removes the inconsistencies between the names for "copyprivate_funcs", "[in]reductions", "privatizers", etc. - General improvements to names, replacement of camel case for snake case everywhere, etc. - Renaming of operation-associated operand structures to use the "Operands" suffix in place of "ClauseOps", to better differentiate between clause operand structures and operation operand structures. - Fields on clause operand structures are sorted according to the tablegen definition of the same clause. The assembly format for a few arguments is updated to better reflect the clause they are associated with: - `chunk_size` -> `dist_schedule_chunk_size` - `grain_size` -> `grainsize` - `simd` -> `par_level_simd`
1 parent 06c1e1b commit 3ffaa8b

File tree

14 files changed

+920
-954
lines changed

14 files changed

+920
-954
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,13 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
187187
// The types of lower bound, upper bound, and step are converted into the
188188
// type of the loop variable if necessary.
189189
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
190-
for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
191-
result.loopLBVar[it] =
192-
firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
193-
result.loopUBVar[it] =
194-
firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
195-
result.loopStepVar[it] =
196-
firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
190+
for (unsigned it = 0; it < (unsigned)result.collapseLowerBound.size(); it++) {
191+
result.collapseLowerBound[it] = firOpBuilder.createConvert(
192+
loc, loopVarType, result.collapseLowerBound[it]);
193+
result.collapseUpperBound[it] = firOpBuilder.createConvert(
194+
loc, loopVarType, result.collapseUpperBound[it]);
195+
result.collapseStep[it] =
196+
firOpBuilder.createConvert(loc, loopVarType, result.collapseStep[it]);
197197
}
198198
}
199199

@@ -232,15 +232,15 @@ bool ClauseProcessor::processCollapse(
232232
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
233233
assert(bounds && "Expected bounds for worksharing do loop");
234234
lower::StatementContext stmtCtx;
235-
result.loopLBVar.push_back(fir::getBase(
235+
result.collapseLowerBound.push_back(fir::getBase(
236236
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
237-
result.loopUBVar.push_back(fir::getBase(
237+
result.collapseUpperBound.push_back(fir::getBase(
238238
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
239239
if (bounds->step) {
240-
result.loopStepVar.push_back(fir::getBase(
240+
result.collapseStep.push_back(fir::getBase(
241241
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
242242
} else { // If `step` is not present, assume it as `1`.
243-
result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
243+
result.collapseStep.push_back(firOpBuilder.createIntegerConstant(
244244
currentLocation, firOpBuilder.getIntegerType(32), 1));
245245
}
246246
iv.push_back(bounds->name.thing.symbol);
@@ -291,8 +291,7 @@ bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
291291
}
292292
}
293293
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
294-
result.deviceVar =
295-
fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
294+
result.device = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
296295
return true;
297296
}
298297
return false;
@@ -322,10 +321,10 @@ bool ClauseProcessor::processDistSchedule(
322321
lower::StatementContext &stmtCtx,
323322
mlir::omp::DistScheduleClauseOps &result) const {
324323
if (auto *clause = findUniqueClause<omp::clause::DistSchedule>()) {
325-
result.distScheduleStaticAttr = converter.getFirOpBuilder().getUnitAttr();
324+
result.distScheduleStatic = converter.getFirOpBuilder().getUnitAttr();
326325
const auto &chunkSize = std::get<std::optional<ExprTy>>(clause->t);
327326
if (chunkSize)
328-
result.distScheduleChunkSizeVar =
327+
result.distScheduleChunkSize =
329328
fir::getBase(converter.genExprValue(*chunkSize, stmtCtx));
330329
return true;
331330
}
@@ -335,7 +334,7 @@ bool ClauseProcessor::processDistSchedule(
335334
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
336335
mlir::omp::FilterClauseOps &result) const {
337336
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
338-
result.filteredThreadIdVar =
337+
result.filteredThreadId =
339338
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
340339
return true;
341340
}
@@ -351,7 +350,7 @@ bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx,
351350

352351
mlir::Value finalVal =
353352
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
354-
result.finalVar = firOpBuilder.createConvert(
353+
result.final = firOpBuilder.createConvert(
355354
clauseLocation, firOpBuilder.getI1Type(), finalVal);
356355
return true;
357356
}
@@ -362,19 +361,19 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
362361
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
363362
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
364363
int64_t hintValue = *evaluate::ToInt64(clause->v);
365-
result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
364+
result.hint = firOpBuilder.getI64IntegerAttr(hintValue);
366365
return true;
367366
}
368367
return false;
369368
}
370369

371370
bool ClauseProcessor::processMergeable(
372371
mlir::omp::MergeableClauseOps &result) const {
373-
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
372+
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
374373
}
375374

376375
bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
377-
return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
376+
return markClauseOccurrence<omp::clause::Nowait>(result.nowait);
378377
}
379378

380379
bool ClauseProcessor::processNumTeams(
@@ -385,7 +384,7 @@ bool ClauseProcessor::processNumTeams(
385384
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
386385
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
387386
auto &upperBound = std::get<ExprTy>(clause->t);
388-
result.numTeamsUpperVar =
387+
result.numTeamsUpper =
389388
fir::getBase(converter.genExprValue(upperBound, stmtCtx));
390389
return true;
391390
}
@@ -397,7 +396,7 @@ bool ClauseProcessor::processNumThreads(
397396
mlir::omp::NumThreadsClauseOps &result) const {
398397
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
399398
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
400-
result.numThreadsVar =
399+
result.numThreads =
401400
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
402401
return true;
403402
}
@@ -408,17 +407,17 @@ bool ClauseProcessor::processOrder(mlir::omp::OrderClauseOps &result) const {
408407
using Order = omp::clause::Order;
409408
if (auto *clause = findUniqueClause<Order>()) {
410409
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
411-
result.orderAttr = mlir::omp::ClauseOrderKindAttr::get(
410+
result.order = mlir::omp::ClauseOrderKindAttr::get(
412411
firOpBuilder.getContext(), mlir::omp::ClauseOrderKind::Concurrent);
413412
const auto &modifier =
414413
std::get<std::optional<Order::OrderModifier>>(clause->t);
415414
if (modifier && *modifier == Order::OrderModifier::Unconstrained) {
416-
result.orderModAttr = mlir::omp::OrderModifierAttr::get(
415+
result.orderMod = mlir::omp::OrderModifierAttr::get(
417416
firOpBuilder.getContext(), mlir::omp::OrderModifier::unconstrained);
418417
} else {
419418
// "If order-modifier is not unconstrained, the behavior is as if the
420419
// reproducible modifier is present."
421-
result.orderModAttr = mlir::omp::OrderModifierAttr::get(
420+
result.orderMod = mlir::omp::OrderModifierAttr::get(
422421
firOpBuilder.getContext(), mlir::omp::OrderModifier::reproducible);
423422
}
424423
return true;
@@ -433,7 +432,7 @@ bool ClauseProcessor::processOrdered(
433432
int64_t orderedClauseValue = 0l;
434433
if (clause->v.has_value())
435434
orderedClauseValue = *evaluate::ToInt64(*clause->v);
436-
result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
435+
result.ordered = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
437436
return true;
438437
}
439438
return false;
@@ -443,8 +442,7 @@ bool ClauseProcessor::processPriority(
443442
lower::StatementContext &stmtCtx,
444443
mlir::omp::PriorityClauseOps &result) const {
445444
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
446-
result.priorityVar =
447-
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
445+
result.priority = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
448446
return true;
449447
}
450448
return false;
@@ -454,7 +452,7 @@ bool ClauseProcessor::processProcBind(
454452
mlir::omp::ProcBindClauseOps &result) const {
455453
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
456454
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
457-
result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
455+
result.procBindKind = genProcBindKindAttr(firOpBuilder, *clause);
458456
return true;
459457
}
460458
return false;
@@ -465,7 +463,7 @@ bool ClauseProcessor::processSafelen(
465463
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
466464
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
467465
const std::optional<std::int64_t> safelenVal = evaluate::ToInt64(clause->v);
468-
result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
466+
result.safelen = firOpBuilder.getI64IntegerAttr(*safelenVal);
469467
return true;
470468
}
471469
return false;
@@ -498,19 +496,19 @@ bool ClauseProcessor::processSchedule(
498496
break;
499497
}
500498

501-
result.scheduleValAttr =
499+
result.scheduleKind =
502500
mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
503501

504-
mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
505-
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
506-
result.scheduleModAttr =
507-
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
502+
mlir::omp::ScheduleModifier scheduleMod = getScheduleModifier(*clause);
503+
if (scheduleMod != mlir::omp::ScheduleModifier::none)
504+
result.scheduleMod =
505+
mlir::omp::ScheduleModifierAttr::get(context, scheduleMod);
508506

509507
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
510-
result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
508+
result.scheduleSimd = firOpBuilder.getUnitAttr();
511509

512510
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
513-
result.scheduleChunkVar =
511+
result.scheduleChunk =
514512
fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
515513

516514
return true;
@@ -523,7 +521,7 @@ bool ClauseProcessor::processSimdlen(
523521
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
524522
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
525523
const std::optional<std::int64_t> simdlenVal = evaluate::ToInt64(clause->v);
526-
result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
524+
result.simdlen = firOpBuilder.getI64IntegerAttr(*simdlenVal);
527525
return true;
528526
}
529527
return false;
@@ -533,15 +531,15 @@ bool ClauseProcessor::processThreadLimit(
533531
lower::StatementContext &stmtCtx,
534532
mlir::omp::ThreadLimitClauseOps &result) const {
535533
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
536-
result.threadLimitVar =
534+
result.threadLimit =
537535
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
538536
return true;
539537
}
540538
return false;
541539
}
542540

543541
bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
544-
return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
542+
return markClauseOccurrence<omp::clause::Untied>(result.untied);
545543
}
546544

547545
//===----------------------------------------------------------------------===//
@@ -565,7 +563,7 @@ static void
565563
addAlignedClause(lower::AbstractConverter &converter,
566564
const omp::clause::Aligned &clause,
567565
llvm::SmallVectorImpl<mlir::Value> &alignedVars,
568-
llvm::SmallVectorImpl<mlir::Attribute> &alignmentAttrs) {
566+
llvm::SmallVectorImpl<mlir::Attribute> &alignments) {
569567
using Aligned = omp::clause::Aligned;
570568
lower::StatementContext stmtCtx;
571569
mlir::IntegerAttr alignmentValueAttr;
@@ -594,7 +592,7 @@ addAlignedClause(lower::AbstractConverter &converter,
594592
alignmentValueAttr = builder.getI64IntegerAttr(alignment);
595593
// All the list items in a aligned clause will have same alignment
596594
for (std::size_t i = 0; i < objects.size(); i++)
597-
alignmentAttrs.push_back(alignmentValueAttr);
595+
alignments.push_back(alignmentValueAttr);
598596
}
599597
}
600598

@@ -603,7 +601,7 @@ bool ClauseProcessor::processAligned(
603601
return findRepeatableClause<omp::clause::Aligned>(
604602
[&](const omp::clause::Aligned &clause, const parser::CharBlock &) {
605603
addAlignedClause(converter, clause, result.alignedVars,
606-
result.alignmentAttrs);
604+
result.alignments);
607605
});
608606
}
609607

@@ -798,7 +796,7 @@ bool ClauseProcessor::processCopyprivate(
798796
result.copyprivateVars.push_back(cpVar);
799797
mlir::func::FuncOp funcOp =
800798
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
801-
result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
799+
result.copyprivateSyms.push_back(mlir::SymbolRefAttr::get(funcOp));
802800
};
803801

804802
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -832,7 +830,7 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
832830

833831
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
834832
genDependKindAttr(firOpBuilder, kind);
835-
result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
833+
result.dependKinds.append(objects.size(), dependTypeOperand);
836834

837835
for (const omp::Object &object : objects) {
838836
assert(object.ref() && "Expecting designator");
@@ -1037,10 +1035,9 @@ bool ClauseProcessor::processReduction(
10371035

10381036
// Copy local lists into the output.
10391037
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
1040-
llvm::copy(reduceVarByRef,
1041-
std::back_inserter(result.reductionVarsByRef));
1038+
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
10421039
llvm::copy(reductionDeclSymbols,
1043-
std::back_inserter(result.reductionDeclSymbols));
1040+
std::back_inserter(result.reductionSyms));
10441041

10451042
if (outReductionTypes) {
10461043
outReductionTypes->reserve(outReductionTypes->size() +

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
226226
firOpBuilder.setInsertionPoint(lastOper);
227227

228228
mlir::Value iv = loopOp.getIVs()[0];
229-
mlir::Value ub = loopOp.getUpperBound()[0];
230-
mlir::Value step = loopOp.getStep()[0];
229+
mlir::Value ub = loopOp.getCollapseUpperBound()[0];
230+
mlir::Value step = loopOp.getCollapseStep()[0];
231231

232232
// v = iv + step
233233
// cmp = step < 0 ? v < ub : v > ub
@@ -537,7 +537,7 @@ void DataSharingProcessor::doPrivatize(const semantics::Symbol *sym,
537537
}();
538538

539539
if (clauseOps) {
540-
clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp));
540+
clauseOps->privateSyms.push_back(mlir::SymbolRefAttr::get(privatizerOp));
541541
clauseOps->privateVars.push_back(hsb.getAddr());
542542
}
543543

0 commit comments

Comments
 (0)