Skip to content

Commit 3c2a74a

Browse files
[mlir][linalg][transform] Add TileOp to transform dialect
This commit adds a tiling op to the transform dialect as an external op. Differential Revision: https://reviews.llvm.org/D124661
1 parent e66127e commit 3c2a74a

File tree

12 files changed

+414
-9
lines changed

12 files changed

+414
-9
lines changed

mlir/include/mlir/Dialect/Linalg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(IR)
2+
add_subdirectory(TransformOps)
23

34
set(LLVM_TARGET_DEFINITIONS Passes.td)
45
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
2+
mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- LinalgTransformOps.h - Linalg transform ops --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
10+
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// Linalg Transform Operations
17+
//===----------------------------------------------------------------------===//
18+
19+
#define GET_OP_CLASSES
20+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"
21+
22+
namespace mlir {
23+
class DialectRegistry;
24+
25+
namespace linalg {
26+
void registerTransformDialectExtension(DialectRegistry &registry);
27+
} // namespace linalg
28+
} // namespace mlir
29+
30+
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- LinalgTransformOps.td - Linalg transform ops --------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LINALG_TRANSFORM_OPS
10+
#define LINALG_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/IR/TransformEffects.td"
14+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
15+
include "mlir/Dialect/PDL/IR/PDLTypes.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
include "mlir/IR/OpBase.td"
18+
19+
def TileOp : Op<Transform_Dialect, "structured.tile",
20+
[DeclareOpInterfaceMethods<TransformOpInterface>,
21+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
22+
let description = [{
23+
Indicates that the given `target` op should be tiled with the options
24+
provided as attributes. This transform generates a loop nest with a smaller
25+
("tiled") target operation in its body. Currently limited to LinalgOps.
26+
27+
`sizes` are the tile sizes. A tile size of `0` indicates that the
28+
respective dimension should not be tiled. No loop will be generated for such
29+
dimensions. If all tile sizes are `0`, this transform is effectively a
30+
no-op.
31+
32+
This op returns handles to the tiled op (in the generated loop nest) and the
33+
generated loops. The number of loops is the number of non-zero tile sizes.
34+
}];
35+
36+
let arguments = (ins PDL_Operation:$target,
37+
DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
38+
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
39+
let results = (outs PDL_Operation:$tiled_linalg_op,
40+
Variadic<PDL_Operation>:$loops);
41+
42+
let hasCustomAssemblyFormat = 1;
43+
}
44+
45+
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/InitAllDialects.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
3434
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
3535
#include "mlir/Dialect/Linalg/IR/Linalg.h"
36+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
3637
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
3738
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
3839
#include "mlir/Dialect/Math/IR/Math.h"
@@ -101,6 +102,11 @@ inline void registerAllDialects(DialectRegistry &registry) {
101102
tosa::TosaDialect,
102103
x86vector::X86VectorDialect>();
103104
// clang-format on
105+
106+
// Register all dialect extensions.
107+
linalg::registerTransformDialectExtension(registry);
108+
109+
// Register all external models.
104110
arith::registerBufferizableOpInterfaceExternalModels(registry);
105111
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
106112
registry);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(Analysis)
22
add_subdirectory(IR)
3+
add_subdirectory(TransformOps)
34
add_subdirectory(Transforms)
45
add_subdirectory(Utils)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_dialect_library(MLIRLinalgTransformOps
2+
LinalgTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps
6+
7+
DEPENDS
8+
MLIRLinalgTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRLinalg
13+
MLIRLinalgTransforms
14+
MLIRParser
15+
MLIRPDL
16+
MLIRSideEffectInterfaces
17+
MLIRTransformDialect
18+
)
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10+
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13+
#include "mlir/Dialect/PDL/IR/PDL.h"
14+
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
15+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16+
#include "mlir/Interfaces/SideEffectInterfaces.h"
17+
#include "mlir/Parser/Parser.h"
18+
#include "llvm/Support/FormatVariadic.h"
19+
20+
using namespace mlir;
21+
using namespace mlir::linalg;
22+
using namespace mlir::transform;
23+
24+
/// Extracts a vector of int64_t from an array attribute. Asserts if the
25+
/// attribute contains values other than integers.
26+
static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
27+
SmallVector<int64_t> result;
28+
result.reserve(attr.size());
29+
for (APInt value : attr.getAsValueRange<IntegerAttr>())
30+
result.push_back(value.getSExtValue());
31+
return result;
32+
}
33+
34+
/// Extracts a vector of unsigned from an array attribute. Asserts if the
35+
/// attribute contains values other than intergers. May truncate.
36+
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
37+
SmallVector<unsigned> result;
38+
result.reserve(attr.size());
39+
for (APInt value : attr.getAsValueRange<IntegerAttr>())
40+
result.push_back(value.getZExtValue());
41+
return result;
42+
}
43+
44+
namespace {
45+
/// A simple pattern rewriter that implements no special logic.
46+
class SimpleRewriter : public PatternRewriter {
47+
public:
48+
SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
49+
};
50+
} // namespace
51+
52+
//===----------------------------------------------------------------------===//
53+
// TileOp
54+
//===----------------------------------------------------------------------===//
55+
56+
/// Apply a tiling transformation to all payload ops and store both the
57+
/// tiled operation as well as the created tile loops.
58+
static LogicalResult
59+
applyTilingToAll(Operation *transformOp, Value target,
60+
ArrayRef<int64_t> tileSizes,
61+
transform::TransformResults &transformResults,
62+
transform::TransformState &state,
63+
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
64+
// Number of loops: Number of tiles sizes that are not zero.
65+
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
66+
// All payload ops. These should all be LinalgOps for now.
67+
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
68+
69+
SmallVector<Operation *> tiledLinalgOps;
70+
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
71+
for (unsigned int i = 0; i < numLoops; ++i)
72+
loopOps[i].reserve(payloadOps.size());
73+
74+
for (Operation *target : payloadOps) {
75+
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
76+
if (!linalgOp)
77+
return transformOp->emitError("only LinalgOps are supported");
78+
79+
FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
80+
if (failed(tiled))
81+
return failure();
82+
83+
tiledLinalgOps.push_back(tiled->op);
84+
if (tiled->loops.size() != numLoops)
85+
// Not enough loops were generated. This usually means that the input size
86+
// was smaller than the tiling size.
87+
// TODO: LinalgTilingPattern should return failure().
88+
return failure();
89+
for (unsigned int i = 0; i < numLoops; ++i)
90+
loopOps[i].push_back(tiled->loops[i]);
91+
}
92+
93+
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
94+
for (unsigned int i = 0; i < numLoops; ++i)
95+
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
96+
return success();
97+
}
98+
99+
LogicalResult transform::TileOp::apply(TransformResults &transformResults,
100+
TransformState &state) {
101+
LinalgTilingOptions tilingOptions;
102+
SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
103+
104+
if (!tileSizes.empty())
105+
tilingOptions.setTileSizes(tileSizes);
106+
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
107+
LinalgTilingPattern pattern(getContext(), tilingOptions);
108+
109+
return applyTilingToAll(getOperation(), getTarget(), tileSizes,
110+
transformResults, state, [&](LinalgOp linalgOp) {
111+
SimpleRewriter rewriter(linalgOp.getContext());
112+
return pattern.returningMatchAndRewrite(linalgOp,
113+
rewriter);
114+
});
115+
}
116+
117+
ParseResult transform::TileOp::parse(OpAsmParser &parser,
118+
OperationState &result) {
119+
StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
120+
OpAsmParser::UnresolvedOperand targetOperand;
121+
SMLoc opLoc;
122+
parser.getCurrentLocation(&opLoc);
123+
if (parser.parseOperand(targetOperand))
124+
return parser.emitError(opLoc, "expected 'target' operand");
125+
if (parser.parseOptionalAttrDict(result.attributes))
126+
return failure();
127+
Attribute sizesAttr = result.attributes.get(sizesAttrName);
128+
if (!sizesAttr)
129+
return parser.emitError(opLoc)
130+
<< "expected '" << sizesAttrName << "' attribute";
131+
auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
132+
if (!sizesArrayAttr)
133+
return parser.emitError(opLoc)
134+
<< "'" << sizesAttrName << "' attribute must be an array";
135+
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
136+
size_t numExpectedLoops =
137+
sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
138+
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
139+
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
140+
return failure();
141+
return success();
142+
}
143+
144+
void TileOp::print(OpAsmPrinter &p) {
145+
p << ' ';
146+
p << getTarget();
147+
p.printOptionalAttrDict((*this)->getAttrs());
148+
}
149+
150+
void TileOp::getEffects(
151+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
152+
&effects) {
153+
// `target` arg is consumed and can no longer be used.
154+
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
155+
TransformMappingResource::get());
156+
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
157+
TransformMappingResource::get());
158+
159+
for (Value r : getResults()) {
160+
effects.emplace_back(MemoryEffects::Write::get(), r,
161+
TransformMappingResource::get());
162+
effects.emplace_back(MemoryEffects::Allocate::get(), r,
163+
TransformMappingResource::get());
164+
}
165+
166+
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
167+
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
168+
}
169+
170+
//===----------------------------------------------------------------------===//
171+
// Transform op registration
172+
//===----------------------------------------------------------------------===//
173+
174+
namespace {
175+
/// Registers new ops and declares PDL as dependent dialect since the additional
176+
/// ops are using PDL types for operands and results.
177+
class LinalgTransformDialectExtension
178+
: public transform::TransformDialectExtension<
179+
LinalgTransformDialectExtension> {
180+
public:
181+
LinalgTransformDialectExtension() {
182+
declareDependentDialect<pdl::PDLDialect>();
183+
declareDependentDialect<scf::SCFDialect>();
184+
registerTransformOps<
185+
#define GET_OP_LIST
186+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
187+
>();
188+
}
189+
};
190+
} // namespace
191+
192+
#define GET_OP_CLASSES
193+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
194+
195+
void mlir::linalg::registerTransformDialectExtension(
196+
DialectRegistry &registry) {
197+
registry.addExtensions<LinalgTransformDialectExtension>();
198+
}

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
168168

169169
// Shift all IndexOp results by the tile offset.
170170
SmallVector<Value> allIvs;
171-
transform(loopRanges, std::back_inserter(allIvs),
172-
[](Range range) { return range.offset; });
171+
llvm::transform(loopRanges, std::back_inserter(allIvs),
172+
[](Range range) { return range.offset; });
173173
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
174174

175175
return clonedOp;

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ getTiledProducerLoops(OpResult producerResult,
8787
assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
8888
"expect slice and producer loop dimensions map one-to-one");
8989
SmallVector<int64_t> tiledProducerLoopIndices;
90-
transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
91-
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
92-
return tiledProducerIndexingSubMap.getDimPosition(idx);
93-
});
90+
llvm::transform(
91+
llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
92+
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
93+
return tiledProducerIndexingSubMap.getDimPosition(idx);
94+
});
9495

9596
return tiledProducerLoopIndices;
9697
}
@@ -141,9 +142,9 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
141142

142143
// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
143144
SmallVector<Value> producerLoopBounds;
144-
transform(producerOp.createLoopRanges(b, loc),
145-
std::back_inserter(producerLoopBounds),
146-
[](Range range) { return range.size; });
145+
llvm::transform(producerOp.createLoopRanges(b, loc),
146+
std::back_inserter(producerLoopBounds),
147+
[](Range range) { return range.size; });
147148
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
148149

149150
// Tile the producer operands given the `sliceOp` ranges. Iterate the

0 commit comments

Comments
 (0)