Skip to content

Commit ae1ea0b

Browse files
committed
[mlir] Decompose Bufferization Clone operation into Memref Alloc and Copy.
This patch introduces a new conversion to convert bufferization.clone operations into a memref.alloc and a memref.copy operation. This transformation is needed to transform all remaining clones which "survive" all previous transformations, before a given program is lowered further (to LLVM e.g.). Otherwise, these operations cannot be handled anymore and lead to compile errors. See: https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665 Differential Revision: https://reviews.llvm.org/D114233
1 parent 3356d88 commit ae1ea0b

File tree

7 files changed

+180
-0
lines changed

7 files changed

+180
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- BufferizationToMemRef.h - Bufferization to MemRef conversion -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
10+
#define MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
class RewritePatternSet;
17+
18+
/// Collect a set of patterns to convert memory-related operations from the
19+
/// Bufferization dialect to the MemRef dialect.
20+
void populateBufferizationToMemRefConversionPatterns(
21+
RewritePatternSet &patterns);
22+
23+
std::unique_ptr<Pass> createBufferizationToMemRefPass();
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
1515
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
1616
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
17+
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
1718
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
1819
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
1920
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,17 @@ def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
126126
];
127127
}
128128

129+
//===----------------------------------------------------------------------===//
130+
// BufferizationToMemRef
131+
//===----------------------------------------------------------------------===//
132+
133+
def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
134+
let summary = "Convert operations from the Bufferization dialect to the "
135+
"MemRef dialect";
136+
let constructor = "mlir::createBufferizationToMemRefPass()";
137+
let dependentDialects = ["arith::ArithmeticDialect", "memref::MemRefDialect"];
138+
}
139+
129140
//===----------------------------------------------------------------------===//
130141
// ComplexToLLVM
131142
//===----------------------------------------------------------------------===//
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert Bufferization dialect to MemRef
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "../PassDetail.h"
15+
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
16+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/Support/LogicalResult.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
23+
using namespace mlir;
24+
25+
namespace {
26+
/// The CloneOpConversion transforms all bufferization clone operations into
27+
/// memref alloc and memref copy operations. In the dynamic-shape case, it also
28+
/// emits additional dim and constant operations to determine the shape. This
29+
/// conversion does not resolve memory leaks if it is used alone.
30+
struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
31+
using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
32+
33+
LogicalResult
34+
matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
35+
ConversionPatternRewriter &rewriter) const override {
36+
// Check for unranked memref types which are currently not supported.
37+
Type type = op.getType();
38+
if (type.isa<UnrankedMemRefType>()) {
39+
return rewriter.notifyMatchFailure(
40+
op, "UnrankedMemRefType is not supported.");
41+
}
42+
43+
// Transform a clone operation into alloc + copy operation and pay
44+
// attention to the shape dimensions.
45+
MemRefType memrefType = type.cast<MemRefType>();
46+
Location loc = op->getLoc();
47+
SmallVector<Value, 4> dynamicOperands;
48+
for (int i = 0; i < memrefType.getRank(); ++i) {
49+
if (!memrefType.isDynamicDim(i))
50+
continue;
51+
Value size = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
52+
Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
53+
dynamicOperands.push_back(dim);
54+
}
55+
Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
56+
dynamicOperands);
57+
rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
58+
return success();
59+
}
60+
};
61+
} // namespace
62+
63+
void mlir::populateBufferizationToMemRefConversionPatterns(
64+
RewritePatternSet &patterns) {
65+
patterns.add<CloneOpConversion>(patterns.getContext());
66+
}
67+
68+
namespace {
69+
struct BufferizationToMemRefPass
70+
: public ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
71+
BufferizationToMemRefPass() = default;
72+
73+
void runOnOperation() override {
74+
RewritePatternSet patterns(&getContext());
75+
populateBufferizationToMemRefConversionPatterns(patterns);
76+
77+
ConversionTarget target(getContext());
78+
target.addLegalDialect<memref::MemRefDialect>();
79+
target.addLegalOp<arith::ConstantOp>();
80+
target.addIllegalDialect<bufferization::BufferizationDialect>();
81+
82+
if (failed(applyPartialConversion(getOperation(), target,
83+
std::move(patterns))))
84+
signalPassFailure();
85+
}
86+
};
87+
} // namespace
88+
89+
std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
90+
return std::make_unique<BufferizationToMemRefPass>();
91+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_mlir_conversion_library(MLIRBufferizationToMemRef
2+
BufferizationToMemRef.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/BufferizationToMemRef
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRBufferization
12+
)

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(ArithmeticToLLVM)
33
add_subdirectory(ArithmeticToSPIRV)
44
add_subdirectory(ArmNeon2dToIntr)
55
add_subdirectory(AsyncToLLVM)
6+
add_subdirectory(BufferizationToMemRef)
67
add_subdirectory(ComplexToLLVM)
78
add_subdirectory(ComplexToStandard)
89
add_subdirectory(GPUCommon)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt -verify-diagnostics -convert-bufferization-to-memref -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: @conversion_static
4+
func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
5+
%0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
6+
memref.dealloc %arg0 : memref<2xf32>
7+
return %0 : memref<2xf32>
8+
}
9+
10+
// CHECK: %[[ALLOC:.*]] = memref.alloc
11+
// CHECK-NEXT: memref.copy %[[ARG:.*]], %[[ALLOC]]
12+
// CHECK-NEXT: memref.dealloc %[[ARG]]
13+
// CHECK-NEXT: return %[[ALLOC]]
14+
15+
// -----
16+
17+
// CHECK-LABEL: @conversion_dynamic
18+
func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
19+
%1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
20+
memref.dealloc %arg0 : memref<?xf32>
21+
return %1 : memref<?xf32>
22+
}
23+
24+
// CHECK: %[[CONST:.*]] = arith.constant
25+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
26+
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
27+
// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
28+
// CHECK-NEXT: memref.dealloc %[[ARG]]
29+
// CHECK-NEXT: return %[[ALLOC]]
30+
31+
// -----
32+
33+
func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
34+
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
35+
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
36+
memref.dealloc %arg0 : memref<*xf32>
37+
return %1 : memref<*xf32>
38+
}

0 commit comments

Comments
 (0)