Skip to content

Commit f310a5d

Browse files
authored
[mlir][tensor] Add a tensor.concat operation (#72779)
This adds an operation for concatenating ranked tensors along a static dimension, as well as a decomposition mirroring the existing lowering from TOSA to Tensor. This offers a convergence point for "input" like dialects that include various lowerings for concatenation operations, easing later analysis. In the future, this op can implement the necessary interfaces for tiling, as well as potentially add conversions to some kind of linalg and/or memref counterpart. This patch adds the op, the decomposition, and some basic folding/canonicalization. Replacing lowerings with the op (such as the TOSA lowering) will come as a follow up. See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
1 parent 4c44dcf commit f310a5d

File tree

13 files changed

+554
-58
lines changed

13 files changed

+554
-58
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,70 @@ def Tensor_CastOp : Tensor_Op<"cast", [
121121
let hasCanonicalizer = 1;
122122
}
123123

124+
//===----------------------------------------------------------------------===//
125+
// ConcatOp
126+
//===----------------------------------------------------------------------===//
127+
128+
def Tensor_ConcatOp : Tensor_Op<"concat",
129+
[Pure,
130+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
131+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
132+
let summary = "tensor concatenation operation";
133+
let description = [{
134+
The "concat" operation constructs a tensor out of a variadic list of input
135+
tensors, concatenated along a static dimension number. All inputs and the
136+
result type must share the same rank.
137+
138+
`dim` specifies the dimension along which to concatenate. The size of the
139+
concatenated dimension in the result must be equal to the sum of the sizes
140+
of the inputs along that dimension. All other dimensions in both the inputs
141+
and result must be the same size.
142+
143+
Example:
144+
145+
```mlir
146+
%0 = tensor.concat dim(0) %0, %1, %2 :
147+
(tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>
148+
149+
// Dynamic + dynamic -> static
150+
%0 = tensor.concat dim(1) %0, %1, %2 :
151+
(tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
152+
```
153+
}];
154+
let arguments = (ins I64Attr:$dim,
155+
Variadic<AnyRankedTensor>:$inputs);
156+
let results = (outs AnyRankedTensor:$result);
157+
let assemblyFormat = [{
158+
`dim` `(` $dim `)` $inputs attr-dict
159+
`:` functional-type(operands, results)
160+
}];
161+
162+
let builders = [
163+
// Builder with an inferred result type.
164+
OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
165+
];
166+
167+
let extraClassDeclaration = [{
168+
// Helper to infer the concatenated result type for the given list of input
169+
// types, being concatenated along `dim`. Because concatenation can specify
170+
// more static information than can automatically be inferred,
171+
// InferTypeOpInterface is not used.
172+
static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);
173+
174+
RankedTensorType getResultType() {
175+
return ::llvm::cast<RankedTensorType>(getResult().getType());
176+
}
177+
178+
int64_t getRank() {
179+
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
180+
}
181+
}];
182+
183+
let hasCanonicalizer = 1;
184+
let hasFolder = 1;
185+
let hasVerifier = 1;
186+
}
187+
124188
//===----------------------------------------------------------------------===//
125189
// DimOp
126190
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
1515
include "mlir/Interfaces/SideEffectInterfaces.td"
1616
include "mlir/IR/OpBase.td"
1717

18+
def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
19+
"apply_patterns.tensor.decompose_concat",
20+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
21+
let description = [{
22+
Indicates that tensor.concat ops should be decomposed into a chain of
23+
tensor.insert_slice operations inserting into a materialized destination.
24+
}];
25+
26+
let assemblyFormat = "attr-dict";
27+
}
28+
29+
1830
def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
1931
"apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
2032
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
6767
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
6868
bool foldSingleUseOnly = false);
6969

70+
/// Populates `patterns` with patterns that decompose `tensor.concat` into
71+
/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
72+
/// of `tensor.insert_slice` operations on the inputs. This is intended to be
73+
/// used as a fallback tensor -> tensor lowering that decomposes concat such
74+
/// that it can be bufferized into a sequence of copies.
75+
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
76+
7077
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
7178
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
7279
/// respectively.

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
151151
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
152152
OpFoldResult step);
153153

154+
/// Idiomatic saturated operations on values like offsets, sizes, and strides.
155+
struct SaturatedInteger {
156+
static SaturatedInteger wrap(int64_t v) {
157+
return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
158+
: SaturatedInteger{false, v};
159+
}
160+
int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
161+
FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
162+
if (saturated && !other.saturated)
163+
return other;
164+
if (!saturated && !other.saturated && v != other.v)
165+
return failure();
166+
return *this;
167+
}
168+
bool operator==(SaturatedInteger other) {
169+
return (saturated && other.saturated) ||
170+
(!saturated && !other.saturated && v == other.v);
171+
}
172+
bool operator!=(SaturatedInteger other) { return !(*this == other); }
173+
SaturatedInteger operator+(SaturatedInteger other) {
174+
if (saturated || other.saturated)
175+
return SaturatedInteger{true, 0};
176+
return SaturatedInteger{false, other.v + v};
177+
}
178+
SaturatedInteger operator*(SaturatedInteger other) {
179+
if (saturated || other.saturated)
180+
return SaturatedInteger{true, 0};
181+
return SaturatedInteger{false, other.v * v};
182+
}
183+
bool saturated = true;
184+
int64_t v = 0;
185+
};
186+
154187
} // namespace mlir
155188

156189
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,6 @@
2626
using namespace mlir;
2727
using namespace mlir::memref;
2828

29-
namespace {
30-
/// Idiomatic saturated operations on offsets, sizes and strides.
31-
namespace saturated_arith {
32-
struct Wrapper {
33-
static Wrapper stride(int64_t v) {
34-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
35-
}
36-
static Wrapper offset(int64_t v) {
37-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
38-
}
39-
static Wrapper size(int64_t v) {
40-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
41-
}
42-
int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
43-
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
44-
int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
45-
bool operator==(Wrapper other) {
46-
return (saturated && other.saturated) ||
47-
(!saturated && !other.saturated && v == other.v);
48-
}
49-
bool operator!=(Wrapper other) { return !(*this == other); }
50-
Wrapper operator+(Wrapper other) {
51-
if (saturated || other.saturated)
52-
return Wrapper{true, 0};
53-
return Wrapper{false, other.v + v};
54-
}
55-
Wrapper operator*(Wrapper other) {
56-
if (saturated || other.saturated)
57-
return Wrapper{true, 0};
58-
return Wrapper{false, other.v * v};
59-
}
60-
bool saturated;
61-
int64_t v;
62-
};
63-
} // namespace saturated_arith
64-
} // namespace
65-
6629
/// Materialize a single constant operation from a given attribute value with
6730
/// the desired resultant type.
6831
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
@@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
22082171
ReassociationIndices reassoc = std::get<0>(it);
22092172
int64_t currentStrideToExpand = std::get<1>(it);
22102173
for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2211-
using saturated_arith::Wrapper;
22122174
reverseResultStrides.push_back(currentStrideToExpand);
2213-
currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
2214-
Wrapper::size(resultShape[shapeIndex--]))
2215-
.asStride();
2175+
currentStrideToExpand =
2176+
(SaturatedInteger::wrap(currentStrideToExpand) *
2177+
SaturatedInteger::wrap(resultShape[shapeIndex--]))
2178+
.asInteger();
22162179
}
22172180
}
22182181
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
@@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
23322295
unsigned resultStrideIndex = resultStrides.size() - 1;
23332296
for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
23342297
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2335-
using saturated_arith::Wrapper;
2336-
auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
2298+
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
23372299
for (int64_t idx : llvm::reverse(trailingReassocs)) {
2338-
stride = stride * Wrapper::size(srcShape[idx]);
2300+
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
23392301

23402302
// Both source and result stride must have the same static value. In that
23412303
// case, we can be sure, that the dimensions are collapsible (because they
@@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
23452307
// ops where obviously non-contiguous dims are collapsed, but accept ops
23462308
// where we cannot be sure statically. Such ops may fail at runtime. See
23472309
// the op documentation for details.
2348-
auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
2310+
auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
23492311
if (strict && (stride.saturated || srcStride.saturated))
23502312
return failure();
23512313

@@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
23712333
SmallVector<int64_t> resultShape;
23722334
resultShape.reserve(reassociation.size());
23732335
for (const ReassociationIndices &group : reassociation) {
2374-
using saturated_arith::Wrapper;
2375-
auto groupSize = Wrapper::size(1);
2336+
auto groupSize = SaturatedInteger::wrap(1);
23762337
for (int64_t srcDim : group)
2377-
groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
2378-
resultShape.push_back(groupSize.asSize());
2338+
groupSize =
2339+
groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2340+
resultShape.push_back(groupSize.asInteger());
23792341
}
23802342

23812343
if (srcType.getLayout().isIdentity()) {
@@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
25862548
int64_t targetOffset = sourceOffset;
25872549
for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
25882550
auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2589-
using saturated_arith::Wrapper;
2590-
targetOffset =
2591-
(Wrapper::offset(targetOffset) +
2592-
Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2593-
.asOffset();
2551+
targetOffset = (SaturatedInteger::wrap(targetOffset) +
2552+
SaturatedInteger::wrap(staticOffset) *
2553+
SaturatedInteger::wrap(targetStride))
2554+
.asInteger();
25942555
}
25952556

25962557
// Compute target stride whose value is:
@@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
25992560
targetStrides.reserve(staticOffsets.size());
26002561
for (auto it : llvm::zip(sourceStrides, staticStrides)) {
26012562
auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2602-
using saturated_arith::Wrapper;
2603-
targetStrides.push_back(
2604-
(Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2605-
.asStride());
2563+
targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2564+
SaturatedInteger::wrap(staticStride))
2565+
.asInteger());
26062566
}
26072567

26082568
// The type is now known.

0 commit comments

Comments
 (0)