Skip to content

Commit 4155632

Browse files
desertfirepytorchmergebot
authored andcommitted
[cpp_wrapper] Change CppWrapperCodeCache to use faster python binding (pytorch#117693)
Summary: Using faster binding following pytorch#117500. torch.utils.cpp_extension.load_inline builds a lot of things and is very slow. With this change, later we can further reduce the included header files using the ABI-compatible mode and thus further speed up the compilation. Result: ``` python test/inductor/test_cuda_cpp_wrapper.py -k test_relu_cuda_cuda_wrapper Before: Ran 1 test in 32.843s After: Ran 1 test in 26.229s ``` Pull Request resolved: pytorch#117693 Approved by: https://github.com/jansel
1 parent 7f474da commit 4155632

File tree

2 files changed

+65
-98
lines changed

2 files changed

+65
-98
lines changed

torch/_inductor/codecache.py

Lines changed: 56 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from ctypes import c_void_p, cdll, CDLL
3434
from dataclasses import field
3535
from functools import partial
36-
from importlib import abc
3736
from pathlib import Path
3837
from threading import Thread
3938
from time import sleep, time
@@ -46,7 +45,7 @@
4645
get_interface_for_device,
4746
get_registered_device_interfaces,
4847
)
49-
from torch._dynamo.utils import counters
48+
from torch._dynamo.utils import counters, dynamo_timed
5049
from torch._inductor import config, exc
5150
from torch._inductor.codegen.cuda import cuda_env
5251
from torch._inductor.utils import cache_dir, developer_warning, is_linux
@@ -1701,6 +1700,7 @@ def cpp_prefix() -> str:
17011700

17021701
# Given a path to an input cpp file and an output path,
17031702
# Attempts to compile the file, storing the output in "output_path"
1703+
@dynamo_timed
17041704
def compile_file(
17051705
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
17061706
) -> None:
@@ -1783,7 +1783,8 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
17831783
raise
17841784

17851785
@classmethod
1786-
def load(cls, source_code: str) -> Union[CDLL, ModuleType]:
1786+
def load(cls, source_code: str, cuda: bool = False) -> Union[CDLL, ModuleType]:
1787+
cls.cpp_compile_command_flags.update({"cuda": cuda})
17871788
picked_vec_isa = pick_vec_isa()
17881789
cpp_command = repr(
17891790
cpp_compile_command(
@@ -1821,9 +1822,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
18211822
"include_pytorch": True,
18221823
"shared": True,
18231824
}
1825+
entry_function = "kernel"
1826+
call_entry_function = "kernel(%s);Py_RETURN_NONE;"
1827+
extra_parse_arg = ""
18241828
suffix_template = textwrap.dedent(
18251829
"""
1826-
// Python bindings to call kernel():
1830+
// Python bindings to call %s():
18271831
#define PY_SSIZE_T_CLEAN
18281832
#include <Python.h>
18291833
#include <sstream>
@@ -1844,14 +1848,15 @@ class CppPythonBindingsCodeCache(CppCodeCache):
18441848
return result;
18451849
}
18461850
1847-
static PyObject* kernel_py(PyObject* self, PyObject* args) {
1851+
%s
1852+
1853+
static PyObject* %s_py(PyObject* self, PyObject* args) {
18481854
try {
18491855
if(!PyTuple_CheckExact(args))
18501856
[[unlikely]] throw std::runtime_error("tuple args required");
18511857
if(PyTuple_GET_SIZE(args) != %s)
18521858
[[unlikely]] throw std::runtime_error("requires %s args");
1853-
kernel(%s);
1854-
Py_RETURN_NONE;
1859+
%s
18551860
} catch(std::exception const& e) {
18561861
PyErr_SetString(PyExc_RuntimeError, e.what());
18571862
return nullptr;
@@ -1862,13 +1867,13 @@ class CppPythonBindingsCodeCache(CppCodeCache):
18621867
}
18631868
18641869
static PyMethodDef py_methods[] = {
1865-
{"kernel", kernel_py, METH_VARARGS, ""},
1870+
{"%s", %s_py, METH_VARARGS, ""},
18661871
{NULL, NULL, 0, NULL}};
18671872
18681873
static struct PyModuleDef py_module =
1869-
{PyModuleDef_HEAD_INIT, "kernel", NULL, -1, py_methods};
1874+
{PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods};
18701875
1871-
PyMODINIT_FUNC PyInit_kernel(void) {
1876+
PyMODINIT_FUNC PyInit_%s(void) {
18721877
const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
18731878
if(!str_addr) {
18741879
PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
@@ -1890,29 +1895,62 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
18901895
torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined]
18911896
)
18921897
return importlib.machinery.ExtensionFileLoader(
1893-
f"{key}.kernel", path
1898+
f"{key}.{cls.entry_function}", path
18941899
).load_module() # type: ignore[call-arg]
18951900

18961901
@classmethod
1897-
def load_pybinding(cls, argtypes: List[str], source_code: str) -> Any:
1902+
def load_pybinding(
1903+
cls, argtypes: List[str], source_code: str, cuda: bool = False
1904+
) -> Any:
18981905
"""
18991906
Wrap a C++ function in fast Python bindings.
19001907
19011908
Args:
1902-
argtypes: The types of args to kernel(), e.g. ["float*", "long"]
1903-
source_code: C++ source code containing a kernel() function
1909+
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
1910+
source_code: C++ source code containing a ENTRY_FUNCTION() function
19041911
19051912
Returns:
1906-
A python version of kernel()
1913+
A python version of ENTRY_FUNCTION()
19071914
"""
19081915
parseargs = ", ".join(
19091916
f"parse_arg<{argtype.replace('const ', '')}>(args, {n})"
19101917
for n, argtype in enumerate(argtypes)
19111918
)
1912-
suffix = cls.suffix_template % (len(argtypes), len(argtypes), parseargs)
1913-
result = cls.load(source_code + suffix)
1919+
suffix = cls.suffix_template % (
1920+
cls.entry_function,
1921+
cls.extra_parse_arg,
1922+
cls.entry_function,
1923+
len(argtypes),
1924+
len(argtypes),
1925+
cls.call_entry_function % parseargs,
1926+
cls.entry_function,
1927+
cls.entry_function,
1928+
cls.entry_function,
1929+
cls.entry_function,
1930+
)
1931+
result = cls.load(source_code + suffix, cuda)
19141932
assert isinstance(result, ModuleType)
1915-
return result.kernel
1933+
return getattr(result, cls.entry_function)
1934+
1935+
1936+
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
1937+
cache: Dict[str, Union[CDLL, ModuleType]] = {}
1938+
clear = staticmethod(cache.clear)
1939+
cpp_compile_command_flags = {
1940+
"include_pytorch": True,
1941+
"shared": True,
1942+
}
1943+
entry_function = "inductor_entry_cpp"
1944+
call_entry_function = "return THPVariable_WrapList(inductor_entry_cpp(%s));"
1945+
extra_parse_arg = textwrap.dedent(
1946+
"""
1947+
#include <torch/csrc/autograd/python_variable.h>
1948+
1949+
template <> inline std::vector<at::Tensor> parse_arg<std::vector<at::Tensor>>(PyObject* args, size_t n) {
1950+
return THPVariable_UnpackList(PyTuple_GET_ITEM(args, n));
1951+
}
1952+
"""
1953+
)
19161954

19171955

19181956
class PyCodeCache:
@@ -1998,81 +2036,6 @@ def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
19982036
return parse_stack_trace(entry)
19992037

20002038

2001-
class CppWrapperCodeCache:
2002-
cache: Dict[str, CDLL] = dict()
2003-
clear = staticmethod(cache.clear)
2004-
2005-
@classmethod
2006-
def load(cls, source_code: str, func_name: str, key: str, cuda: bool) -> CDLL:
2007-
name = f"inline_extension_{key}"
2008-
cpp_wrapper_dir = cpp_wrapper_cache_dir(name)
2009-
os.makedirs(cpp_wrapper_dir, exist_ok=True)
2010-
2011-
ext = "so"
2012-
filepath = os.path.join(cpp_wrapper_dir, f"{name}.{ext}")
2013-
log.debug("Cpp wrapper code path %s", filepath)
2014-
2015-
if key not in cls.cache:
2016-
log.debug("Cpp wrapper cache miss for %s", filepath)
2017-
from filelock import FileLock
2018-
2019-
lock_dir = get_lock_dir()
2020-
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
2021-
with lock:
2022-
if not os.path.exists(filepath):
2023-
log.debug("Cpp wrapper building %s", filepath)
2024-
2025-
_cpp_flags = cpp_flags()
2026-
_opt_flags = optimization_flags()
2027-
_shared = get_shared()
2028-
_warning_all_flag = get_warning_all_flag()
2029-
(
2030-
_ipaths,
2031-
_lpaths,
2032-
_libs,
2033-
_macros,
2034-
_build_arch_flags,
2035-
) = get_include_and_linking_paths(
2036-
vec_isa=pick_vec_isa(),
2037-
cuda=cuda,
2038-
)
2039-
_use_custom_generated_macros = use_custom_generated_macros()
2040-
_cpp_wrapper_flags = cpp_wrapper_flags()
2041-
2042-
extra_cflags = f"{_cpp_flags} {_opt_flags} {_warning_all_flag} {_build_arch_flags} {_macros} \
2043-
{_cpp_wrapper_flags} {_use_custom_generated_macros}"
2044-
# For CPP wrapper, add -ffast-math during linking to make CPU flush denormals.
2045-
# CPP wrapper leverages cpp_extension which will do the compilation and linking in two stages.
2046-
# We need to explicitly add -ffast-math as a linking flag.
2047-
# For the default python wrapper, the compilation and linking are done in one command thus -ffast-math
2048-
# will take effect in both compilation and linking.
2049-
extra_ldflags = f"{_shared} {_lpaths} {_libs} -ffast-math"
2050-
2051-
mod = torch.utils.cpp_extension.load_inline(
2052-
name=name,
2053-
build_directory=cpp_wrapper_dir,
2054-
cpp_sources=[source_code],
2055-
functions=[func_name],
2056-
extra_cflags=[extra_cflags],
2057-
extra_ldflags=[extra_ldflags],
2058-
extra_include_paths=_ipaths,
2059-
use_pch=True,
2060-
)
2061-
log.debug("Cpp wrapper done building %s", filepath)
2062-
else:
2063-
log.debug("Found target .so, cpp wrapper loading %s", filepath)
2064-
spec = importlib.util.spec_from_file_location(name, filepath) # type: ignore[attr-defined]
2065-
assert spec is not None
2066-
mod = importlib.util.module_from_spec(spec) # type: ignore[attr-defined]
2067-
assert isinstance(spec.loader, abc.Loader)
2068-
spec.loader.exec_module(mod)
2069-
log.debug("Cpp wrapper done loading %s", filepath)
2070-
2071-
cls.cache[key] = mod
2072-
2073-
return cls.cache[key]
2074-
2075-
20762039
class TritonCodeCache:
20772040
@classmethod
20782041
def load(cls, kernel_name: str, source_code: str) -> ModuleType:

torch/_inductor/codegen/wrapper.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,12 @@ def write_header(self):
14491449
if config.aot_inductor.abi_compatible:
14501450
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
14511451
else:
1452+
if not V.graph.aot_mode:
1453+
self.header.splice(
1454+
"""
1455+
#include <pybind11/pybind11.h>
1456+
"""
1457+
)
14521458
self.header.splice(
14531459
"""
14541460
#include <ATen/ATen.h>
@@ -1622,7 +1628,7 @@ def write_wrapper_decl(self):
16221628
else:
16231629
self.prefix.splice(
16241630
"""
1625-
py::gil_scoped_release release;
1631+
pybind11::gil_scoped_release release;
16261632
"""
16271633
)
16281634

@@ -1978,11 +1984,9 @@ def generate_end(self, result):
19781984
return
19791985

19801986
result.writeline("'''\n)")
1981-
# get the hash of the wrapper code to name the extension
1982-
wrapper_call_hash = codecache.code_hash(result.getvalue())
19831987
result.splice(
19841988
f"""
1985-
module = CppWrapperCodeCache.load(cpp_wrapper_src, '{self.call_func_name}', '{wrapper_call_hash}', {self.cuda})
1989+
inductor_entry = CppWrapperCodeCache.load_pybinding(["std::vector<at::Tensor>"], cpp_wrapper_src, {self.cuda})
19861990
"""
19871991
)
19881992

@@ -2024,7 +2028,7 @@ def g(args):
20242028
{args_str}
20252029
{return_str}
20262030
return g
2027-
call = _wrap_func(module.{self.call_func_name})
2031+
call = _wrap_func(inductor_entry)
20282032
"""
20292033
)
20302034

0 commit comments

Comments
 (0)