Skip to content

Commit f75f391

Browse files
committed
[MLIR][Linalg] Refactor transforms to use linalg::getDynOperands helper
getDynOperands behavior is commonly used in a number of passes. Refactored to use a helper function and avoid code reuse. Differential Revision: https://reviews.llvm.org/D94340
1 parent eefd420 commit f75f391

File tree

7 files changed

+73
-38
lines changed

7 files changed

+73
-38
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- Utils.h - General transformation utilities ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines prototypes for various transformation utilities for
10+
// the StandardOps dialect. These are not passes by themselves but are used
11+
// either by passes, optimization sequences, or in turn by other transformation
12+
// utilities.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
17+
#define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
18+
19+
#include "mlir/IR/Value.h"
20+
21+
namespace mlir {
22+
23+
class Location;
24+
class OpBuilder;
25+
26+
/// Given an operation, retrieves the value of each dynamic dimension through
27+
/// constructing the necessary DimOp operators.
28+
SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
29+
30+
} // end namespace mlir
31+
32+
#endif // MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1414
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1515
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16+
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
1617
#include "mlir/Dialect/Vector/VectorOps.h"
1718
#include "mlir/IR/BuiltinDialect.h"
1819
#include "mlir/IR/Operation.h"
@@ -21,18 +22,6 @@
2122
using namespace ::mlir;
2223
using namespace ::mlir::linalg;
2324

24-
static SmallVector<Value, 4> getDynOperands(Location loc, Value val,
25-
OpBuilder &b) {
26-
SmallVector<Value, 4> dynOperands;
27-
auto shapedType = val.getType().cast<ShapedType>();
28-
for (auto dim : llvm::enumerate(shapedType.getShape())) {
29-
if (dim.value() == TensorType::kDynamicSize) {
30-
dynOperands.push_back(b.create<DimOp>(loc, val, dim.index()));
31-
}
32-
}
33-
return dynOperands;
34-
}
35-
3625
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
3726
auto memrefType = memref.getType().cast<MemRefType>();
3827
auto alloc =

mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
#include "PassDetail.h"
1212
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
13+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1314
#include "mlir/Dialect/StandardOps/IR/Ops.h"
15+
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
1416
#include "mlir/Transforms/DialectConversion.h"
1517

1618
using namespace mlir;
@@ -62,18 +64,9 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
6264
// Extract static / dynamic shape mix from the first operand.
6365
Value firstOperand = operands.front();
6466
auto rankedTensorType = t.cast<RankedTensorType>();
65-
SmallVector<Value, 8> dynamicShape;
66-
SmallVector<int64_t, 8> staticShape;
67-
dynamicShape.reserve(rankedTensorType.getRank());
68-
staticShape.reserve(rankedTensorType.getRank());
69-
unsigned idx = 0;
70-
for (auto shape : rankedTensorType.getShape()) {
71-
staticShape.push_back(shape);
72-
if (rankedTensorType.isDynamicDim(idx))
73-
dynamicShape.push_back(b.create<DimOp>(loc, firstOperand, idx));
74-
++idx;
75-
}
76-
// Create init tensor.
67+
auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
68+
auto dynamicShape = getDynOperands(loc, firstOperand, b);
69+
7770
res.push_back(b.create<linalg::InitTensorOp>(
7871
loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
7972
}

mlir/lib/Dialect/StandardOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRStandard
22
IR/Ops.cpp
33
EDSC/Builders.cpp
44
EDSC/Intrinsics.cpp
5+
Utils/Utils.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements utilities for the Linalg dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
14+
15+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
16+
17+
using namespace mlir;
18+
19+
SmallVector<Value, 4> mlir::getDynOperands(Location loc, Value val,
20+
OpBuilder &b) {
21+
SmallVector<Value, 4> dynOperands;
22+
auto shapedType = val.getType().cast<ShapedType>();
23+
for (auto dim : llvm::enumerate(shapedType.getShape())) {
24+
if (dim.value() == TensorType::kDynamicSize)
25+
dynOperands.push_back(b.create<DimOp>(loc, val, dim.index()));
26+
}
27+
return dynOperands;
28+
}

mlir/lib/Transforms/BufferDeallocation.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "PassDetail.h"
5555
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
5656
#include "mlir/Dialect/StandardOps/IR/Ops.h"
57+
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
5758
#include "mlir/IR/Operation.h"
5859
#include "mlir/Interfaces/ControlFlowInterfaces.h"
5960
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -394,13 +395,8 @@ class BufferDeallocation : BufferPlacementTransformationBase {
394395

395396
// Extract information about dynamically shaped types by
396397
// extracting their dynamic dimensions.
397-
SmallVector<Value, 4> dynamicOperands;
398-
for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
399-
if (!ShapedType::isDynamic(shapeElement.value()))
400-
continue;
401-
dynamicOperands.push_back(builder.create<DimOp>(
402-
terminator->getLoc(), sourceValue, shapeElement.index()));
403-
}
398+
auto dynamicOperands =
399+
getDynOperands(terminator->getLoc(), sourceValue, builder);
404400

405401
// TODO: provide a generic interface to create dialect-specific
406402
// Alloc and CopyOp nodes.

mlir/lib/Transforms/PipelineDataTransfer.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Analysis/LoopAnalysis.h"
1818
#include "mlir/Analysis/Utils.h"
1919
#include "mlir/Dialect/Affine/IR/AffineOps.h"
20+
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/Transforms/LoopUtils.h"
2223
#include "mlir/Transforms/Utils.h"
@@ -83,13 +84,8 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
8384
// The double buffer is allocated right before 'forOp'.
8485
OpBuilder bOuter(forOp);
8586
// Put together alloc operands for any dynamic dimensions of the memref.
86-
SmallVector<Value, 4> allocOperands;
87-
unsigned dynamicDimCount = 0;
88-
for (auto dimSize : oldMemRefType.getShape()) {
89-
if (dimSize == -1)
90-
allocOperands.push_back(
91-
bOuter.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
92-
}
87+
88+
auto allocOperands = getDynOperands(forOp.getLoc(), oldMemRef, bOuter);
9389

9490
// Create and place the alloc right before the 'affine.for' operation.
9591
Value newMemRef =

0 commit comments

Comments
 (0)