@@ -124,7 +124,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
124
124
static ParseResult
125
125
parseCommonStructuredOpParts (OpAsmParser &parser, OperationState &result,
126
126
SmallVectorImpl<Type> &inputTypes,
127
- SmallVectorImpl<Type> &outputTypes) {
127
+ SmallVectorImpl<Type> &outputTypes,
128
+ bool addOperandSegmentSizes = true ) {
128
129
SMLoc inputsOperandsLoc, outputsOperandsLoc;
129
130
SmallVector<OpAsmParser::UnresolvedOperand, 4 > inputsOperands,
130
131
outputsOperands;
@@ -155,10 +156,12 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
155
156
result.operands ))
156
157
return failure ();
157
158
158
- result.addAttribute (" operand_segment_sizes" ,
159
- parser.getBuilder ().getDenseI32ArrayAttr (
160
- {static_cast <int32_t >(inputsOperands.size ()),
161
- static_cast <int32_t >(outputsOperands.size ())}));
159
+ if (addOperandSegmentSizes) {
160
+ result.addAttribute (" operand_segment_sizes" ,
161
+ parser.getBuilder ().getDenseI32ArrayAttr (
162
+ {static_cast <int32_t >(inputsOperands.size ()),
163
+ static_cast <int32_t >(outputsOperands.size ())}));
164
+ }
162
165
return success ();
163
166
}
164
167
@@ -1180,6 +1183,209 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
1180
1183
return foldMemRefCast (*this );
1181
1184
}
1182
1185
1186
+ // ===----------------------------------------------------------------------===//
1187
+ // ReduceOp
1188
+ // ===----------------------------------------------------------------------===//
1189
+
1190
+ ArrayAttr ReduceOp::getIteratorTypes () {
1191
+ int64_t inputRank = getInputs ()[0 ].getType ().cast <ShapedType>().getRank ();
1192
+ SmallVector<StringRef> iteratorTypes (inputRank,
1193
+ getParallelIteratorTypeName ());
1194
+ for (int64_t reductionDim : getDimensions ())
1195
+ iteratorTypes[reductionDim] = getReductionIteratorTypeName ();
1196
+ return Builder (getContext ()).getStrArrayAttr (iteratorTypes);
1197
+ }
1198
+
1199
+ ArrayAttr ReduceOp::getIndexingMaps () {
1200
+ int64_t inputRank = getInputs ()[0 ].getType ().cast <ShapedType>().getRank ();
1201
+ SmallVector<AffineMap> affineMaps (
1202
+ getNumInputs (),
1203
+ AffineMap::getMultiDimIdentityMap (inputRank, getContext ()));
1204
+ AffineMap resultMap =
1205
+ AffineMap::getMultiDimIdentityMap (inputRank, getContext ())
1206
+ .dropResults (getDimensions ());
1207
+ for (int64_t i = 0 , e = getNumOutputs (); i < e; ++i)
1208
+ affineMaps.push_back (resultMap);
1209
+ return Builder (getContext ()).getAffineMapArrayAttr (affineMaps);
1210
+ }
1211
+
1212
+ void ReduceOp::getEffects (
1213
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1214
+ &effects) {
1215
+ SmallVector<Value> inputBuffers = getInputBufferOperands ();
1216
+ SmallVector<Value> outputBuffers = getOutputBufferOperands ();
1217
+ getGenericEffectsImpl (effects, getOperation ()->getResults (), inputBuffers,
1218
+ outputBuffers);
1219
+ }
1220
+
1221
+ static ParseResult parseDstStyleOp (
1222
+ OpAsmParser &parser, OperationState &result,
1223
+ function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1224
+ nullptr) {
1225
+ // Parse `ins` and `outs`.
1226
+ SmallVector<Type, 4 > inputTypes, outputTypes;
1227
+ if (parseCommonStructuredOpParts (parser, result, inputTypes, outputTypes,
1228
+ /* addOperandSegmentSizes=*/ false ))
1229
+ return failure ();
1230
+
1231
+ // Add result types.
1232
+ for (Type outputType : outputTypes) {
1233
+ if (!outputType.isa <RankedTensorType>())
1234
+ return failure ();
1235
+ result.addTypes (outputType);
1236
+ }
1237
+
1238
+ // Parse required attributes.
1239
+ if (parseAttrsFn && failed (parseAttrsFn (parser, result.attributes )))
1240
+ return failure ();
1241
+
1242
+ // Parse optional attributes.
1243
+ if (parser.parseOptionalAttrDict (result.attributes ))
1244
+ return failure ();
1245
+ return success ();
1246
+ }
1247
+
1248
+ static ParseResult parseDenseI64ArrayAttr (OpAsmParser &parser,
1249
+ NamedAttrList &attributes,
1250
+ StringRef attributeName) {
1251
+ if (parser.parseKeyword (attributeName) || parser.parseEqual ())
1252
+ return failure ();
1253
+
1254
+ attributes.set (attributeName, DenseI64ArrayAttr::parse (parser, Type{}));
1255
+ return success ();
1256
+ }
1257
+
1258
+ ParseResult ReduceOp::parse (OpAsmParser &parser, OperationState &result) {
1259
+ if (parseDstStyleOp (
1260
+ parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1261
+ return parseDenseI64ArrayAttr (parser, attributes, " dimensions" );
1262
+ }))
1263
+ return failure ();
1264
+
1265
+ SmallVector<OpAsmParser::Argument> regionArgs;
1266
+ if (parser.parseArgumentList (regionArgs, OpAsmParser::Delimiter::Paren,
1267
+ /* allowType=*/ true , /* allowAttrs=*/ true )) {
1268
+ return failure ();
1269
+ }
1270
+
1271
+ Region *body = result.addRegion ();
1272
+ if (parser.parseRegion (*body, regionArgs))
1273
+ return failure ();
1274
+
1275
+ return success ();
1276
+ }
1277
+
1278
+ static void printDenseI64ArrayAttr (OpAsmPrinter &p, StringRef attributeName,
1279
+ ArrayRef<int64_t > attributeValue) {
1280
+ p << " " << attributeName << " = [" << attributeValue << " ] " ;
1281
+ }
1282
+
1283
+ void ReduceOp::print (OpAsmPrinter &p) {
1284
+ printCommonStructuredOpParts (p, getInputs (), getOutputs ());
1285
+ printDenseI64ArrayAttr (p, getDimensionsAttrName (), getDimensions ());
1286
+ p.printOptionalAttrDict ((*this )->getAttrs (), {getDimensionsAttrName ()});
1287
+
1288
+ p << " (" ;
1289
+ llvm::interleaveComma (getCombiner ().getArguments (), p,
1290
+ [&](auto arg) { p.printRegionArgument (arg); });
1291
+ p << " ) " ;
1292
+
1293
+ p.printRegion (getCombiner (), /* printEntryBlockArgs=*/ false );
1294
+ }
1295
+
1296
+ LogicalResult ReduceOp::verify () {
1297
+ ArrayRef<int64_t > dimensionsRef = getDimensions ();
1298
+
1299
+ for (int64_t i = 1 ; i < getNumInputs (); ++i) {
1300
+ if (getInputs ()[i].getType ().cast <ShapedType>().getShape () !=
1301
+ getInputs ()[0 ].getType ().cast <ShapedType>().getShape ()) {
1302
+ return emitOpError () << " expects all inputs to have the same shapes. "
1303
+ " Shape at input-index "
1304
+ << i
1305
+ << " is not equal to the shape at input-index 0." ;
1306
+ }
1307
+ }
1308
+ for (int64_t i = 1 ; i < getNumOutputs (); ++i) {
1309
+ if (getInits ()[i].getType ().cast <ShapedType>().getShape () !=
1310
+ getInits ()[0 ].getType ().cast <ShapedType>().getShape ()) {
1311
+ return emitOpError () << " expects all outputs to have the same shapes. "
1312
+ " Shape at output-index "
1313
+ << i
1314
+ << " is not equal to the shape at output-index 0." ;
1315
+ }
1316
+ }
1317
+ auto inputType = getInputs ()[0 ].getType ().cast <ShapedType>();
1318
+ auto initType = getInits ()[0 ].getType ().cast <ShapedType>();
1319
+
1320
+ DenseSet<int64_t > dimensionsToReduce;
1321
+ int64_t lastDimension = -1 ;
1322
+ for (int64_t dimension : dimensionsRef) {
1323
+ if (dimension < 0 || dimension >= inputType.getRank ()) {
1324
+ return emitOpError ()
1325
+ << " dimensions for reduction should be in the range [0, "
1326
+ << inputType.getRank () - 1 << " ]." ;
1327
+ }
1328
+ if (dimension <= lastDimension) {
1329
+ return emitOpError ()
1330
+ << " reduction dimensions are not in increasing order: "
1331
+ << dimensionsRef;
1332
+ }
1333
+
1334
+ lastDimension = dimension;
1335
+ dimensionsToReduce.insert (dimension);
1336
+ }
1337
+
1338
+ auto inputDims = inputType.getShape ();
1339
+ auto initDims = initType.getShape ();
1340
+
1341
+ // Input dimensions that will be left after the reduction.
1342
+ SmallVector<int64_t > reducedInputDims;
1343
+ for (const auto &en : llvm::enumerate (inputDims)) {
1344
+ if (!dimensionsToReduce.count (en.index ()))
1345
+ reducedInputDims.push_back (en.value ());
1346
+ }
1347
+
1348
+ if (reducedInputDims.size () != initType.getRank ()) {
1349
+ return emitOpError () << " number of dimensions after reduction "
1350
+ << reducedInputDims.size ()
1351
+ << " doesn't match the init rank "
1352
+ << initType.getRank ();
1353
+ }
1354
+
1355
+ if (reducedInputDims != initDims)
1356
+ return emitOpError () << " init dimensions [" << initDims
1357
+ << " ] doesn't match input dimensions after reduction ["
1358
+ << reducedInputDims << " ]" ;
1359
+
1360
+ Block *block = getBody ();
1361
+ if (block->getNumArguments () != this ->getNumOperands ())
1362
+ return emitOpError ()
1363
+ << " mismatching number of operands and block arguments" ;
1364
+
1365
+ // Check that the first block arguments match the element type of the inputs.
1366
+ for (auto [input, bbArg] : llvm::zip (getInputs (), block->getArguments ())) {
1367
+ Type inputElementType = input.getType ().cast <ShapedType>().getElementType ();
1368
+ if (inputElementType != bbArg.getType ())
1369
+ return emitOpError ()
1370
+ << " input element type " << inputElementType
1371
+ << " does not match corresponding block argument type "
1372
+ << bbArg.getType ();
1373
+ }
1374
+
1375
+ // Check that the last block arguments match the element type of the outputs.
1376
+ for (auto [output, bbArg] : llvm::zip (
1377
+ getOutputs (), block->getArguments ().take_back (getNumOutputs ()))) {
1378
+ auto outputElementType =
1379
+ output.getType ().cast <ShapedType>().getElementType ();
1380
+ if (outputElementType != bbArg.getType ())
1381
+ return emitOpError ()
1382
+ << " output element type " << outputElementType
1383
+ << " does not match corresponding block argument type "
1384
+ << bbArg.getType ();
1385
+ }
1386
+ return success ();
1387
+ }
1388
+
1183
1389
// ===----------------------------------------------------------------------===//
1184
1390
// InitTensorOp
1185
1391
// ===----------------------------------------------------------------------===//
0 commit comments