Skip to content

Commit 2865d11

Browse files
committed
[mlir] Use ReassociationIndices instead of affine maps in linalg.reshape.
Differential Revision: https://reviews.llvm.org/D101861
1 parent e4eec51 commit 2865d11

File tree

13 files changed

+508
-742
lines changed

13 files changed

+508
-742
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -315,55 +315,54 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
315315
// Builders for a contracting reshape whose result type is computed from
316316
// `src` and `reassociation`.
317317
OpBuilder<(ins "Value":$src,
318-
"ArrayRef<ReassociationExprs>":$reassociation,
318+
"ArrayRef<ReassociationIndices>":$reassociation,
319319
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
320320
OpBuilder<(ins "Value":$src,
321-
"ArrayRef<ReassociationIndices>":$reassociation,
321+
"ArrayRef<ReassociationExprs>":$reassociation,
322322
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
323323
[{
324324
auto reassociationMaps =
325-
convertReassociationIndicesToMaps($_builder, reassociation);
325+
convertReassociationMapsToIndices($_builder, reassociation);
326326
build($_builder, $_state, src, reassociationMaps, attrs);
327327
}]>,
328328

329329
// Builders for a reshape whose result type is passed explicitly. This may
330330
// be either a contracting or expanding reshape.
331331
OpBuilder<(ins "Type":$resultType, "Value":$src,
332-
"ArrayRef<ReassociationExprs>":$reassociation,
332+
"ArrayRef<ReassociationIndices>":$reassociation,
333333
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
334334
OpBuilder<(ins "Type":$resultType, "Value":$src,
335-
"ArrayRef<ReassociationIndices>":$reassociation,
335+
"ArrayRef<ReassociationExprs>":$reassociation,
336336
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
337337
[{
338338
auto reassociationMaps =
339-
convertReassociationIndicesToMaps($_builder, reassociation);
339+
convertReassociationMapsToIndices($_builder, reassociation);
340340
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
341341
}]>
342342
];
343343

344344
code commonExtraClassDeclaration = [{
345345
static StringRef getReassociationAttrName() { return "reassociation"; }
346-
SmallVector<AffineMap, 4> getReassociationMaps() {
347-
return llvm::to_vector<4>(llvm::map_range(reassociation(), [
348-
](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
349-
}
350-
SmallVector<ReassociationExprs, 4> getReassociationExprs() {
351-
return
352-
llvm::to_vector<4>(llvm::map_range(reassociation(),
353-
[](Attribute a) {
354-
return llvm::to_vector<2>(
355-
a.cast<AffineMapAttr>().getValue().getResults());
356-
}));
357-
}
358-
}];
359-
let assemblyFormat = [{
360-
$src $reassociation attr-dict `:` type($src) `into` type(results)
346+
SmallVector<AffineMap, 4> getReassociationMaps();
347+
SmallVector<ReassociationExprs, 4> getReassociationExprs();
348+
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
349+
SmallVector<ReassociationIndices, 4> reassociationIndices;
350+
for (auto attr : reassociation())
351+
reassociationIndices.push_back(llvm::to_vector<2>(
352+
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
353+
return indexAttr.cast<IntegerAttr>().getInt();
354+
})));
355+
return reassociationIndices;
356+
};
361357
}];
362358
}
363359

360+
def IndexListArrayAttr :
361+
TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
362+
364363
def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
365364
[DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
366-
Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr:$reassociation)>,
365+
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
367366
Results<(outs AnyStridedMemRef:$result)> {
368367
let summary = "linalg.reshape produces a new view into the operand view";
369368
let description = [{
@@ -373,9 +372,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
373372
and copies.
374373

375374
A reassociation is defined as a continuous grouping of dimensions and is
376-
represented with an affine map array attribute. In the future,
377-
non-continuous groupings may be allowed (i.e. permutations, reindexings
378-
etc).
375+
represented with an array of I64ArrayAttr attribute.
379376

380377
For now, it is assumed that either:
381378
1. a reassociation produces and consumes contiguous MemRefType or,
@@ -401,13 +398,13 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
401398

402399
```mlir
403400
// Dimension collapse (i, j) -> i' and k -> k'
404-
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
401+
%1 = linalg.reshape %0 [[0, 1], [2]] :
405402
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
406403
```
407404

408405
```mlir
409406
// Dimension expansion i -> (i', j') and (k) -> (k')
410-
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
407+
%1 = linalg.reshape %0 [[0, 1], [2]] :
411408
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
412409
```
413410
}];
@@ -417,24 +414,24 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
417414
}];
418415
let hasFolder = 1;
419416
let hasCanonicalizer = 1;
417+
let printer = [{ return ::print(p, *this); }];
418+
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
420419
}
421420

422421
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
423422
"tensor_reshape",
424423
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
425424
["reifyReturnTypeShapesPerResultDim"]>]>,
426425
Arguments<(ins AnyTensor:$src,
427-
AffineMapArrayAttr:$reassociation)>,
426+
IndexListArrayAttr:$reassociation)>,
428427
Results<(outs AnyTensor:$result)> {
429428
let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
430429
let description = [{
431430
The `linalg.reshape` op produces a new tensor whose sizes are a
432431
reassociation of the original `src`.
433432

434433
A reassociation is defined as a continuous grouping of dimensions and is
435-
represented with an affine map array attribute. In the future,
436-
non-continuous groupings may be allowed (i.e. permutations, reindexings
437-
etc).
434+
represented with an array of I64ArrayAttr attribute.
438435

439436
A reshape may either collapse or expand dimensions, depending on the
440437
relationship between source and target tensor ranks. The verification rule
@@ -453,14 +450,14 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
453450

454451
```mlir
455452
// Dimension collapse (i, j) -> i' and k -> k'
456-
%b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
457-
tensor<?x?x?xf32> into tensor<?x?xf32>
453+
%b = linalg.tensor_reshape %a [[0, 1], [2]]
454+
: tensor<?x?x?xf32> into tensor<?x?xf32>
458455
```
459456

460457
```mlir
461458
// Dimension expansion i -> (i', j') and (k) -> (k')
462-
%b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
463-
tensor<?x?xf32> into tensor<?x?x?xf32>
459+
%b = linalg.tensor_reshape %a [[0, 1], [2]]
460+
: tensor<?x?xf32> into tensor<?x?x?xf32>
464461
```
465462
}];
466463
let extraClassDeclaration = commonExtraClassDeclaration # [{
@@ -473,6 +470,8 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
473470
}];
474471
let hasFolder = 1;
475472
let hasCanonicalizer = 1;
473+
let printer = [{ return ::print(p, *this); }];
474+
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
476475
}
477476

478477
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,

0 commit comments

Comments
 (0)