@@ -315,55 +315,54 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
315
315
// Builders for a contracting reshape whose result type is computed from
316
316
// `src` and `reassociation`.
317
317
OpBuilder<(ins "Value":$src,
318
- "ArrayRef<ReassociationExprs >":$reassociation,
318
+ "ArrayRef<ReassociationIndices >":$reassociation,
319
319
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
320
320
OpBuilder<(ins "Value":$src,
321
- "ArrayRef<ReassociationIndices >":$reassociation,
321
+ "ArrayRef<ReassociationExprs >":$reassociation,
322
322
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
323
323
[{
324
324
auto reassociationMaps =
325
- convertReassociationIndicesToMaps ($_builder, reassociation);
325
+ convertReassociationMapsToIndices ($_builder, reassociation);
326
326
build($_builder, $_state, src, reassociationMaps, attrs);
327
327
}]>,
328
328
329
329
// Builders for a reshape whose result type is passed explicitly. This may
330
330
// be either a contracting or expanding reshape.
331
331
OpBuilder<(ins "Type":$resultType, "Value":$src,
332
- "ArrayRef<ReassociationExprs >":$reassociation,
332
+ "ArrayRef<ReassociationIndices >":$reassociation,
333
333
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
334
334
OpBuilder<(ins "Type":$resultType, "Value":$src,
335
- "ArrayRef<ReassociationIndices >":$reassociation,
335
+ "ArrayRef<ReassociationExprs >":$reassociation,
336
336
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
337
337
[{
338
338
auto reassociationMaps =
339
- convertReassociationIndicesToMaps ($_builder, reassociation);
339
+ convertReassociationMapsToIndices ($_builder, reassociation);
340
340
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
341
341
}]>
342
342
];
343
343
344
344
code commonExtraClassDeclaration = [{
345
345
static StringRef getReassociationAttrName() { return "reassociation"; }
346
- SmallVector<AffineMap, 4> getReassociationMaps() {
347
- return llvm::to_vector<4>(llvm::map_range(reassociation(), [
348
- ](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
349
- }
350
- SmallVector<ReassociationExprs, 4> getReassociationExprs() {
351
- return
352
- llvm::to_vector<4>(llvm::map_range(reassociation(),
353
- [](Attribute a) {
354
- return llvm::to_vector<2>(
355
- a.cast<AffineMapAttr>().getValue().getResults());
356
- }));
357
- }
358
- }];
359
- let assemblyFormat = [{
360
- $src $reassociation attr-dict `:` type($src) `into` type(results)
346
+ SmallVector<AffineMap, 4> getReassociationMaps();
347
+ SmallVector<ReassociationExprs, 4> getReassociationExprs();
348
+ SmallVector<ReassociationIndices, 4> getReassociationIndices() {
349
+ SmallVector<ReassociationIndices, 4> reassociationIndices;
350
+ for (auto attr : reassociation())
351
+ reassociationIndices.push_back(llvm::to_vector<2>(
352
+ llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
353
+ return indexAttr.cast<IntegerAttr>().getInt();
354
+ })));
355
+ return reassociationIndices;
356
+ };
361
357
}];
362
358
}
363
359
360
+ def IndexListArrayAttr :
361
+ TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
362
+
364
363
def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
365
364
[DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
366
- Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr :$reassociation)>,
365
+ Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr :$reassociation)>,
367
366
Results<(outs AnyStridedMemRef:$result)> {
368
367
let summary = "linalg.reshape produces a new view into the operand view";
369
368
let description = [{
@@ -373,9 +372,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
373
372
and copies.
374
373
375
374
A reassociation is defined as a continuous grouping of dimensions and is
376
- represented with an affine map array attribute. In the future,
377
- non-continuous groupings may be allowed (i.e. permutations, reindexings
378
- etc).
375
+ represented with an array of I64ArrayAttr attribute.
379
376
380
377
For now, it is assumed that either:
381
378
1. a reassociation produces and consumes contiguous MemRefType or,
@@ -401,13 +398,13 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
401
398
402
399
```mlir
403
400
// Dimension collapse (i, j) -> i' and k -> k'
404
- %1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k) ] :
401
+ %1 = linalg.reshape %0 [[0, 1], [2] ] :
405
402
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
406
403
```
407
404
408
405
```mlir
409
406
// Dimension expansion i -> (i', j') and (k) -> (k')
410
- %1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k) ] :
407
+ %1 = linalg.reshape %0 [[0, 1], [2] ] :
411
408
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
412
409
```
413
410
}];
@@ -417,24 +414,24 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
417
414
}];
418
415
let hasFolder = 1;
419
416
let hasCanonicalizer = 1;
417
+ let printer = [{ return ::print(p, *this); }];
418
+ let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
420
419
}
421
420
422
421
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
423
422
"tensor_reshape",
424
423
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
425
424
["reifyReturnTypeShapesPerResultDim"]>]>,
426
425
Arguments<(ins AnyTensor:$src,
427
- AffineMapArrayAttr :$reassociation)>,
426
+ IndexListArrayAttr :$reassociation)>,
428
427
Results<(outs AnyTensor:$result)> {
429
428
let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
430
429
let description = [{
431
430
The `linalg.reshape` op produces a new tensor whose sizes are a
432
431
reassociation of the original `src`.
433
432
434
433
A reassociation is defined as a continuous grouping of dimensions and is
435
- represented with an affine map array attribute. In the future,
436
- non-continuous groupings may be allowed (i.e. permutations, reindexings
437
- etc).
434
+ represented with an array of I64ArrayAttr attribute.
438
435
439
436
A reshape may either collapse or expand dimensions, depending on the
440
437
relationship between source and target tensor ranks. The verification rule
@@ -453,14 +450,14 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
453
450
454
451
```mlir
455
452
// Dimension collapse (i, j) -> i' and k -> k'
456
- %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
457
- tensor<?x?x?xf32> into tensor<?x?xf32>
453
+ %b = linalg.tensor_reshape %a [[0, 1], [2]]
454
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
458
455
```
459
456
460
457
```mlir
461
458
// Dimension expansion i -> (i', j') and (k) -> (k')
462
- %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
463
- tensor<?x?xf32> into tensor<?x?x?xf32>
459
+ %b = linalg.tensor_reshape %a [[0, 1], [2]]
460
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
464
461
```
465
462
}];
466
463
let extraClassDeclaration = commonExtraClassDeclaration # [{
@@ -473,6 +470,8 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
473
470
}];
474
471
let hasFolder = 1;
475
472
let hasCanonicalizer = 1;
473
+ let printer = [{ return ::print(p, *this); }];
474
+ let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
476
475
}
477
476
478
477
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
0 commit comments