@@ -41,7 +41,7 @@ void OpenMPDialect::initialize() {
41
41
// /
42
42
// / operand-and-type-list ::= `(` ssa-id-and-type-list `)`
43
43
// / ssa-id-and-type-list ::= ssa-id-and-type |
44
- // / ssa-id-and-type ',' ssa-id-and-type-list
44
+ // / ssa-id-and-type `,` ssa-id-and-type-list
45
45
// / ssa-id-and-type ::= ssa-id `:` type
46
46
static ParseResult
47
47
parseOperandAndTypeList (OpAsmParser &parser,
@@ -65,6 +65,52 @@ parseOperandAndTypeList(OpAsmParser &parser,
65
65
return success ();
66
66
}
67
67
68
+ // / Parse an allocate clause with allocators and a list of operands with types.
69
+ // /
70
+ // / operand-and-type-list ::= `(` allocate-operand-list `)`
71
+ // / allocate-operand-list :: = allocate-operand |
72
+ // / allocator-operand `,` allocate-operand-list
73
+ // / allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
74
+ // / ssa-id-and-type ::= ssa-id `:` type
75
+ static ParseResult parseAllocateAndAllocator (
76
+ OpAsmParser &parser,
77
+ SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
78
+ SmallVectorImpl<Type> &typesAllocate,
79
+ SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
80
+ SmallVectorImpl<Type> &typesAllocator) {
81
+ if (parser.parseLParen ())
82
+ return failure ();
83
+
84
+ do {
85
+ OpAsmParser::OperandType operand;
86
+ Type type;
87
+
88
+ if (parser.parseOperand (operand) || parser.parseColonType (type))
89
+ return failure ();
90
+ operandsAllocator.push_back (operand);
91
+ typesAllocator.push_back (type);
92
+ if (parser.parseArrow ())
93
+ return failure ();
94
+ if (parser.parseOperand (operand) || parser.parseColonType (type))
95
+ return failure ();
96
+
97
+ operandsAllocate.push_back (operand);
98
+ typesAllocate.push_back (type);
99
+ } while (succeeded (parser.parseOptionalComma ()));
100
+
101
+ if (parser.parseRParen ())
102
+ return failure ();
103
+
104
+ return success ();
105
+ }
106
+
107
+ static LogicalResult verifyParallelOp (ParallelOp op) {
108
+ if (op.allocate_vars ().size () != op.allocators_vars ().size ())
109
+ return op.emitError (
110
+ " expected equal sizes for allocate and allocator variables" );
111
+ return success ();
112
+ }
113
+
68
114
static void printParallelOp (OpAsmPrinter &p, ParallelOp op) {
69
115
p << " omp.parallel" ;
70
116
@@ -84,10 +130,26 @@ static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
84
130
}
85
131
}
86
132
};
133
+
134
+ // Print allocator and allocate parameters
135
+ auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
136
+ OperandRange varsAllocator) {
137
+ if (varsAllocate.empty ())
138
+ return ;
139
+
140
+ p << " allocate(" ;
141
+ for (unsigned i = 0 ; i < varsAllocate.size (); ++i) {
142
+ std::string separator = i == varsAllocate.size () - 1 ? " )" : " , " ;
143
+ p << varsAllocator[i] << " : " << varsAllocator[i].getType () << " -> " ;
144
+ p << varsAllocate[i] << " : " << varsAllocate[i].getType () << separator;
145
+ }
146
+ };
147
+
87
148
printDataVars (" private" , op.private_vars ());
88
149
printDataVars (" firstprivate" , op.firstprivate_vars ());
89
150
printDataVars (" shared" , op.shared_vars ());
90
151
printDataVars (" copyin" , op.copyin_vars ());
152
+ printAllocateAndAllocator (op.allocate_vars (), op.allocators_vars ());
91
153
92
154
if (auto def = op.default_val ())
93
155
p << " default(" << def->drop_front (3 ) << " )" ;
@@ -118,6 +180,7 @@ static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
118
180
// / firstprivate ::= `firstprivate` operand-and-type-list
119
181
// / shared ::= `shared` operand-and-type-list
120
182
// / copyin ::= `copyin` operand-and-type-list
183
+ // / allocate ::= `allocate` operand-and-type `->` operand-and-type-list
121
184
// / default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
122
185
// / procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
123
186
// /
@@ -134,7 +197,11 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
134
197
SmallVector<Type, 4 > sharedTypes;
135
198
SmallVector<OpAsmParser::OperandType, 4 > copyins;
136
199
SmallVector<Type, 4 > copyinTypes;
137
- std::array<int , 6 > segments{0 , 0 , 0 , 0 , 0 , 0 };
200
+ SmallVector<OpAsmParser::OperandType, 4 > allocates;
201
+ SmallVector<Type, 4 > allocateTypes;
202
+ SmallVector<OpAsmParser::OperandType, 4 > allocators;
203
+ SmallVector<Type, 4 > allocatorTypes;
204
+ std::array<int , 8 > segments{0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 };
138
205
llvm::StringRef keyword;
139
206
bool defaultVal = false ;
140
207
bool procBind = false ;
@@ -145,6 +212,8 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
145
212
const int firstprivateClausePos = 3 ;
146
213
const int sharedClausePos = 4 ;
147
214
const int copyinClausePos = 5 ;
215
+ const int allocateClausePos = 6 ;
216
+ const int allocatorPos = 7 ;
148
217
const llvm::StringRef opName = result.name .getStringRef ();
149
218
150
219
while (succeeded (parser.parseOptionalKeyword (&keyword))) {
@@ -192,6 +261,15 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
192
261
if (parseOperandAndTypeList (parser, copyins, copyinTypes))
193
262
return failure ();
194
263
segments[copyinClausePos] = copyins.size ();
264
+ } else if (keyword == " allocate" ) {
265
+ // fail if there was already another allocate clause
266
+ if (segments[allocateClausePos])
267
+ return allowedOnce (parser, " allocate" , opName);
268
+ if (parseAllocateAndAllocator (parser, allocates, allocateTypes,
269
+ allocators, allocatorTypes))
270
+ return failure ();
271
+ segments[allocateClausePos] = allocates.size ();
272
+ segments[allocatorPos] = allocators.size ();
195
273
} else if (keyword == " default" ) {
196
274
// fail if there was already another default clause
197
275
if (defaultVal)
@@ -261,6 +339,18 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
261
339
result.operands ))
262
340
return failure ();
263
341
342
+ // Add allocate parameters
343
+ if (segments[allocateClausePos] &&
344
+ parser.resolveOperands (allocates, allocateTypes, allocates[0 ].location ,
345
+ result.operands ))
346
+ return failure ();
347
+
348
+ // Add allocator parameters
349
+ if (segments[allocatorPos] &&
350
+ parser.resolveOperands (allocators, allocatorTypes, allocators[0 ].location ,
351
+ result.operands ))
352
+ return failure ();
353
+
264
354
result.addAttribute (" operand_segment_sizes" ,
265
355
parser.getBuilder ().getI32VectorAttr (segments));
266
356
0 commit comments