Skip to content

Commit d4fbf83

Browse files
author
Nicolas Vasilache
committed
[mlir][EDSC] NFC - Move StructuredIndexed and IteratorType out of Linalg
Summary: This NFC revision will allow those classes to be reused to allow building structured vector operations. Reviewers: aartbik, ftynse Subscribers: arphaman, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74279
1 parent 20344d3 commit d4fbf83

File tree

4 files changed

+68
-64
lines changed

4 files changed

+68
-64
lines changed

mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -87,57 +87,6 @@ template <typename LoopTy> class GenericLoopNestRangeBuilder {
8787
std::unique_ptr<BuilderType> builder;
8888
};
8989

90-
enum class IterType { Parallel, Reduction };
91-
92-
inline StringRef toString(IterType t) {
93-
switch (t) {
94-
case IterType::Parallel:
95-
return getParallelIteratorTypeName();
96-
case IterType::Reduction:
97-
return getReductionIteratorTypeName();
98-
}
99-
llvm_unreachable("Unsupported IterType");
100-
}
101-
102-
/// A StructuredIndexed represents an indexable quantity that is either:
103-
/// 1. a captured value, which is suitable for buffer and tensor operands, or;
104-
/// 2. a captured type, which is suitable for tensor return values.
105-
///
106-
/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
107-
/// It enable an idiomatic syntax for index expressions such as:
108-
///
109-
/// ```
110-
/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
111-
/// C(buffer_value_or_tensor_type);
112-
/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
113-
/// ```
114-
struct StructuredIndexed : public ValueHandle {
115-
StructuredIndexed(Type type) : ValueHandle(type) {}
116-
StructuredIndexed(Value value) : ValueHandle(value) {}
117-
StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
118-
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
119-
return StructuredIndexed(*this, indexings);
120-
}
121-
122-
ArrayRef<AffineExpr> getExprs() { return exprs; }
123-
124-
private:
125-
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
126-
: ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
127-
assert(t.isa<RankedTensorType>() && "RankedTensor expected");
128-
}
129-
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
130-
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
131-
assert((v.getType().isa<MemRefType>() ||
132-
v.getType().isa<RankedTensorType>()) &&
133-
"MemRef or RankedTensor expected");
134-
}
135-
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
136-
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
137-
138-
SmallVector<AffineExpr, 4> exprs;
139-
};
140-
14190
inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}
14291

14392
/// Build a `linalg.generic` op with the specified `inputs`, `outputs` and
@@ -157,7 +106,7 @@ inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}
157106
/// restriction output tensor results would need to be reordered, which would
158107
/// result in surprising behavior when combined with region definition.
159108
Operation *makeGenericLinalgOp(
160-
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
109+
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
161110
ArrayRef<StructuredIndexed> outputs,
162111
function_ref<void(ArrayRef<BlockArgument>)> regionBuilder =
163112
defaultRegionBuilder,

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
8484
return res;
8585
}
8686

87+
/// Typed representation for loop type strings.
88+
enum class IteratorType { Parallel, Reduction };
89+
90+
inline StringRef toString(IteratorType t) {
91+
switch (t) {
92+
case IteratorType::Parallel:
93+
return getParallelIteratorTypeName();
94+
case IteratorType::Reduction:
95+
return getReductionIteratorTypeName();
96+
}
97+
llvm_unreachable("Unsupported IteratorType");
98+
}
99+
87100
} // end namespace mlir
88101

89102
#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H

mlir/include/mlir/EDSC/Builders.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/AffineOps/AffineOps.h"
1818
#include "mlir/Dialect/LoopOps/LoopOps.h"
1919
#include "mlir/Dialect/StandardOps/Ops.h"
20+
#include "mlir/IR/AffineExpr.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/Transforms/FoldUtils.h"
2223

@@ -493,6 +494,46 @@ class BlockHandle : public CapturableHandle {
493494
mlir::Block *block;
494495
};
495496

497+
/// A StructuredIndexed represents an indexable quantity that is either:
498+
/// 1. a captured value, which is suitable for buffer and tensor operands, or;
499+
/// 2. a captured type, which is suitable for tensor return values.
500+
///
501+
/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
502+
/// It enable an idiomatic syntax for index expressions such as:
503+
///
504+
/// ```
505+
/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
506+
/// C(buffer_value_or_tensor_type);
507+
/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
508+
/// ```
509+
struct StructuredIndexed : public ValueHandle {
510+
StructuredIndexed(Type type) : ValueHandle(type) {}
511+
StructuredIndexed(Value value) : ValueHandle(value) {}
512+
StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
513+
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
514+
return this->hasValue() ? StructuredIndexed(this->getValue(), indexings)
515+
: StructuredIndexed(this->getType(), indexings);
516+
}
517+
518+
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
519+
: ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
520+
assert(t.isa<RankedTensorType>() && "RankedTensor expected");
521+
}
522+
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
523+
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
524+
assert((v.getType().isa<MemRefType>() ||
525+
v.getType().isa<RankedTensorType>()) &&
526+
"MemRef or RankedTensor expected");
527+
}
528+
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
529+
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
530+
531+
ArrayRef<AffineExpr> getExprs() { return exprs; }
532+
533+
private:
534+
SmallVector<AffineExpr, 4> exprs;
535+
};
536+
496537
template <typename Op, typename... Args>
497538
OperationHandle OperationHandle::create(Args... args) {
498539
return OperationHandle(ScopedContext::getBuilder()

mlir/lib/Dialect/Linalg/EDSC/Builders.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
1010
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
1111
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1213
#include "mlir/EDSC/Builders.h"
1314
#include "mlir/EDSC/Intrinsics.h"
1415
#include "mlir/IR/AffineExpr.h"
@@ -144,7 +145,7 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
144145
}
145146

146147
Operation *mlir::edsc::makeGenericLinalgOp(
147-
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
148+
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
148149
ArrayRef<StructuredIndexed> outputs,
149150
function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
150151
ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
@@ -240,8 +241,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
240241
Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
241242
StructuredIndexed I,
242243
StructuredIndexed O) {
243-
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
244-
edsc::IterType::Parallel);
244+
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
245+
IteratorType::Parallel);
245246
if (O.getType().isa<RankedTensorType>()) {
246247
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
247248
assert(args.size() == 1 && "expected 1 block arguments");
@@ -270,8 +271,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
270271
StructuredIndexed I1,
271272
StructuredIndexed I2,
272273
StructuredIndexed O) {
273-
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
274-
edsc::IterType::Parallel);
274+
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
275+
IteratorType::Parallel);
275276
if (O.getType().isa<RankedTensorType>()) {
276277
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
277278
assert(args.size() == 2 && "expected 2 block arguments");
@@ -315,7 +316,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
315316
bindDims(ScopedContext::getContext(), m, n, k);
316317
StructuredIndexed A(vA), B(vB), C(vC);
317318
return makeGenericLinalgOp(
318-
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
319+
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
319320
{A({m, k}), B({k, n})},
320321
{C({m, n})},
321322
macRegionBuilder);
@@ -329,7 +330,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
329330
bindDims(ScopedContext::getContext(), m, n, k);
330331
StructuredIndexed A(vA), B(vB), C(tC);
331332
return makeGenericLinalgOp(
332-
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
333+
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
333334
{A({m, k}), B({k, n})},
334335
{C({m, n})},
335336
mulRegionBuilder);
@@ -343,7 +344,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
343344
bindDims(ScopedContext::getContext(), m, n, k);
344345
StructuredIndexed A(vA), B(vB), C(vC), D(tD);
345346
return makeGenericLinalgOp(
346-
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
347+
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
347348
{A({m, k}), B({k, n}), C({m, n})},
348349
{D({m, n})},
349350
macRegionBuilder);
@@ -360,8 +361,8 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
360361
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
361362

362363
// Some short names.
363-
auto par = IterType::Parallel;
364-
auto red = IterType::Reduction;
364+
auto par = IteratorType::Parallel;
365+
auto red = IteratorType::Reduction;
365366
auto s = strides;
366367
auto d = dilations;
367368

@@ -393,8 +394,8 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
393394
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
394395

395396
// Some short names.
396-
auto par = IterType::Parallel;
397-
auto red = IterType::Reduction;
397+
auto par = IteratorType::Parallel;
398+
auto red = IteratorType::Reduction;
398399
auto s = strides;
399400
auto d = dilations;
400401

0 commit comments

Comments
 (0)