Skip to content

[mlir][MLProgram] Add MLProgram to MemRef bufferization pass #75103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace ml_program {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace ml_program
} // namespace mlir

#endif // MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
Expand Down Expand Up @@ -160,6 +161,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
ml_program::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
<< "\n//===-------------------------------------------===//\n");
}

// Return early if the top-level op is entirely gone.
if (erasedOps.contains(op))
return success();

// Fold all to_memref(to_tensor(x)) pairs.
for (Operation *op : toMemrefOps) {
rewriter.setInsertionPoint(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}

// Bufferize all other ops.
for (Operation &op : moduleOp.getOps()) {
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op))
continue;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::ml_program;

namespace mlir {
namespace ml_program {
namespace {

template <typename Interface, typename Op>
struct ExternalModelBase
: public BufferizableOpInterface::ExternalModel<Interface, Op> {

AliasingValueList getAliasingValues(Operation *, OpOperand &,
const AnalysisState &) const {
return {};
}

BufferRelation bufferRelation(Operation *, OpResult,
const AnalysisState &) const {
return BufferRelation::Unknown;
}
};

/// Bufferization of ml_program.global into a memref.global
struct GlobalOpInterface
: public ExternalModelBase<GlobalOpInterface, GlobalOp> {

bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool hasTensorSemantics(Operation *) const { return true; }

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &) const {
auto globalOp = cast<GlobalOp>(op);
if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value");

auto tensorType = cast<TensorType>(globalOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);

replaceOpWithNewBufferizedOp<memref::GlobalOp>(
rewriter, globalOp, globalOp.getSymName(),
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
/*type=*/cast<MemRefType>(memrefType),
/*initial_value=*/globalOp.getValue().value(),
/*constant=*/!globalOp.getIsMutable(),
/*alignment=*/nullptr);

return success();
}
};

/// Bufferization of ml_program.global_load into a memref.get_global
struct GlobalLoadOpInterface
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {

bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool isWritable(Operation *, Value, const AnalysisState &) const {
return false;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &) const {
auto globalLoadOp = cast<GlobalLoadOp>(op);

auto tensorType = cast<TensorType>(globalLoadOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);

replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, globalLoadOp, memrefType,
globalLoadOp.getGlobalAttr().getLeafReference());

return success();
}
};

/// Bufferization of ml_program.global_store into a memref.get_global and
/// memcpy
struct GlobalStoreOpInterface
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {

bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return true;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto globalStoreOp = cast<GlobalStoreOp>(op);

auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);

auto loc = globalStoreOp.getLoc();
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());

auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
if (failed(sourceMemref)) {
return failure();
}

auto memcpy =
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
if (failed(memcpy)) {
return failure();
}
rewriter.eraseOp(globalStoreOp);

return success();
}
};
} // namespace

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
});
}
} // namespace ml_program
} // namespace mlir
1 change: 1 addition & 0 deletions mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMLProgramTransforms
BufferizableOpInterfaceImpl.cpp
PipelineGlobalOps.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s

// CHECK-LABEL: memref.global "private" @global
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>

// CHECK-LABEL: func.func @global_load_store
func.func @global_load_store() -> i64 {
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
// CHECK: %[[VALUE:.*]] = memref.load %[[GLOBAL_1]][]
// CHECK: %[[NEW_VALUE:.*]] = arith.muli %[[VALUE]], %[[CST127]]
// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
// CHECK: memref.store %[[NEW_VALUE]], %[[ALLOC]][]
// CHECK: %[[GLOBAL_2:.*]] = memref.get_global @global
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_2]]
// CHECK: return %[[NEW_VALUE]]
%c127 = arith.constant 127 : i64
%0 = ml_program.global_load @global : tensor<i64>
%extracted = tensor.extract %0[] : tensor<i64>
%1 = arith.muli %extracted, %c127 : i64
%inserted = tensor.insert %1 into %0[] : tensor<i64>
ml_program.global_store @global = %inserted : tensor<i64>
return %1 : i64
}

// -----

// CHECK-LABEL: memref.global "private" @global
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>

// CHECK-LABEL: func.func @raw_hazard
func.func @raw_hazard() -> i64 {
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
// CHECK-DAG: %[[GLOBAL_2:.*]] = memref.get_global @global
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
// CHECK: memref.store %[[CST127]], %[[ALLOC]][]
// CHECK: %[[VAL:.*]] = memref.load %[[GLOBAL_2]][]
// CHECK: %[[GLOBAL_3:.*]] = memref.get_global @global
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_3]]
// CHECK: return %[[VAL]]
%c127 = arith.constant 127 : i64
%0 = ml_program.global_load @global : tensor<i64>
%1 = ml_program.global_load @global : tensor<i64>
%inserted = tensor.insert %c127 into %0[] : tensor<i64>
%extracted = tensor.extract %1[] : tensor<i64>
ml_program.global_store @global = %inserted : tensor<i64>
return %extracted : i64
}