Skip to content

Commit df7b7a4

Browse files
committed
Rebase and address comments
- Simplify various implementations - Make inferResultType always return RankedTensorType - Import saturated_arith and reuse for concat - Clarify reason for decomposing - Use Affine folding helpers - Switch test to a transform op
1 parent d531246 commit df7b7a4

File tree

11 files changed

+198
-179
lines changed

11 files changed

+198
-179
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
132132
let summary = "tensor concatenation operation";
133133
let description = [{
134134
The "concat" operation constructs a tensor out of a variadic list of input
135-
tensors, concatenated along a static dimension. All inputs and the result
136-
type must share the same rank.
135+
tensors, concatenated along a static dimension number. All inputs and the
136+
result type must share the same rank.
137137

138138
`dim` specifies the dimension along which to concatenate. The size of the
139139
concatenated dimension in the result must be equal to the sum of the sizes
@@ -169,11 +169,15 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
169169
// types, being concatenated along `dim`. Because concatenation can specify
170170
// more static information than can automatically be inferred,
171171
// InferTypeOpInterface is not used.
172-
static FailureOr<RankedTensorType> inferResultType(int64_t dim, TypeRange inputTypes);
172+
static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);
173173

174174
RankedTensorType getResultType() {
175175
return ::llvm::cast<RankedTensorType>(getResult().getType());
176176
}
177+
178+
int64_t getRank() {
179+
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
180+
}
177181
}];
178182

179183
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
1515
include "mlir/Interfaces/SideEffectInterfaces.td"
1616
include "mlir/IR/OpBase.td"
1717

18+
def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
19+
"apply_patterns.tensor.decompose_concat",
20+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
21+
let description = [{
22+
Indicates that tensor.concat ops should be decomposed into a chain of
23+
tensor.insert_slice operations inserting into a materialized destination.
24+
}];
25+
26+
let assemblyFormat = "attr-dict";
27+
}
28+
29+
1830
def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
1931
"apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
2032
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
6969

7070
/// Populates `patterns` with patterns that decompose `tensor.concat` into
7171
/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
72-
/// of `tensor.insert_slice` operations on the inputs.
72+
/// of `tensor.insert_slice` operations on the inputs. This is intended to be
73+
/// used as a fallback tensor -> tensor lowering that decomposes concat such
74+
/// that it can be bufferized into a sequence of copies.
7375
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
7476

7577
/// Populates `patterns` with patterns that fold operations like `tensor.pad`

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
151151
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
152152
OpFoldResult step);
153153

154+
/// Idiomatic saturated operations on values like offsets, sizes, and strides.
155+
struct SaturatedInteger {
156+
static SaturatedInteger wrap(int64_t v) {
157+
return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
158+
: SaturatedInteger{false, v};
159+
}
160+
int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
161+
FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
162+
if (saturated && !other.saturated)
163+
return other;
164+
if (!saturated && !other.saturated && v != other.v)
165+
return failure();
166+
return *this;
167+
}
168+
bool operator==(SaturatedInteger other) {
169+
return (saturated && other.saturated) ||
170+
(!saturated && !other.saturated && v == other.v);
171+
}
172+
bool operator!=(SaturatedInteger other) { return !(*this == other); }
173+
SaturatedInteger operator+(SaturatedInteger other) {
174+
if (saturated || other.saturated)
175+
return SaturatedInteger{true, 0};
176+
return SaturatedInteger{false, other.v + v};
177+
}
178+
SaturatedInteger operator*(SaturatedInteger other) {
179+
if (saturated || other.saturated)
180+
return SaturatedInteger{true, 0};
181+
return SaturatedInteger{false, other.v * v};
182+
}
183+
bool saturated = true;
184+
int64_t v = 0;
185+
};
186+
154187
} // namespace mlir
155188

156189
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,6 @@
2626
using namespace mlir;
2727
using namespace mlir::memref;
2828

29-
namespace {
30-
/// Idiomatic saturated operations on offsets, sizes and strides.
31-
namespace saturated_arith {
32-
struct Wrapper {
33-
static Wrapper stride(int64_t v) {
34-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
35-
}
36-
static Wrapper offset(int64_t v) {
37-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
38-
}
39-
static Wrapper size(int64_t v) {
40-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
41-
}
42-
int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
43-
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
44-
int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
45-
bool operator==(Wrapper other) {
46-
return (saturated && other.saturated) ||
47-
(!saturated && !other.saturated && v == other.v);
48-
}
49-
bool operator!=(Wrapper other) { return !(*this == other); }
50-
Wrapper operator+(Wrapper other) {
51-
if (saturated || other.saturated)
52-
return Wrapper{true, 0};
53-
return Wrapper{false, other.v + v};
54-
}
55-
Wrapper operator*(Wrapper other) {
56-
if (saturated || other.saturated)
57-
return Wrapper{true, 0};
58-
return Wrapper{false, other.v * v};
59-
}
60-
bool saturated;
61-
int64_t v;
62-
};
63-
} // namespace saturated_arith
64-
} // namespace
65-
6629
/// Materialize a single constant operation from a given attribute value with
6730
/// the desired resultant type.
6831
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
@@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
22082171
ReassociationIndices reassoc = std::get<0>(it);
22092172
int64_t currentStrideToExpand = std::get<1>(it);
22102173
for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2211-
using saturated_arith::Wrapper;
22122174
reverseResultStrides.push_back(currentStrideToExpand);
2213-
currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
2214-
Wrapper::size(resultShape[shapeIndex--]))
2215-
.asStride();
2175+
currentStrideToExpand =
2176+
(SaturatedInteger::wrap(currentStrideToExpand) *
2177+
SaturatedInteger::wrap(resultShape[shapeIndex--]))
2178+
.asInteger();
22162179
}
22172180
}
22182181
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
@@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
23322295
unsigned resultStrideIndex = resultStrides.size() - 1;
23332296
for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
23342297
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2335-
using saturated_arith::Wrapper;
2336-
auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
2298+
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
23372299
for (int64_t idx : llvm::reverse(trailingReassocs)) {
2338-
stride = stride * Wrapper::size(srcShape[idx]);
2300+
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
23392301

23402302
// Both source and result stride must have the same static value. In that
23412303
// case, we can be sure, that the dimensions are collapsible (because they
@@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
23452307
// ops where obviously non-contiguous dims are collapsed, but accept ops
23462308
// where we cannot be sure statically. Such ops may fail at runtime. See
23472309
// the op documentation for details.
2348-
auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
2310+
auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
23492311
if (strict && (stride.saturated || srcStride.saturated))
23502312
return failure();
23512313

@@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
23712333
SmallVector<int64_t> resultShape;
23722334
resultShape.reserve(reassociation.size());
23732335
for (const ReassociationIndices &group : reassociation) {
2374-
using saturated_arith::Wrapper;
2375-
auto groupSize = Wrapper::size(1);
2336+
auto groupSize = SaturatedInteger::wrap(1);
23762337
for (int64_t srcDim : group)
2377-
groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
2378-
resultShape.push_back(groupSize.asSize());
2338+
groupSize =
2339+
groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2340+
resultShape.push_back(groupSize.asInteger());
23792341
}
23802342

23812343
if (srcType.getLayout().isIdentity()) {
@@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
25862548
int64_t targetOffset = sourceOffset;
25872549
for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
25882550
auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2589-
using saturated_arith::Wrapper;
2590-
targetOffset =
2591-
(Wrapper::offset(targetOffset) +
2592-
Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2593-
.asOffset();
2551+
targetOffset = (SaturatedInteger::wrap(targetOffset) +
2552+
SaturatedInteger::wrap(staticOffset) *
2553+
SaturatedInteger::wrap(targetStride))
2554+
.asInteger();
25942555
}
25952556

25962557
// Compute target stride whose value is:
@@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
25992560
targetStrides.reserve(staticOffsets.size());
26002561
for (auto it : llvm::zip(sourceStrides, staticStrides)) {
26012562
auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2602-
using saturated_arith::Wrapper;
2603-
targetStrides.push_back(
2604-
(Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2605-
.asStride());
2563+
targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2564+
SaturatedInteger::wrap(staticStride))
2565+
.asInteger());
26062566
}
26072567

26082568
// The type is now known.

0 commit comments

Comments
 (0)