Skip to content

[mlir][ods] Allow sharding of op definitions #89423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
add_subdirectory(tools/mlir-tblgen)
add_subdirectory(tools/mlir-src-sharder)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")

add_subdirectory(include/mlir)
add_subdirectory(lib)
Expand Down
38 changes: 38 additions & 0 deletions mlir/cmake/modules/AddMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
tablegen(MLIR ${ARGV})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)

# Get the current set of include paths for this td file.
cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
# Filter out any empty include items.
list(REMOVE_ITEM tblgen_includes "")

# Build the absolute path for the current input file.
if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
else()
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
endif()

# Append the includes used for this file to the tablegen_compile_commands
# file.
file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
"--- !FileInfo:\n"
" filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
" includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
)
endfunction()

# Clear out any pre-existing compile_commands file before processing. This
Expand Down Expand Up @@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()

# Declare sharded dialect operation declarations and definitions
function(add_sharded_ops ops_target shard_count)
set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
foreach(index RANGE ${shard_count})
set(SHARDED_SRC ${ops_target}.${index}.cpp)
list(APPEND SHARDED_SRCS ${SHARDED_SRC})
tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
endforeach()
add_public_tablegen_target(MLIR${ops_target}ShardGen)
set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
endfunction()

# Declare a dialect in the include directory
function(add_mlir_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
Expand Down
2 changes: 2 additions & 0 deletions mlir/cmake/modules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# Refer to the best host mlir-tbgen, which might be a host-optimized version
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
Expand Down Expand Up @@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
Expand Down
1 change: 1 addition & 0 deletions mlir/cmake/modules/MLIRConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/TableGen/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,22 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
/// Create a constraint uniquer with a unique prefix derived from the record
/// keeper with an optional tag.
StaticVerifierFunctionEmitter(raw_ostream &os,
const llvm::RecordKeeper &records);
const llvm::RecordKeeper &records,
StringRef tag = "");

/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);

/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
/// the generated file.
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);

/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
Expand Down Expand Up @@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
/// Emit pattern constraints.
void emitPatternConstraints();

/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all pattern constraints.
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);

Expand Down
15 changes: 6 additions & 9 deletions mlir/lib/TableGen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ using namespace mlir::tblgen;

/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
StringRef tag) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();

Expand All @@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
nameRef.consume_back(".td");

// Sanitize any invalid characters.
std::string uniqueName;
std::string uniqueName(tag);
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
Expand All @@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
}

StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const llvm::RecordKeeper &records)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}

void StaticVerifierFunctionEmitter::emitOpConstraints(
ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
collectOpConstraints(opDefs);
if (emitDecl)
return;

ArrayRef<llvm::Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/mlir-tblgen/shard-op-defs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS

include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "test";
}

class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

def OpA : Test_Op<"a">;
def OpB : Test_Op<"b">;
def OpC : Test_Op<"c">;

// DECLS: OpA
// DECLS: OpB
// DECLS: OpC
// DECLS: registerTestDialectOperations(
// DECLS: registerTestDialectOperations0(
// DECLS: registerTestDialectOperations1(

// DEFS-LABEL: GET_OP_DEFS_0
// DEFS: void test::registerTestDialectOperations(
// DEFS: void test::registerTestDialectOperations0(
// DEFS: OpAAdaptor
// DEFS: OpBAdaptor

// DEFS-LABEL: GET_OP_DEFS_1
// DEFS: void test::registerTestDialectOperations1(
// DEFS: OpCAdaptor
14 changes: 14 additions & 0 deletions mlir/tools/mlir-src-sharder/CMakeLists.txt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR src sharded makes me think it's related to . mlir files rather than ODS ones. Did you consider making this a mlir-tblgen "function" (such as attribute gen or doc gen) and then call it 2x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir-tblgen only ingests .td files as records. Do you want me to inject a hook into its main function to sniff the command and change its operating mode?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow: can you expand on the command line issue with mlir-tblgen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir-tblgen calls into the TableGen parser and then calls into a function based on the command line with the parsed records. In order to make it ingest a C++ file (or another kind of file), I have to intercept it in the main function:

Turn this

// Generator that prints records.
GenRegistration printRecords("print-records", "Print all records to stdout",
                             [](const RecordKeeper &records, raw_ostream &os) {
                               os << records;
                               return false;
                             });

int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }

Into this:

int main(int argc, char **argv) { 
  if (argv[1] == "shard-src-files") return shardSourceFiles(argc, argv)
  return MlirTblgenMain(argc, argv); 
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph ping!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping should be for @jpienaar who started in this direction first :)

I agree though that using mlir-tblgen for non-tablegen file does not seem right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't resolved?

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
set(LLVM_LINK_COMPONENTS Support)
set(LIBS MLIRSupport)

add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
mlir-src-sharder.cpp

DEPENDS
${LIBS}
)

set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})

mlir_check_all_link_libraries(mlir-src-sharder)
114 changes: 114 additions & 0 deletions mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/ToolOutputFile.h"

using namespace mlir;

/// Create a dependency file for `-d` option.
///
/// This functionality is generally only for the benefit of the build system,
/// and is modeled after the same option in TableGen.
static LogicalResult createDependencyFile(StringRef outputFilename,
StringRef dependencyFile) {
if (outputFilename == "-") {
llvm::errs() << "error: the option -d must be used together with -o\n";
return failure();
}

std::string errorMessage;
std::unique_ptr<llvm::ToolOutputFile> outputFile =
openOutputFile(dependencyFile, &errorMessage);
if (!outputFile) {
llvm::errs() << errorMessage << "\n";
return failure();
}

outputFile->os() << outputFilename << ":\n";
outputFile->keep();
return success();
}

int main(int argc, char **argv) {
// FIXME: This is necessary because we link in TableGen, which defines its
// options as static variables.. some of which overlap with our options.
llvm::cl::ResetCommandLineParser();

llvm::cl::opt<unsigned> opShardIndex(
"op-shard-index", llvm::cl::desc("The current shard index"));
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::list<std::string> includeDirs(
"I", llvm::cl::desc("Directory of include files"),
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
llvm::cl::opt<std::string> dependencyFilename(
"d", llvm::cl::desc("Dependency filename"),
llvm::cl::value_desc("filename"), llvm::cl::init(""));
llvm::cl::opt<bool> writeIfChanged(
"write-if-changed",
llvm::cl::desc("Only write to the output file if it changed"));

llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);

// Open the input file.
std::string errorMessage;
std::unique_ptr<llvm::MemoryBuffer> inputFile =
openInputFile(inputFilename, &errorMessage);
if (!inputFile) {
llvm::errs() << errorMessage << "\n";
return 1;
}

// Write the output to a buffer.
std::string outputStr;
llvm::raw_string_ostream os(outputStr);
os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
<< inputFile->getBuffer();

// Determine whether we need to write the output file.
bool shouldWriteOutput = true;
if (writeIfChanged) {
// Only update the real output file if there are any differences. This
// prevents recompilation of all the files depending on it if there aren't
// any.
if (auto existingOrErr =
llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
if (std::move(existingOrErr.get())->getBuffer() == os.str())
shouldWriteOutput = false;
}

// Populate the output file if necessary.
if (shouldWriteOutput) {
std::unique_ptr<llvm::ToolOutputFile> outputFile =
openOutputFile(outputFilename, &errorMessage);
if (!outputFile) {
llvm::errs() << errorMessage << "\n";
return 1;
}
outputFile->os() << os.str();
outputFile->keep();
}

// Always write the depfile, even if the main output hasn't changed. If it's
// missing, Ninja considers the output dirty.
if (!dependencyFilename.empty())
if (failed(createDependencyFile(outputFilename, dependencyFilename)))
return 1;

return 0;
}
Loading