Skip to content

Commit fa79dff

Browse files
committed
[mlir][Linalg] Reland Add ReduceOp to Linalg structured ops.
This op will allow to model (variadic) reductions with this special op instead of using GenericOp. This reverts commit 535fd75. Additional fix: implement a getLibraryName method.
1 parent e0d5012 commit fa79dff

File tree

5 files changed

+493
-6
lines changed

5 files changed

+493
-6
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,70 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
221221
}
222222

223223

224+
//===----------------------------------------------------------------------===//
225+
// Reduce op.
226+
//===----------------------------------------------------------------------===//
227+
228+
def TensorOrMemref :
229+
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
230+
231+
def ReduceOp : LinalgStructuredBase_Op<"reduce", [
232+
SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">
233+
]> {
234+
let summary = "Reduce operator";
235+
let description = [{
236+
Executes `combiner` on the `dimensions` of `inputs` and returns the
237+
reduced result. The `dimensions` attribute needs to list the reduction
238+
dimensions in increasing order.
239+
240+
Example:
241+
```
242+
%reduce = linalg.reduce
243+
ins(%input:tensor<16x32x64xf32>)
244+
outs(%init:tensor<16x64xf32>)
245+
dimensions = [1]
246+
(%in: f32, %out: f32) {
247+
%0 = arith.addf %in, %out: f32
248+
linalg.yield %0: f32
249+
}
250+
```
251+
}];
252+
253+
let arguments = (ins
254+
// Input arg
255+
Variadic<TensorOrMemref>:$inputs,
256+
// Output arg
257+
Variadic<TensorOrMemref>:$inits,
258+
259+
DenseI64ArrayAttr:$dimensions
260+
);
261+
let results = (outs Variadic<TensorOrMemref>);
262+
let regions = (region SizedRegion<1>:$combiner);
263+
264+
let extraClassDeclaration = structuredOpsBaseDecls # [{
265+
// Declare functions necessary for LinalgStructuredInterface.
266+
ArrayAttr getIteratorTypes();
267+
ArrayAttr getIndexingMaps();
268+
std::string getLibraryCallName() {
269+
return "op_has_no_registered_library_name";
270+
}
271+
272+
// Implement functions necessary for DestinationStyleOpInterface.
273+
mlir::ValueRange getOutputs() { return getInits(); }
274+
unsigned getNumInputs() { return getInputs().size(); };
275+
unsigned getNumOutputs() { return getInits().size(); };
276+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
277+
mlir::ArrayRef<mlir::NamedAttribute>)>
278+
getRegionBuilder() {
279+
return nullptr;
280+
}
281+
}];
282+
283+
let hasCustomAssemblyFormat = 1;
284+
let hasVerifier = 1;
285+
}
286+
287+
224288
//===----------------------------------------------------------------------===//
225289
// Named Linalg ops, implemented as a declarative configurations of generic ops.
226290
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/AffineMap.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,21 @@ class AffineMap {
243243

244244
/// Returns a new AffineMap with the same number of dims and symbols and one
245245
/// less result at `pos`, dropped.
246-
AffineMap dropResult(unsigned pos) {
246+
AffineMap dropResult(int64_t pos) {
247247
auto exprs = llvm::to_vector<4>(getResults());
248248
exprs.erase(exprs.begin() + pos);
249249
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
250250
}
251251

252+
// Returns a new AffineMap with the same number of dims and symbols, but all
253+
// positions in `positions` dropped from results.
254+
AffineMap dropResults(ArrayRef<int64_t> positions) {
255+
AffineMap resultMap = *this;
256+
for (int64_t pos : positions)
257+
resultMap = resultMap.dropResult(pos);
258+
return resultMap;
259+
}
260+
252261
/// Returns a new AffineMap with the same number of dims and symbols and an
253262
/// extra result inserted at `pos`.
254263
AffineMap insertResult(AffineExpr expr, unsigned pos) {

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 211 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
124124
static ParseResult
125125
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
126126
SmallVectorImpl<Type> &inputTypes,
127-
SmallVectorImpl<Type> &outputTypes) {
127+
SmallVectorImpl<Type> &outputTypes,
128+
bool addOperandSegmentSizes = true) {
128129
SMLoc inputsOperandsLoc, outputsOperandsLoc;
129130
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
130131
outputsOperands;
@@ -155,10 +156,12 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
155156
result.operands))
156157
return failure();
157158

158-
result.addAttribute("operand_segment_sizes",
159-
parser.getBuilder().getDenseI32ArrayAttr(
160-
{static_cast<int32_t>(inputsOperands.size()),
161-
static_cast<int32_t>(outputsOperands.size())}));
159+
if (addOperandSegmentSizes) {
160+
result.addAttribute("operand_segment_sizes",
161+
parser.getBuilder().getDenseI32ArrayAttr(
162+
{static_cast<int32_t>(inputsOperands.size()),
163+
static_cast<int32_t>(outputsOperands.size())}));
164+
}
162165
return success();
163166
}
164167

@@ -1180,6 +1183,209 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
11801183
return foldMemRefCast(*this);
11811184
}
11821185

1186+
//===----------------------------------------------------------------------===//
1187+
// ReduceOp
1188+
//===----------------------------------------------------------------------===//
1189+
1190+
ArrayAttr ReduceOp::getIteratorTypes() {
1191+
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
1192+
SmallVector<StringRef> iteratorTypes(inputRank,
1193+
getParallelIteratorTypeName());
1194+
for (int64_t reductionDim : getDimensions())
1195+
iteratorTypes[reductionDim] = getReductionIteratorTypeName();
1196+
return Builder(getContext()).getStrArrayAttr(iteratorTypes);
1197+
}
1198+
1199+
ArrayAttr ReduceOp::getIndexingMaps() {
1200+
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
1201+
SmallVector<AffineMap> affineMaps(
1202+
getNumInputs(),
1203+
AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1204+
AffineMap resultMap =
1205+
AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1206+
.dropResults(getDimensions());
1207+
for (int64_t i = 0, e = getNumOutputs(); i < e; ++i)
1208+
affineMaps.push_back(resultMap);
1209+
return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1210+
}
1211+
1212+
void ReduceOp::getEffects(
1213+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1214+
&effects) {
1215+
SmallVector<Value> inputBuffers = getInputBufferOperands();
1216+
SmallVector<Value> outputBuffers = getOutputBufferOperands();
1217+
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
1218+
outputBuffers);
1219+
}
1220+
1221+
static ParseResult parseDstStyleOp(
1222+
OpAsmParser &parser, OperationState &result,
1223+
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1224+
nullptr) {
1225+
// Parse `ins` and `outs`.
1226+
SmallVector<Type, 4> inputTypes, outputTypes;
1227+
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1228+
/*addOperandSegmentSizes=*/false))
1229+
return failure();
1230+
1231+
// Add result types.
1232+
for (Type outputType : outputTypes) {
1233+
if (!outputType.isa<RankedTensorType>())
1234+
return failure();
1235+
result.addTypes(outputType);
1236+
}
1237+
1238+
// Parse required attributes.
1239+
if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1240+
return failure();
1241+
1242+
// Parse optional attributes.
1243+
if (parser.parseOptionalAttrDict(result.attributes))
1244+
return failure();
1245+
return success();
1246+
}
1247+
1248+
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1249+
NamedAttrList &attributes,
1250+
StringRef attributeName) {
1251+
if (parser.parseKeyword(attributeName) || parser.parseEqual())
1252+
return failure();
1253+
1254+
attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1255+
return success();
1256+
}
1257+
1258+
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1259+
if (parseDstStyleOp(
1260+
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1261+
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1262+
}))
1263+
return failure();
1264+
1265+
SmallVector<OpAsmParser::Argument> regionArgs;
1266+
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1267+
/*allowType=*/true, /*allowAttrs=*/true)) {
1268+
return failure();
1269+
}
1270+
1271+
Region *body = result.addRegion();
1272+
if (parser.parseRegion(*body, regionArgs))
1273+
return failure();
1274+
1275+
return success();
1276+
}
1277+
1278+
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1279+
ArrayRef<int64_t> attributeValue) {
1280+
p << " " << attributeName << " = [" << attributeValue << "] ";
1281+
}
1282+
1283+
void ReduceOp::print(OpAsmPrinter &p) {
1284+
printCommonStructuredOpParts(p, getInputs(), getOutputs());
1285+
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1286+
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1287+
1288+
p << "(";
1289+
llvm::interleaveComma(getCombiner().getArguments(), p,
1290+
[&](auto arg) { p.printRegionArgument(arg); });
1291+
p << ") ";
1292+
1293+
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1294+
}
1295+
1296+
LogicalResult ReduceOp::verify() {
1297+
ArrayRef<int64_t> dimensionsRef = getDimensions();
1298+
1299+
for (int64_t i = 1; i < getNumInputs(); ++i) {
1300+
if (getInputs()[i].getType().cast<ShapedType>().getShape() !=
1301+
getInputs()[0].getType().cast<ShapedType>().getShape()) {
1302+
return emitOpError() << "expects all inputs to have the same shapes. "
1303+
"Shape at input-index "
1304+
<< i
1305+
<< " is not equal to the shape at input-index 0.";
1306+
}
1307+
}
1308+
for (int64_t i = 1; i < getNumOutputs(); ++i) {
1309+
if (getInits()[i].getType().cast<ShapedType>().getShape() !=
1310+
getInits()[0].getType().cast<ShapedType>().getShape()) {
1311+
return emitOpError() << "expects all outputs to have the same shapes. "
1312+
"Shape at output-index "
1313+
<< i
1314+
<< " is not equal to the shape at output-index 0.";
1315+
}
1316+
}
1317+
auto inputType = getInputs()[0].getType().cast<ShapedType>();
1318+
auto initType = getInits()[0].getType().cast<ShapedType>();
1319+
1320+
DenseSet<int64_t> dimensionsToReduce;
1321+
int64_t lastDimension = -1;
1322+
for (int64_t dimension : dimensionsRef) {
1323+
if (dimension < 0 || dimension >= inputType.getRank()) {
1324+
return emitOpError()
1325+
<< "dimensions for reduction should be in the range [0, "
1326+
<< inputType.getRank() - 1 << "].";
1327+
}
1328+
if (dimension <= lastDimension) {
1329+
return emitOpError()
1330+
<< "reduction dimensions are not in increasing order: "
1331+
<< dimensionsRef;
1332+
}
1333+
1334+
lastDimension = dimension;
1335+
dimensionsToReduce.insert(dimension);
1336+
}
1337+
1338+
auto inputDims = inputType.getShape();
1339+
auto initDims = initType.getShape();
1340+
1341+
// Input dimensions that will be left after the reduction.
1342+
SmallVector<int64_t> reducedInputDims;
1343+
for (const auto &en : llvm::enumerate(inputDims)) {
1344+
if (!dimensionsToReduce.count(en.index()))
1345+
reducedInputDims.push_back(en.value());
1346+
}
1347+
1348+
if (reducedInputDims.size() != initType.getRank()) {
1349+
return emitOpError() << "number of dimensions after reduction "
1350+
<< reducedInputDims.size()
1351+
<< " doesn't match the init rank "
1352+
<< initType.getRank();
1353+
}
1354+
1355+
if (reducedInputDims != initDims)
1356+
return emitOpError() << "init dimensions [" << initDims
1357+
<< "] doesn't match input dimensions after reduction ["
1358+
<< reducedInputDims << "]";
1359+
1360+
Block *block = getBody();
1361+
if (block->getNumArguments() != this->getNumOperands())
1362+
return emitOpError()
1363+
<< "mismatching number of operands and block arguments";
1364+
1365+
// Check that the first block arguments match the element type of the inputs.
1366+
for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1367+
Type inputElementType = input.getType().cast<ShapedType>().getElementType();
1368+
if (inputElementType != bbArg.getType())
1369+
return emitOpError()
1370+
<< "input element type " << inputElementType
1371+
<< " does not match corresponding block argument type "
1372+
<< bbArg.getType();
1373+
}
1374+
1375+
// Check that the last block arguments match the element type of the outputs.
1376+
for (auto [output, bbArg] : llvm::zip(
1377+
getOutputs(), block->getArguments().take_back(getNumOutputs()))) {
1378+
auto outputElementType =
1379+
output.getType().cast<ShapedType>().getElementType();
1380+
if (outputElementType != bbArg.getType())
1381+
return emitOpError()
1382+
<< "output element type " << outputElementType
1383+
<< " does not match corresponding block argument type "
1384+
<< bbArg.getType();
1385+
}
1386+
return success();
1387+
}
1388+
11831389
//===----------------------------------------------------------------------===//
11841390
// InitTensorOp
11851391
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)