Skip to content

Commit aa4e466

Browse files
committed
[mlir][Linalg] Improve region support in Linalg ops
This revision takes advantage of the newly extended `ref` directive in assembly format to allow better region handling for LinalgOps. Specifically, FillOp and CopyOp now build their regions explicitly which allows retiring older behavior that relied on specific op knowledge in both lowering to loops and vectorization. This reverts commit 3f22547 and reland 973e133 with a workaround for a gcc bug that does not accept lambda default parameters: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59949 Differential Revision: https://reviews.llvm.org/D96598
1 parent 9f17599 commit aa4e466

File tree

11 files changed

+319
-282
lines changed

11 files changed

+319
-282
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,20 +1056,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
10561056
//===------------------------------------------------------------------===//
10571057
// Other static interface methods.
10581058
//===------------------------------------------------------------------===//
1059-
StaticInterfaceMethod<
1060-
/*desc=*/[{
1061-
Create an operation of the current type with the given location,
1062-
operands, and attributes.
1063-
}],
1064-
/*retTy=*/"Operation *",
1065-
/*methodName=*/"create",
1066-
(ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
1067-
"ValueRange":$operands,
1068-
"ArrayRef<NamedAttribute>":$attributes), [{
1069-
return builder.create<ConcreteOp>(
1070-
loc, resultTypes, operands, attributes);
1071-
}]
1072-
>,
10731059
InterfaceMethod<
10741060
/*desc=*/[{
10751061
Clone the current operation with the given location and operands. This
@@ -1082,14 +1068,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
10821068
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
10831069
"ValueRange":$operands),
10841070
[{
1085-
BlockAndValueMapping map;
1086-
unsigned numRegions = $_op->getNumRegions();
1087-
Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
1088-
assert(res->getNumRegions() == numRegions && "inconsistent # regions");
1089-
for (unsigned ridx = 0; ridx < numRegions; ++ridx)
1090-
$_op->getRegion(ridx).cloneInto(
1091-
&res->getRegion(ridx), map);
1092-
return res;
1071+
BlockAndValueMapping bvm;
1072+
OperationState state(
1073+
loc, ConcreteOp::getOperationName(), operands, resultTypes,
1074+
$_op->getAttrs());
1075+
for (Region &r : $_op->getRegions())
1076+
r.cloneInto(state.addRegion(), bvm);
1077+
return b.createOperation(state);
10931078
}]
10941079
>,
10951080
StaticInterfaceMethod<
@@ -1098,7 +1083,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
10981083
Returns a null function if this named op does not define a region
10991084
builder.
11001085
}],
1101-
/*retTy=*/"std::function<void(Block &)>",
1086+
/*retTy=*/"std::function<void(Block &, ValueRange)>",
11021087
/*methodName=*/"getRegionBuilder",
11031088
(ins),
11041089
[{ return ConcreteOp::getRegionBuilder(); }]

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,13 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
110110
AnyStridedMemRef:$output,
111111
OptionalAttr<AffineMapAttr>:$inputPermutation,
112112
OptionalAttr<AffineMapAttr>:$outputPermutation);
113+
let regions = (region AnyRegion:$region);
113114

114-
// TODO: this should go away once the usage of OptionalAttr triggers emission
115-
// of builders with default arguments left unspecified.
116-
let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
117-
[{
118-
return build(
119-
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
120-
}]>];
115+
let builders = [
116+
OpBuilderDAG<(ins "Value":$input, "Value":$output,
117+
CArg<"AffineMap", "AffineMap()">:$inputPermutation,
118+
CArg<"AffineMap", "AffineMap()">:$outputPermutation,
119+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
121120

122121
let extraClassDeclaration = structuredOpsDecls # [{
123122
ValueRange inputs() { return getOperands().take_front(); }
@@ -146,24 +145,31 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
146145
Value getSource() { return input();}
147146
Value getTarget() { return output(); }
148147

149-
static std::function<void(Block &)> getRegionBuilder() {
150-
return nullptr;
148+
static void regionBuilder(Block &block, ValueRange captures);
149+
static std::function<void(Block &block, ValueRange captures)>
150+
getRegionBuilder() {
151+
return &regionBuilder;
151152
}
153+
static unsigned getNumRegionArgs() { return 2; }
152154
}];
153155
let verifier = [{ return ::verify(*this); }];
154156

155157
let assemblyFormat = [{
156-
`(` operands `)` attr-dict `:` type(operands)
158+
`(` $input `,` $output `)` attr-dict `:`
159+
type($input) `,` type($output)
160+
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
157161
}];
158162

159163
let hasFolder = 1;
160164
let hasCanonicalizer = 1;
165+
let skipDefaultBuilders = 1;
161166
}
162167

163168
def FillOp : LinalgStructured_Op<"fill", []> {
164169
let arguments = (ins AnyShaped:$output,
165170
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
166171
let results = (outs Optional<AnyRankedTensor>:$result);
172+
let regions = (region AnyRegion:$region);
167173
let extraClassDeclaration = structuredOpsDecls # [{
168174
ValueRange inputs() { return {}; }
169175
ValueRange outputs() { return getOperands().take_front(); }
@@ -183,13 +189,18 @@ def FillOp : LinalgStructured_Op<"fill", []> {
183189
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
184190
}
185191

186-
static std::function<void(Block &)> getRegionBuilder() {
187-
return nullptr;
192+
static void regionBuilder(Block &block, ValueRange captures);
193+
static std::function<void(Block &block, ValueRange captures)>
194+
getRegionBuilder() {
195+
return &regionBuilder;
188196
}
197+
static unsigned getNumRegionArgs() { return 1; }
189198
}];
190199

191200
let assemblyFormat = [{
192-
`(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
201+
`(` $output `,` $value `)` attr-dict `:`
202+
type($output) `,` type($value) (`->` type($result)^)?
203+
custom<FillOpRegion>($region, ref(type($output)), ref($value))
193204
}];
194205

195206
let builders = [
@@ -268,7 +279,8 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
268279
return padding().getValue().getValue<int64_t>({i, 1});
269280
}
270281

271-
static std::function<void(Block &)> getRegionBuilder() {
282+
static std::function<void(Block &, ValueRange captures)> getRegionBuilder()
283+
{
272284
return nullptr;
273285
}
274286
}];
@@ -519,7 +531,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
519531
library_call()->str() : "op_has_no_registered_library_name";
520532
}
521533

522-
static std::function<void(Block &)> getRegionBuilder() {
534+
static std::function<void(Block &, ValueRange)> getRegionBuilder() {
523535
return nullptr;
524536
}
525537
}];

mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
154154
if (in == op.input() && out == op.output())
155155
return failure();
156156

157-
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
157+
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
158+
if (!libraryCallName)
159+
return failure();
160+
161+
rewriter.replaceOpWithNewOp<mlir::CallOp>(
162+
op, libraryCallName.getValue(), TypeRange(),
163+
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
158164
return success();
159165
}
160166

mlir/lib/Dialect/Linalg/EDSC/Builders.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ Operation *mlir::edsc::makeGenericLinalgOp(
2727
ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
2828
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
2929
ArrayRef<Attribute> otherAttributes) {
30-
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
31-
3230
// Build maps
3331
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
3432
exprsList.reserve(inputs.size() + outputs.size());
@@ -54,13 +52,10 @@ Operation *mlir::edsc::makeGenericLinalgOp(
5452
resultTensorTypes,
5553
inputValues,
5654
outputValues,
57-
builder.getAffineMapArrayAttr(maps),
58-
builder.getStrArrayAttr(iteratorStrTypes),
59-
StringAttr() /*doc*/,
60-
StringAttr() /*library_call*/,
61-
ArrayAttr() /*sparse*/
62-
/* TODO: other attributes in op */
63-
)
55+
maps,
56+
iteratorStrTypes,
57+
""/*doc*/,
58+
""/*library_call*/)
6459
.getOperation();
6560
// clang-format on
6661

0 commit comments

Comments
 (0)