9
9
#include " mlir/Dialect/Linalg/EDSC/Builders.h"
10
10
#include " mlir/Dialect/Linalg/EDSC/Intrinsics.h"
11
11
#include " mlir/Dialect/Linalg/IR/LinalgOps.h"
12
+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
12
13
#include " mlir/EDSC/Builders.h"
13
14
#include " mlir/EDSC/Intrinsics.h"
14
15
#include " mlir/IR/AffineExpr.h"
@@ -144,7 +145,7 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
144
145
}
145
146
146
147
Operation *mlir::edsc::makeGenericLinalgOp (
147
- ArrayRef<IterType > iteratorTypes, ArrayRef<StructuredIndexed> inputs,
148
+ ArrayRef<IteratorType > iteratorTypes, ArrayRef<StructuredIndexed> inputs,
148
149
ArrayRef<StructuredIndexed> outputs,
149
150
function_ref<void (ArrayRef<BlockArgument>)> regionBuilder,
150
151
ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
@@ -240,8 +241,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
240
241
Operation *mlir::edsc::ops::linalg_pointwise (UnaryPointwiseOpBuilder unaryOp,
241
242
StructuredIndexed I,
242
243
StructuredIndexed O) {
243
- SmallVector<edsc::IterType , 4 > iterTypes (O.getExprs ().size (),
244
- edsc::IterType ::Parallel);
244
+ SmallVector<IteratorType , 4 > iterTypes (O.getExprs ().size (),
245
+ IteratorType ::Parallel);
245
246
if (O.getType ().isa <RankedTensorType>()) {
246
247
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
247
248
assert (args.size () == 1 && " expected 1 block arguments" );
@@ -270,8 +271,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
270
271
StructuredIndexed I1,
271
272
StructuredIndexed I2,
272
273
StructuredIndexed O) {
273
- SmallVector<edsc::IterType , 4 > iterTypes (O.getExprs ().size (),
274
- edsc::IterType ::Parallel);
274
+ SmallVector<IteratorType , 4 > iterTypes (O.getExprs ().size (),
275
+ IteratorType ::Parallel);
275
276
if (O.getType ().isa <RankedTensorType>()) {
276
277
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
277
278
assert (args.size () == 2 && " expected 2 block arguments" );
@@ -315,7 +316,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
315
316
bindDims (ScopedContext::getContext (), m, n, k);
316
317
StructuredIndexed A (vA), B (vB), C (vC);
317
318
return makeGenericLinalgOp (
318
- {IterType ::Parallel, IterType ::Parallel, IterType ::Reduction},
319
+ {IteratorType ::Parallel, IteratorType ::Parallel, IteratorType ::Reduction},
319
320
{A ({m, k}), B ({k, n})},
320
321
{C ({m, n})},
321
322
macRegionBuilder);
@@ -329,7 +330,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
329
330
bindDims (ScopedContext::getContext (), m, n, k);
330
331
StructuredIndexed A (vA), B (vB), C (tC);
331
332
return makeGenericLinalgOp (
332
- {IterType ::Parallel, IterType ::Parallel, IterType ::Reduction},
333
+ {IteratorType ::Parallel, IteratorType ::Parallel, IteratorType ::Reduction},
333
334
{A ({m, k}), B ({k, n})},
334
335
{C ({m, n})},
335
336
mulRegionBuilder);
@@ -343,7 +344,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
343
344
bindDims (ScopedContext::getContext (), m, n, k);
344
345
StructuredIndexed A (vA), B (vB), C (vC), D (tD);
345
346
return makeGenericLinalgOp (
346
- {IterType ::Parallel, IterType ::Parallel, IterType ::Reduction},
347
+ {IteratorType ::Parallel, IteratorType ::Parallel, IteratorType ::Reduction},
347
348
{A ({m, k}), B ({k, n}), C ({m, n})},
348
349
{D ({m, n})},
349
350
macRegionBuilder);
@@ -360,8 +361,8 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
360
361
assert ((strides.empty () || strides.size () == 2 ) && " only 2-D conv atm" );
361
362
362
363
// Some short names.
363
- auto par = IterType ::Parallel;
364
- auto red = IterType ::Reduction;
364
+ auto par = IteratorType ::Parallel;
365
+ auto red = IteratorType ::Reduction;
365
366
auto s = strides;
366
367
auto d = dilations;
367
368
@@ -393,8 +394,8 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
393
394
assert ((strides.empty () || strides.size () == 2 ) && " only 2-D conv atm" );
394
395
395
396
// Some short names.
396
- auto par = IterType ::Parallel;
397
- auto red = IterType ::Reduction;
397
+ auto par = IteratorType ::Parallel;
398
+ auto red = IteratorType ::Reduction;
398
399
auto s = strides;
399
400
auto d = dilations;
400
401
0 commit comments