33
33
from ctypes import c_void_p , cdll , CDLL
34
34
from dataclasses import field
35
35
from functools import partial
36
- from importlib import abc
37
36
from pathlib import Path
38
37
from threading import Thread
39
38
from time import sleep , time
46
45
get_interface_for_device ,
47
46
get_registered_device_interfaces ,
48
47
)
49
- from torch ._dynamo .utils import counters
48
+ from torch ._dynamo .utils import counters , dynamo_timed
50
49
from torch ._inductor import config , exc
51
50
from torch ._inductor .codegen .cuda import cuda_env
52
51
from torch ._inductor .utils import cache_dir , developer_warning , is_linux
@@ -1701,6 +1700,7 @@ def cpp_prefix() -> str:
1701
1700
1702
1701
# Given a path to an input cpp file and an output path,
1703
1702
# Attempts to compile the file, storing the output in "output_path"
1703
+ @dynamo_timed
1704
1704
def compile_file (
1705
1705
input_path : Union [str , List [str ]], output_path : str , cmd : List [str ]
1706
1706
) -> None :
@@ -1783,7 +1783,8 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
1783
1783
raise
1784
1784
1785
1785
@classmethod
1786
- def load (cls , source_code : str ) -> Union [CDLL , ModuleType ]:
1786
+ def load (cls , source_code : str , cuda : bool = False ) -> Union [CDLL , ModuleType ]:
1787
+ cls .cpp_compile_command_flags .update ({"cuda" : cuda })
1787
1788
picked_vec_isa = pick_vec_isa ()
1788
1789
cpp_command = repr (
1789
1790
cpp_compile_command (
@@ -1821,9 +1822,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
1821
1822
"include_pytorch" : True ,
1822
1823
"shared" : True ,
1823
1824
}
1825
+ entry_function = "kernel"
1826
+ call_entry_function = "kernel(%s);Py_RETURN_NONE;"
1827
+ extra_parse_arg = ""
1824
1828
suffix_template = textwrap .dedent (
1825
1829
"""
1826
- // Python bindings to call kernel ():
1830
+ // Python bindings to call %s ():
1827
1831
#define PY_SSIZE_T_CLEAN
1828
1832
#include <Python.h>
1829
1833
#include <sstream>
@@ -1844,14 +1848,15 @@ class CppPythonBindingsCodeCache(CppCodeCache):
1844
1848
return result;
1845
1849
}
1846
1850
1847
- static PyObject* kernel_py(PyObject* self, PyObject* args) {
1851
+ %s
1852
+
1853
+ static PyObject* %s_py(PyObject* self, PyObject* args) {
1848
1854
try {
1849
1855
if(!PyTuple_CheckExact(args))
1850
1856
[[unlikely]] throw std::runtime_error("tuple args required");
1851
1857
if(PyTuple_GET_SIZE(args) != %s)
1852
1858
[[unlikely]] throw std::runtime_error("requires %s args");
1853
- kernel(%s);
1854
- Py_RETURN_NONE;
1859
+ %s
1855
1860
} catch(std::exception const& e) {
1856
1861
PyErr_SetString(PyExc_RuntimeError, e.what());
1857
1862
return nullptr;
@@ -1862,13 +1867,13 @@ class CppPythonBindingsCodeCache(CppCodeCache):
1862
1867
}
1863
1868
1864
1869
static PyMethodDef py_methods[] = {
1865
- {"kernel ", kernel_py , METH_VARARGS, ""},
1870
+ {"%s ", %s_py , METH_VARARGS, ""},
1866
1871
{NULL, NULL, 0, NULL}};
1867
1872
1868
1873
static struct PyModuleDef py_module =
1869
- {PyModuleDef_HEAD_INIT, "kernel ", NULL, -1, py_methods};
1874
+ {PyModuleDef_HEAD_INIT, "%s ", NULL, -1, py_methods};
1870
1875
1871
- PyMODINIT_FUNC PyInit_kernel (void) {
1876
+ PyMODINIT_FUNC PyInit_%s (void) {
1872
1877
const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
1873
1878
if(!str_addr) {
1874
1879
PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
@@ -1890,29 +1895,62 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
1890
1895
torch ._C ._dynamo .guards ._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined]
1891
1896
)
1892
1897
return importlib .machinery .ExtensionFileLoader (
1893
- f"{ key } .kernel " , path
1898
+ f"{ key } .{ cls . entry_function } " , path
1894
1899
).load_module () # type: ignore[call-arg]
1895
1900
1896
1901
@classmethod
1897
- def load_pybinding (cls , argtypes : List [str ], source_code : str ) -> Any :
1902
+ def load_pybinding (
1903
+ cls , argtypes : List [str ], source_code : str , cuda : bool = False
1904
+ ) -> Any :
1898
1905
"""
1899
1906
Wrap a C++ function in fast Python bindings.
1900
1907
1901
1908
Args:
1902
- argtypes: The types of args to kernel (), e.g. ["float*", "long"]
1903
- source_code: C++ source code containing a kernel () function
1909
+ argtypes: The types of args to ENTRY_FUNCTION (), e.g. ["float*", "long"]
1910
+ source_code: C++ source code containing a ENTRY_FUNCTION () function
1904
1911
1905
1912
Returns:
1906
- A python version of kernel ()
1913
+ A python version of ENTRY_FUNCTION ()
1907
1914
"""
1908
1915
parseargs = ", " .join (
1909
1916
f"parse_arg<{ argtype .replace ('const ' , '' )} >(args, { n } )"
1910
1917
for n , argtype in enumerate (argtypes )
1911
1918
)
1912
- suffix = cls .suffix_template % (len (argtypes ), len (argtypes ), parseargs )
1913
- result = cls .load (source_code + suffix )
1919
+ suffix = cls .suffix_template % (
1920
+ cls .entry_function ,
1921
+ cls .extra_parse_arg ,
1922
+ cls .entry_function ,
1923
+ len (argtypes ),
1924
+ len (argtypes ),
1925
+ cls .call_entry_function % parseargs ,
1926
+ cls .entry_function ,
1927
+ cls .entry_function ,
1928
+ cls .entry_function ,
1929
+ cls .entry_function ,
1930
+ )
1931
+ result = cls .load (source_code + suffix , cuda )
1914
1932
assert isinstance (result , ModuleType )
1915
- return result .kernel
1933
+ return getattr (result , cls .entry_function )
1934
+
1935
+
1936
+ class CppWrapperCodeCache (CppPythonBindingsCodeCache ):
1937
+ cache : Dict [str , Union [CDLL , ModuleType ]] = {}
1938
+ clear = staticmethod (cache .clear )
1939
+ cpp_compile_command_flags = {
1940
+ "include_pytorch" : True ,
1941
+ "shared" : True ,
1942
+ }
1943
+ entry_function = "inductor_entry_cpp"
1944
+ call_entry_function = "return THPVariable_WrapList(inductor_entry_cpp(%s));"
1945
+ extra_parse_arg = textwrap .dedent (
1946
+ """
1947
+ #include <torch/csrc/autograd/python_variable.h>
1948
+
1949
+ template <> inline std::vector<at::Tensor> parse_arg<std::vector<at::Tensor>>(PyObject* args, size_t n) {
1950
+ return THPVariable_UnpackList(PyTuple_GET_ITEM(args, n));
1951
+ }
1952
+ """
1953
+ )
1916
1954
1917
1955
1918
1956
class PyCodeCache :
@@ -1998,81 +2036,6 @@ def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
1998
2036
return parse_stack_trace (entry )
1999
2037
2000
2038
2001
- class CppWrapperCodeCache :
2002
- cache : Dict [str , CDLL ] = dict ()
2003
- clear = staticmethod (cache .clear )
2004
-
2005
- @classmethod
2006
- def load (cls , source_code : str , func_name : str , key : str , cuda : bool ) -> CDLL :
2007
- name = f"inline_extension_{ key } "
2008
- cpp_wrapper_dir = cpp_wrapper_cache_dir (name )
2009
- os .makedirs (cpp_wrapper_dir , exist_ok = True )
2010
-
2011
- ext = "so"
2012
- filepath = os .path .join (cpp_wrapper_dir , f"{ name } .{ ext } " )
2013
- log .debug ("Cpp wrapper code path %s" , filepath )
2014
-
2015
- if key not in cls .cache :
2016
- log .debug ("Cpp wrapper cache miss for %s" , filepath )
2017
- from filelock import FileLock
2018
-
2019
- lock_dir = get_lock_dir ()
2020
- lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
2021
- with lock :
2022
- if not os .path .exists (filepath ):
2023
- log .debug ("Cpp wrapper building %s" , filepath )
2024
-
2025
- _cpp_flags = cpp_flags ()
2026
- _opt_flags = optimization_flags ()
2027
- _shared = get_shared ()
2028
- _warning_all_flag = get_warning_all_flag ()
2029
- (
2030
- _ipaths ,
2031
- _lpaths ,
2032
- _libs ,
2033
- _macros ,
2034
- _build_arch_flags ,
2035
- ) = get_include_and_linking_paths (
2036
- vec_isa = pick_vec_isa (),
2037
- cuda = cuda ,
2038
- )
2039
- _use_custom_generated_macros = use_custom_generated_macros ()
2040
- _cpp_wrapper_flags = cpp_wrapper_flags ()
2041
-
2042
- extra_cflags = f"{ _cpp_flags } { _opt_flags } { _warning_all_flag } { _build_arch_flags } { _macros } \
2043
- { _cpp_wrapper_flags } { _use_custom_generated_macros } "
2044
- # For CPP wrapper, add -ffast-math during linking to make CPU flush denormals.
2045
- # CPP wrapper leverages cpp_extension which will do the compilation and linking in two stages.
2046
- # We need to explicitly add -ffast-math as a linking flag.
2047
- # For the default python wrapper, the compilation and linking are done in one command thus -ffast-math
2048
- # will take effect in both compilation and linking.
2049
- extra_ldflags = f"{ _shared } { _lpaths } { _libs } -ffast-math"
2050
-
2051
- mod = torch .utils .cpp_extension .load_inline (
2052
- name = name ,
2053
- build_directory = cpp_wrapper_dir ,
2054
- cpp_sources = [source_code ],
2055
- functions = [func_name ],
2056
- extra_cflags = [extra_cflags ],
2057
- extra_ldflags = [extra_ldflags ],
2058
- extra_include_paths = _ipaths ,
2059
- use_pch = True ,
2060
- )
2061
- log .debug ("Cpp wrapper done building %s" , filepath )
2062
- else :
2063
- log .debug ("Found target .so, cpp wrapper loading %s" , filepath )
2064
- spec = importlib .util .spec_from_file_location (name , filepath ) # type: ignore[attr-defined]
2065
- assert spec is not None
2066
- mod = importlib .util .module_from_spec (spec ) # type: ignore[attr-defined]
2067
- assert isinstance (spec .loader , abc .Loader )
2068
- spec .loader .exec_module (mod )
2069
- log .debug ("Cpp wrapper done loading %s" , filepath )
2070
-
2071
- cls .cache [key ] = mod
2072
-
2073
- return cls .cache [key ]
2074
-
2075
-
2076
2039
class TritonCodeCache :
2077
2040
@classmethod
2078
2041
def load (cls , kernel_name : str , source_code : str ) -> ModuleType :
0 commit comments