@@ -213,53 +213,38 @@ def attributes(self):
213
213
214
214
215
215
class ObjectCode :
216
- """Represent a compiled program that was loaded onto the device.
216
+ """Represent a compiled program to be loaded onto the device.
217
217
218
218
This object provides a unified interface for different types of
219
- compiled programs that are loaded onto the device.
220
-
221
- Loads the module library with specified module code and JIT options.
219
+ compiled programs that will be loaded onto the device.
222
220
223
221
Note
224
222
----
225
- The public constructor assumes that ``module`` is of code type "cubin".
226
- For all other possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program`
227
- accepts them and returns an `ObjectCode` instance with its ``compile`` method.
223
+ This class has no default constructor. If you already have a cubin that you would
224
+ like to load, use the :meth:`from_cubin` alternative constructor. For all other
225
+ possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program`
226
+ accepts them and returns an :class:`ObjectCode` instance with its
227
+ :meth:`~cuda.core.experimental.Program.compile` method.
228
228
229
229
Note
230
230
----
231
231
Usage under CUDA 11.x will only load to the current device
232
232
context.
233
-
234
- Parameters
235
- ----------
236
- module : Union[bytes, str]
237
- Either a bytes object containing the cubin to load, or
238
- a file path string pointing to the cubin to load.
239
- symbol_mapping : Optional[dict]
240
- A dictionary specifying how the unmangled symbol names (as keys)
241
- should be mapped to the mangled names before trying to retrieve
242
- them (default to no mappings).
243
233
"""
244
234
245
235
__slots__ = ("_handle" , "_backend_version" , "_code_type" , "_module" , "_loader" , "_sym_map" )
246
236
_supported_code_type = ("cubin" , "ptx" , "ltoir" , "fatbin" )
247
237
248
- def __init__ (self , module : Union [bytes , str ], * , symbol_mapping : Optional [dict ] = None ):
249
- _lazy_init ()
250
-
251
- # handle is assigned during _lazy_load
252
- self ._handle = None
253
- self ._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000 ) else "old"
254
- self ._loader = _backend [self ._backend_version ]
255
- self ._code_type = "cubin"
256
- self ._module = module
257
- self ._sym_map = {} if symbol_mapping is None else symbol_mapping
238
+ def __init__ (self ):
239
+ raise NotImplementedError (
240
+ "directly creating an ObjectCode object can be ambiguous. Please either call Program.compile() "
241
+ "or one of the ObjectCode.from_*() constructors"
242
+ )
258
243
244
+ @staticmethod
259
245
def _init (module , code_type , * , symbol_mapping : Optional [dict ] = None ):
260
246
self = ObjectCode .__new__ (ObjectCode )
261
- if code_type not in self ._supported_code_type :
262
- raise ValueError
247
+ assert code_type in self ._supported_code_type , f"{ code_type = } is not supported"
263
248
_lazy_init ()
264
249
265
250
# handle is assigned during _lazy_load
@@ -274,6 +259,22 @@ def _init(module, code_type, *, symbol_mapping: Optional[dict] = None):
274
259
275
260
return self
276
261
262
+ @staticmethod
263
+ def from_cubin (module : Union [bytes , str ], * , symbol_mapping : Optional [dict ] = None ) -> "ObjectCode" :
264
+ """Create an :class:`ObjectCode` instance from an existing cubin.
265
+
266
+ Parameters
267
+ ----------
268
+ module : Union[bytes, str]
269
+ Either a bytes object containing the in-memory cubin to load, or
270
+ a file path string pointing to the on-disk cubin to load.
271
+ symbol_mapping : Optional[dict]
272
+ A dictionary specifying how the unmangled symbol names (as keys)
273
+ should be mapped to the mangled names before trying to retrieve
274
+ them (default to no mappings).
275
+ """
276
+ return ObjectCode ._init (module , "cubin" , symbol_mapping = symbol_mapping )
277
+
277
278
# TODO: do we want to unload in a finalizer? Probably not..
278
279
279
280
def _lazy_load_module (self , * args , ** kwargs ):
@@ -307,12 +308,12 @@ def get_kernel(self, name):
307
308
Newly created kernel object.
308
309
309
310
"""
311
+ if self ._code_type not in ("cubin" , "ptx" , "fatbin" ):
312
+ raise RuntimeError (f"get_kernel() is not supported for { self ._code_type } " )
310
313
try :
311
314
name = self ._sym_map [name ]
312
315
except KeyError :
313
316
name = name .encode ()
314
317
315
318
data = handle_return (self ._loader ["kernel" ](self ._handle , name ))
316
319
return Kernel ._from_obj (data , self )
317
-
318
- # TODO: implement from_handle()
0 commit comments