Skip to content

Commit 5bb8d28

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] Add tensor support to Linalg EDSC Builders
Summary: This diff extends the Linalg EDSC builders so we can easily create mixed tensor/buffer linalg.generic ops. This is expected to be useful for HLO -> Linalg lowering. The StructuredIndexed struct is made to derive from ValueHandle and can now capture a type + indexing expressions. This is used to represent return tensors. Pointwise unary and binary builders are extended to allow both output buffers and return tensors. This has implications on the number of region arguments. Reviewers: ftynse, hanchung, asaadaldien Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73149
1 parent f55b033 commit 5bb8d28

File tree

4 files changed

+110
-17
lines changed

4 files changed

+110
-17
lines changed

mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,37 +94,63 @@ inline StringRef toString(IterType t) {
9494
llvm_unreachable("Unsupported IterType");
9595
}
9696

97-
/// A StructuredIndexed represents a captured value that can be indexed and
98-
/// passed to the `makeGenericLinalgOp`. It allows writing intuitive index
99-
/// expressions such as:
97+
/// A StructuredIndexed represents an indexable quantity that is either:
98+
/// 1. a captured value, which is suitable for buffer and tensor operands, or;
99+
/// 2. a captured type, which is suitable for tensor return values.
100+
///
101+
/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
102+
/// It enable an idiomatic syntax for index expressions such as:
100103
///
101104
/// ```
102-
/// StructuredIndexed A(vA), B(vB), C(vC);
105+
/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
106+
/// C(buffer_value_or_tensor_type);
103107
/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
104108
/// ```
105-
struct StructuredIndexed {
106-
StructuredIndexed(Value v) : value(v) {}
109+
struct StructuredIndexed : public ValueHandle {
110+
StructuredIndexed(Type type) : ValueHandle(type) {}
111+
StructuredIndexed(Value value) : ValueHandle(value) {}
112+
StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
107113
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
108-
return StructuredIndexed(value, indexings);
114+
return StructuredIndexed(*this, indexings);
109115
}
110116

111-
operator Value() const /* implicit */ { return value; }
112117
ArrayRef<AffineExpr> getExprs() { return exprs; }
113118

114119
private:
120+
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
121+
: ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
122+
assert(t.isa<RankedTensorType>() && "RankedTensor expected");
123+
}
115124
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
116-
: value(v), exprs(indexings.begin(), indexings.end()) {
117-
assert(v.getType().isa<MemRefType>() && "MemRefType expected");
125+
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
126+
assert((v.getType().isa<MemRefType>() ||
127+
v.getType().isa<RankedTensorType>()) &&
128+
"MemRef or RankedTensor expected");
118129
}
119-
StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
120-
: StructuredIndexed(v.getValue(), indexings) {}
130+
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
131+
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
121132

122-
Value value;
123133
SmallVector<AffineExpr, 4> exprs;
124134
};
125135

126136
inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}
127137

138+
/// Build a `linalg.generic` op with the specified `inputs`, `outputs` and
139+
/// `region`.
140+
///
141+
/// `otherValues` and `otherAttributes` may be passed and will be appended as
142+
/// operands and attributes respectively.
143+
///
144+
/// Prerequisites:
145+
/// =============
146+
///
147+
/// 1. `inputs` may contain StructuredIndexed that capture either buffer or
148+
/// tensor values.
149+
/// 2. `outputs` may contain StructuredIndexed that capture either buffer values
150+
/// or tensor types. If both buffer values and tensor types are present, then
151+
/// all buffer values must appear before any tensor type. Without this
152+
/// restriction output tensor results would need to be reordered, which would
153+
/// result in surprising behavior when combined with region definition.
128154
Operation *makeGenericLinalgOp(
129155
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
130156
ArrayRef<StructuredIndexed> outputs,
@@ -189,7 +215,7 @@ Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
189215
StructuredIndexed O);
190216

191217
/// Build a linalg.pointwise with all `parallel` iterators and a region that
192-
/// computes `O = max(I!, I2)`. The client is responsible for specifying the
218+
/// computes `O = max(I1, I2)`. The client is responsible for specifying the
193219
/// proper indexings when creating the StructuredIndexed.
194220
Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
195221
StructuredIndexed O);

mlir/include/mlir/EDSC/Builders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class ValueHandle : public CapturableHandle {
339339

340340
/// Implicit conversion useful for automatic conversion to Container<Value>.
341341
operator Value() const { return getValue(); }
342+
operator Type() const { return getType(); }
342343
operator bool() const { return hasValue(); }
343344

344345
/// Generic mlir::Op create. This is the key to being extensible to the whole

mlir/lib/Dialect/Linalg/EDSC/Builders.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ Operation *mlir::edsc::makeGenericLinalgOp(
131131
ArrayRef<StructuredIndexed> outputs,
132132
function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
133133
ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
134+
for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i)
135+
assert(!(outputs[i].getType().isa<RankedTensorType>() &&
136+
outputs[i + 1].getType().isa<MemRefType>()) &&
137+
"output tensors must be passed after output buffers");
134138
auto &builder = edsc::ScopedContext::getBuilder();
135139
auto *ctx = builder.getContext();
136140
unsigned nInputs = inputs.size();
@@ -154,15 +158,19 @@ Operation *mlir::edsc::makeGenericLinalgOp(
154158
SmallVector<Value, 4> values;
155159
values.reserve(nViews);
156160
values.append(inputs.begin(), inputs.end());
157-
values.append(outputs.begin(), outputs.end());
161+
std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(values),
162+
[](StructuredIndexed s) { return s.hasValue(); });
163+
SmallVector<Type, 4> types;
164+
std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types),
165+
[](StructuredIndexed s) { return !s.hasValue(); });
158166

159167
auto iteratorStrTypes = functional::map(toString, iteratorTypes);
160168
// clang-format off
161169
auto *op =
162170
edsc::ScopedContext::getBuilder()
163171
.create<linalg::GenericOp>(
164172
edsc::ScopedContext::getLocation(),
165-
ArrayRef<Type>{}, // TODO(ntv): support tensors
173+
types,
166174
values,
167175
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
168176
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
@@ -210,6 +218,14 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
210218
StructuredIndexed O) {
211219
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
212220
edsc::IterType::Parallel);
221+
if (O.getType().isa<RankedTensorType>()) {
222+
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
223+
assert(args.size() == 1 && "expected 1 block arguments");
224+
ValueHandle a(args[0]);
225+
linalg_yield(unaryOp(a));
226+
};
227+
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
228+
}
213229
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
214230
assert(args.size() == 2 && "expected 2 block arguments");
215231
ValueHandle a(args[0]);
@@ -220,7 +236,6 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
220236

221237
Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
222238
StructuredIndexed O) {
223-
;
224239
using edsc::intrinsics::tanh;
225240
UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
226241
return linalg_pointwise(unOp, I, O);
@@ -233,6 +248,14 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
233248
StructuredIndexed O) {
234249
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
235250
edsc::IterType::Parallel);
251+
if (O.getType().isa<RankedTensorType>()) {
252+
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
253+
assert(args.size() == 2 && "expected 2 block arguments");
254+
ValueHandle a(args[0]), b(args[1]);
255+
linalg_yield(binaryOp(a, b));
256+
};
257+
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
258+
}
236259
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
237260
assert(args.size() == 3 && "expected 3 block arguments");
238261
ValueHandle a(args[0]), b(args[1]);

mlir/test/EDSC/builder-api-test.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,49 @@ TEST_FUNC(linalg_pointwise_test) {
871871
f.erase();
872872
}
873873

874+
// clang-format off
875+
// CHECK-LABEL: func @linalg_pointwise_mixed_tensors
876+
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
877+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
878+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
879+
// CHECK: addf
880+
// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
881+
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
882+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
883+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
884+
// CHECK: cmpf "ogt"
885+
// CHECK: select
886+
// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
887+
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
888+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
889+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
890+
// CHECK: tanh
891+
// CHECK: }: tensor<?x?xf32> -> tensor<?x?xf32>
892+
// clang-format on
893+
TEST_FUNC(linalg_pointwise_mixed_tensors_test) {
894+
using namespace edsc;
895+
using namespace edsc::ops;
896+
897+
auto f32Type = FloatType::getF32(&globalContext());
898+
auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
899+
auto tensorType = RankedTensorType::get({-1, -1}, f32Type);
900+
auto f = makeFunction("linalg_pointwise_mixed_tensors", {},
901+
{tensorType, memrefType});
902+
903+
OpBuilder builder(f.getBody());
904+
ScopedContext scope(builder, f.getLoc());
905+
ValueHandle A(f.getArgument(0)), B(f.getArgument(1));
906+
AffineExpr i, j;
907+
bindDims(&globalContext(), i, j);
908+
StructuredIndexed SA(A), SB(B), SC(tensorType);
909+
linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
910+
linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
911+
linalg_pointwise_tanh(SA({i, j}), SC({i, j}));
912+
913+
f.print(llvm::outs());
914+
f.erase();
915+
}
916+
874917
// clang-format off
875918
// CHECK-LABEL: func @linalg_matmul
876919
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,

0 commit comments

Comments
 (0)