@@ -429,68 +429,98 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
429
429
// Parser, printer and verifier for ReductionVarList
430
430
// ===----------------------------------------------------------------------===//
431
431
432
- ParseResult
433
- parseReductionClause (OpAsmParser &parser, Region ®ion,
434
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
435
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
436
- SmallVectorImpl<OpAsmParser::Argument> &privates) {
437
- if (failed (parser.parseOptionalKeyword (" reduction" )))
438
- return failure ();
439
-
432
+ ParseResult parseClauseWithRegionArgs (
433
+ OpAsmParser &parser, Region ®ion,
434
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
435
+ SmallVectorImpl<Type> &types, ArrayAttr &symbols,
436
+ SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs) {
440
437
SmallVector<SymbolRefAttr> reductionVec;
438
+ unsigned regionArgOffset = regionPrivateArgs.size ();
441
439
442
440
if (failed (
443
441
parser.parseCommaSeparatedList (OpAsmParser::Delimiter::Paren, [&]() {
444
442
if (parser.parseAttribute (reductionVec.emplace_back ()) ||
445
443
parser.parseOperand (operands.emplace_back ()) ||
446
444
parser.parseArrow () ||
447
- parser.parseArgument (privates .emplace_back ()) ||
445
+ parser.parseArgument (regionPrivateArgs .emplace_back ()) ||
448
446
parser.parseColonType (types.emplace_back ()))
449
447
return failure ();
450
448
return success ();
451
449
})))
452
450
return failure ();
453
451
454
- for (auto [prv, type] : llvm::zip_equal (privates, types)) {
452
+ auto *argsBegin = regionPrivateArgs.begin ();
453
+ MutableArrayRef argsSubrange (argsBegin + regionArgOffset,
454
+ argsBegin + regionArgOffset + types.size ());
455
+ for (auto [prv, type] : llvm::zip_equal (argsSubrange, types)) {
455
456
prv.type = type;
456
457
}
457
458
SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
458
- reductionSymbols = ArrayAttr::get (parser.getContext (), reductions);
459
+ symbols = ArrayAttr::get (parser.getContext (), reductions);
459
460
return success ();
460
461
}
461
462
462
- static void printReductionClause (OpAsmPrinter &p, Operation *op, Region ®ion,
463
- ValueRange operands, TypeRange types,
464
- ArrayAttr reductionSymbols) {
465
- p << " reduction(" ;
466
- llvm::interleaveComma (llvm::zip_equal (reductionSymbols, operands,
467
- region.front ().getArguments (), types),
468
- p, [&p](auto t) {
469
- auto [sym, op, arg, type] = t;
470
- p << sym << " " << op << " -> " << arg << " : "
471
- << type;
472
- });
463
+ static void printClauseWithRegionArgs (OpAsmPrinter &p, Operation *op,
464
+ Region ®ion, StringRef clauseName,
465
+ ValueRange operands, TypeRange types,
466
+ ArrayAttr symbols,
467
+ unsigned regionArgOffset) {
468
+ p << clauseName << " (" ;
469
+
470
+ auto *argsBegin = region.front ().getArguments ().begin ();
471
+ MutableArrayRef argsSubrange (argsBegin + regionArgOffset,
472
+ argsBegin + regionArgOffset + types.size ());
473
+ llvm::interleaveComma (
474
+ llvm::zip_equal (symbols, operands, argsSubrange, types), p, [&p](auto t) {
475
+ auto [sym, op, arg, type] = t;
476
+ p << sym << " " << op << " -> " << arg << " : " << type;
477
+ });
473
478
p << " ) " ;
474
479
}
475
480
476
- static ParseResult
477
- parseParallelRegion (OpAsmParser &parser, Region ®ion,
478
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
481
+ static ParseResult parseParallelRegion (
482
+ OpAsmParser &parser, Region ®ion,
483
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
484
+ SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
485
+ llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
486
+ llvm::SmallVectorImpl<Type> &privateVarsTypes,
487
+ ArrayAttr &privatizerSymbols) {
488
+ llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
489
+
490
+ if (succeeded (parser.parseOptionalKeyword (" reduction" ))) {
491
+ if (failed (parseClauseWithRegionArgs (parser, region, reductionVarOperands,
492
+ reductionVarTypes, reductionSymbols,
493
+ regionPrivateArgs)))
494
+ return failure ();
495
+ }
480
496
481
- llvm::SmallVector<OpAsmParser::Argument> privates;
482
- if (succeeded (parseReductionClause (parser, region, operands, types,
483
- reductionSymbols, privates)))
484
- return parser.parseRegion (region, privates);
497
+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
498
+ if (failed (parseClauseWithRegionArgs (parser, region, privateVarOperands,
499
+ privateVarsTypes, privatizerSymbols,
500
+ regionPrivateArgs)))
501
+ return failure ();
502
+ }
485
503
486
- return parser.parseRegion (region);
504
+ return parser.parseRegion (region, regionPrivateArgs );
487
505
}
488
506
489
507
static void printParallelRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
490
- ValueRange operands, TypeRange types,
491
- ArrayAttr reductionSymbols) {
508
+ ValueRange reductionVarOperands,
509
+ TypeRange reductionVarTypes,
510
+ ArrayAttr reductionSymbols,
511
+ ValueRange privateVarOperands,
512
+ TypeRange privateVarTypes,
513
+ ArrayAttr privatizerSymbols) {
492
514
if (reductionSymbols)
493
- printReductionClause (p, op, region, operands, types, reductionSymbols);
515
+ printClauseWithRegionArgs (p, op, region, " reduction" , reductionVarOperands,
516
+ reductionVarTypes, reductionSymbols,
517
+ /* regionArgOffset=*/ 0 );
518
+
519
+ if (privatizerSymbols)
520
+ printClauseWithRegionArgs (p, op, region, " private" , privateVarOperands,
521
+ privateVarTypes, privatizerSymbols,
522
+ reductionVarOperands.size ());
523
+
494
524
p.printRegion (region, /* printEntryBlockArgs=*/ false );
495
525
}
496
526
@@ -1057,14 +1087,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
1057
1087
builder, state, /* if_expr_var=*/ nullptr , /* num_threads_var=*/ nullptr ,
1058
1088
/* allocate_vars=*/ ValueRange (), /* allocators_vars=*/ ValueRange (),
1059
1089
/* reduction_vars=*/ ValueRange (), /* reductions=*/ nullptr ,
1060
- /* proc_bind_val=*/ nullptr );
1090
+ /* proc_bind_val=*/ nullptr , /* private_vars=*/ ValueRange (),
1091
+ /* privatizers=*/ nullptr );
1061
1092
state.addAttributes (attributes);
1062
1093
}
1063
1094
1095
+ static LogicalResult verifyPrivateVarList (ParallelOp &op) {
1096
+ auto privateVars = op.getPrivateVars ();
1097
+ auto privatizers = op.getPrivatizersAttr ();
1098
+
1099
+ if (privateVars.empty () && (privatizers == nullptr || privatizers.empty ()))
1100
+ return success ();
1101
+
1102
+ auto numPrivateVars = privateVars.size ();
1103
+ auto numPrivatizers = (privatizers == nullptr ) ? 0 : privatizers.size ();
1104
+
1105
+ if (numPrivateVars != numPrivatizers)
1106
+ return op.emitError () << " inconsistent number of private variables and "
1107
+ " privatizer op symbols, private vars: "
1108
+ << numPrivateVars
1109
+ << " vs. privatizer op symbols: " << numPrivatizers;
1110
+
1111
+ for (auto privateVarInfo : llvm::zip (privateVars, privatizers)) {
1112
+ Type varType = std::get<0 >(privateVarInfo).getType ();
1113
+ SymbolRefAttr privatizerSym =
1114
+ std::get<1 >(privateVarInfo).cast <SymbolRefAttr>();
1115
+ PrivateClauseOp privatizerOp =
1116
+ SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1117
+ privatizerSym);
1118
+
1119
+ if (privatizerOp == nullptr )
1120
+ return op.emitError () << " failed to lookup privatizer op with symbol: '"
1121
+ << privatizerSym << " '" ;
1122
+
1123
+ Type privatizerType = privatizerOp.getType ();
1124
+
1125
+ if (varType != privatizerType)
1126
+ return op.emitError ()
1127
+ << " type mismatch between a "
1128
+ << (privatizerOp.getDataSharingType () ==
1129
+ DataSharingClauseType::Private
1130
+ ? " private"
1131
+ : " firstprivate" )
1132
+ << " variable and its privatizer op, var type: " << varType
1133
+ << " vs. privatizer op type: " << privatizerType;
1134
+ }
1135
+
1136
+ return success ();
1137
+ }
1138
+
1064
1139
LogicalResult ParallelOp::verify () {
1065
1140
if (getAllocateVars ().size () != getAllocatorsVars ().size ())
1066
1141
return emitError (
1067
1142
" expected equal sizes for allocate and allocator variables" );
1143
+
1144
+ if (failed (verifyPrivateVarList (*this )))
1145
+ return failure ();
1146
+
1068
1147
return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1069
1148
}
1070
1149
0 commit comments