Skip to content

Commit faa66b1

Browse files
committed
[mlir] Bufferize tensor constant ops
We lower them to a std.global_memref (uniqued by constant value) + a std.get_global_memref to produce the corresponding memref value. This allows removing Linalg's somewhat hacky lowering of tensor constants, now that std properly supports this. Differential Revision: https://reviews.llvm.org/D91306
1 parent ad2f9f6 commit faa66b1

File tree

13 files changed

+230
-81
lines changed

13 files changed

+230
-81
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
6464
def LinalgBufferize : Pass<"linalg-bufferize", "ModuleOp"> {
6565
let summary = "Bufferize the linalg dialect";
6666
let constructor = "mlir::createLinalgBufferizePass()";
67-
let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"];
67+
let dependentDialects = ["linalg::LinalgDialect"];
6868
}
6969

7070
def LinalgLowerToParallelLoops

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ std::unique_ptr<Pass> createStdBufferizePass();
3535
/// Creates an instance of func bufferization pass.
3636
std::unique_ptr<Pass> createFuncBufferizePass();
3737

38+
/// Creates an instance of tensor constant bufferization pass.
39+
std::unique_ptr<Pass> createTensorConstantBufferizePass();
40+
3841
/// Creates an instance of the StdExpand pass that legalizes Std
3942
/// dialect ops to be convertible to LLVM. For example,
4043
/// `std.ceildivi_signed` gets transformed to a number of std operations,

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,17 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
5151
let constructor = "mlir::createFuncBufferizePass()";
5252
}
5353

54+
def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
55+
let summary = "Bufferize tensor constants.";
56+
let description = [{
57+
This pass bufferizes tensor constants.
58+
59+
This pass needs to be a module pass because it inserts std.global_memref
60+
ops into the module, which cannot be done safely from a function pass due to
61+
multi-threading. Most other bufferization passes can run in parallel at
62+
function granularity.
63+
}];
64+
let constructor = "mlir::createTensorConstantBufferizePass()";
65+
}
66+
5467
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
1+
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
22
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
33
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
44
// RUN: | FileCheck %s

mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
1+
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
22
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
33
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
44
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
2+
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
3+
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
4+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
5+
// RUN: | FileCheck %s
6+
7+
func @main() {
8+
%const = constant dense<10.0> : tensor<2xf32>
9+
%insert_val = constant dense<20.0> : tensor<1xf32>
10+
%inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
11+
12+
%unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
13+
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
14+
15+
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
16+
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
17+
// CHECK-NEXT: [20, 10]
18+
19+
return
20+
}
21+
22+
func @print_memref_f32(%ptr : tensor<*xf32>)

mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
1+
// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
22
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
33
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
44
// RUN: | FileCheck %s

mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
1+
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
22
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
33
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
44
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
55
// RUN: | FileCheck %s
66

77
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
8-
// RUN: -scf-bufferize -std-bufferize -func-bufferize -convert-linalg-to-loops \
8+
// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize -convert-linalg-to-loops \
99
// RUN: -convert-scf-to-std -convert-linalg-to-llvm | \
1010
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
1111
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -325,60 +325,6 @@ class SubTensorInsertOpConverter
325325
return success();
326326
}
327327
};
328-
329-
/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
330-
/// stored in memory. A linalg.reshape is introduced to convert to the desired
331-
/// n-D buffer form.
332-
class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> {
333-
public:
334-
using OpConversionPattern::OpConversionPattern;
335-
336-
LogicalResult
337-
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
338-
ConversionPatternRewriter &rewriter) const final {
339-
340-
RankedTensorType rankedTensorType =
341-
op.getType().dyn_cast<RankedTensorType>();
342-
if (!rankedTensorType)
343-
return failure();
344-
if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
345-
return s == 0 || ShapedType::isDynamic(s);
346-
}))
347-
return failure();
348-
349-
int64_t nElements = 1;
350-
for (int64_t s : rankedTensorType.getShape())
351-
nElements *= s;
352-
Type elementType = rankedTensorType.getElementType();
353-
MemRefType memrefType =
354-
getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
355-
VectorType flatVectorType = VectorType::get({nElements}, elementType);
356-
MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
357-
MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
358-
359-
Location loc = op.getLoc();
360-
auto attr = op.getValue().cast<DenseElementsAttr>();
361-
Value alloc =
362-
rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
363-
Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
364-
attr.reshape(flatVectorType));
365-
rewriter.create<StoreOp>(loc, cstVec, alloc);
366-
367-
Value memref =
368-
rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
369-
if (rankedTensorType.getRank() > 1) {
370-
// Introduce a linalg.reshape to flatten the memref.
371-
AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
372-
/*numDims=*/rankedTensorType.getRank(), op.getContext());
373-
memref = rewriter.create<linalg::ReshapeOp>(
374-
loc, memrefType, memref,
375-
rewriter.getAffineMapArrayAttr(collapseAllDims));
376-
}
377-
rewriter.replaceOp(op, memref);
378-
379-
return success();
380-
}
381-
};
382328
} // namespace
383329

384330
namespace {
@@ -391,7 +337,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
391337
BufferizeTypeConverter typeConverter;
392338

393339
// Mark all Standard operations legal.
394-
target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
340+
target.addLegalDialect<StandardOpsDialect>();
395341
target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
396342

397343
// Mark all Linalg operations illegal as long as they work on tensors.
@@ -422,8 +368,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
422368
patterns.insert<
423369
// clang-format off
424370
SubTensorOpConverter,
425-
SubTensorInsertOpConverter,
426-
TensorConstantOpConverter
371+
SubTensorInsertOpConverter
427372
// clang-format on
428373
>(typeConverter, context);
429374
}

mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
44
ExpandTanh.cpp
55
FuncBufferize.cpp
66
FuncConversions.cpp
7+
TensorConstantBufferize.cpp
78

89
ADDITIONAL_HEADER_DIRS
910
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements bufferization of tensor-valued std.constant ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "PassDetail.h"
14+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
15+
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16+
#include "mlir/IR/BlockAndValueMapping.h"
17+
#include "mlir/Transforms/Bufferize.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
using namespace mlir;
21+
22+
namespace {
23+
// This class creates global ops for all tensor-valued constants in the program.
24+
// It creates them with pretty names and makes sure that duplicate globals
25+
// aren't created.
26+
class GlobalCreator {
27+
public:
28+
explicit GlobalCreator(ModuleOp module);
29+
GlobalMemrefOp getGlobalFor(Attribute attr) {
30+
assert(globals.find(attr) != globals.end() && "unknown constant attr");
31+
return globals[attr];
32+
}
33+
34+
private:
35+
DenseMap<Attribute, GlobalMemrefOp> globals;
36+
};
37+
38+
GlobalCreator::GlobalCreator(ModuleOp module) {
39+
BufferizeTypeConverter typeConverter;
40+
// Create a builder without an insertion point. We will insert using the
41+
// symbol table to guarantee unique names.
42+
OpBuilder globalBuilder(module.getContext());
43+
SymbolTable symbolTable(module);
44+
module.walk([&](ConstantOp op) {
45+
// We only want tensor constants for now.
46+
auto type = op.getType().dyn_cast<RankedTensorType>();
47+
if (!type)
48+
return;
49+
// If we already have a global for this constant value, no need to do
50+
// anything else.
51+
auto it = globals.find(op.getValue());
52+
if (it != globals.end())
53+
return;
54+
55+
// Create a pretty name.
56+
SmallString<64> buf;
57+
llvm::raw_svector_ostream os(buf);
58+
interleave(type.getShape(), os, "x");
59+
os << "x" << type.getElementType();
60+
61+
auto global = globalBuilder.create<GlobalMemrefOp>(
62+
op.getLoc(), (Twine("__constant_") + os.str()).str(),
63+
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
64+
/*type=*/
65+
TypeAttr::get(typeConverter.convertType(type)), /*initial_value=*/
66+
op.getValue().cast<ElementsAttr>(), /*constant=*/true);
67+
symbolTable.insert(global);
68+
// The symbol table inserts at the end of the module, but globals are a bit
69+
// nicer if they are at the beginning.
70+
global.getOperation()->moveBefore(&module.front());
71+
globals[op.getValue()] = global;
72+
});
73+
}
74+
} // namespace
75+
76+
namespace {
77+
class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
78+
public:
79+
BufferizeTensorConstantOp(GlobalCreator &globals,
80+
TypeConverter &typeConverter, MLIRContext *context)
81+
: OpConversionPattern<ConstantOp>(typeConverter, context, /*benefit=*/1),
82+
globals(globals) {}
83+
84+
LogicalResult
85+
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
86+
ConversionPatternRewriter &rewriter) const override {
87+
auto type = op.getType().dyn_cast<RankedTensorType>();
88+
if (!type)
89+
return failure();
90+
91+
auto globalMemref = globals.getGlobalFor(op.value());
92+
rewriter.replaceOpWithNewOp<GetGlobalMemrefOp>(op, globalMemref.type(),
93+
globalMemref.getName());
94+
return success();
95+
}
96+
GlobalCreator &globals;
97+
};
98+
} // namespace
99+
100+
namespace {
101+
struct TensorConstantBufferizePass
102+
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
103+
void runOnOperation() override {
104+
auto module = getOperation();
105+
GlobalCreator globals(module);
106+
107+
auto *context = &getContext();
108+
BufferizeTypeConverter typeConverter;
109+
OwningRewritePatternList patterns;
110+
ConversionTarget target(*context);
111+
112+
target.addLegalDialect<StandardOpsDialect>();
113+
patterns.insert<BufferizeTensorConstantOp>(globals, typeConverter, context);
114+
target.addDynamicallyLegalOp<ConstantOp>(
115+
[&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
116+
if (failed(applyPartialConversion(module, target, std::move(patterns))))
117+
signalPassFailure();
118+
}
119+
};
120+
} // namespace
121+
122+
std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
123+
return std::make_unique<TensorConstantBufferizePass>();
124+
}

mlir/test/Dialect/Linalg/bufferize.mlir

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,24 +94,6 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
9494

9595
// -----
9696

97-
// Check lowering of tensor-valued std.constant's
98-
// TODO: Move this to std-bufferize.
99-
100-
// CHECK-LABEL: func @constant() -> tensor<2x3xf32> {
101-
// CHECK: %[[VECTOR_MEMREF:.*]] = alloc() : memref<vector<6xf32>>
102-
// CHECK: %[[VECTOR_CONST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32>
103-
// CHECK: store %[[VECTOR_CONST]], %[[VECTOR_MEMREF]][] : memref<vector<6xf32>>
104-
// CHECK: %[[MEMREF:.*]] = vector.type_cast %[[VECTOR_MEMREF]] : memref<vector<6xf32>> to memref<6xf32>
105-
// CHECK: %[[FINAL_SHAPE:.*]] = linalg.reshape %[[MEMREF]] [#map] : memref<6xf32> into memref<2x3xf32>
106-
// CHECK: %[[RESULT:.*]] = tensor_load %[[FINAL_SHAPE]] : memref<2x3xf32>
107-
// CHECK: return %[[RESULT]] : tensor<2x3xf32>
108-
func @constant() -> tensor<2x3xf32> {
109-
%0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
110-
return %0: tensor<2x3xf32>
111-
}
112-
113-
// -----
114-
11597
#accesses = [
11698
affine_map<(i, j, k) -> (j, i, k)>,
11799
affine_map<(i, j, k) -> (i, j)>
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file
2+
3+
// CHECK-LABEL: module {
4+
// We check the debug name too since we put some effort into making that readable.
5+
// The name isn't load-bearing though.
6+
// CHECK: global_memref "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
7+
// CHECK: @basic
8+
func @basic() -> tensor<3x4xf32> {
9+
// CHECK: %[[MEMREF:.*]] = get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
10+
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
11+
%0 = constant dense<7.0> : tensor<3x4xf32>
12+
// CHECK: return %[[TENSOR]]
13+
return %0 : tensor<3x4xf32>
14+
}
15+
16+
// CHECK: }
17+
18+
// -----
19+
20+
// CHECK-LABEL: module {
21+
22+
// Only one global is created.
23+
// CHECK: global_memref
24+
// CHECK-NOT: global_memref
25+
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
26+
%0 = constant dense<7.0> : tensor<3x4xf32>
27+
%1 = constant dense<7.0> : tensor<3x4xf32>
28+
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
29+
}
30+
31+
// CHECK: }
32+
33+
// -----
34+
35+
// CHECK-LABEL: module {
36+
37+
// Two globals are created.
38+
// CHECK: global_memref
39+
// CHECK: global_memref
40+
// CHECK-NOT: global_memref
41+
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
42+
%0 = constant dense<7.0> : tensor<3x4xf32>
43+
%1 = constant dense<8.0> : tensor<3x4xf32>
44+
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
45+
}
46+
47+
// CHECK: }
48+
49+
// -----
50+
51+
// CHECK-LABEL: module {
52+
// We don't convert non-tensor globals.
53+
// CHECK-NOT: global_memref
54+
func @non_tensor() {
55+
%0 = constant 7 : i32
56+
return
57+
}
58+
59+
// CHECK: }

0 commit comments

Comments
 (0)