@@ -435,34 +435,45 @@ def suggested_fixes():
435
435
#
436
436
# Currently, the steps to register a custom op for use by ``torch.export`` are:
437
437
#
438
- # - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
439
- # as with any other custom op
438
+ # - If you’re writing custom ops purely in Python, use torch.library.custom_op.
440
439
441
- from torch .library import Library , impl
440
+ import torch .library
441
+ import numpy as np
442
442
443
- m = Library ("my_custom_library" , "DEF" )
444
-
445
- m .define ("custom_op(Tensor input) -> Tensor" )
446
-
447
- @impl (m , "custom_op" , "CompositeExplicitAutograd" )
448
- def custom_op (x ):
449
- print ("custom_op called!" )
450
- return torch .relu (x )
443
+ @torch .library .custom_op ("mylib::sin" , mutates_args = ())
444
+ def sin (x ):
445
+ x_np = x .numpy ()
446
+ y_np = np .sin (x_np )
447
+ return torch .from_numpy (y_np )
451
448
452
449
######################################################################
453
- # - Define a ``"Meta"`` implementation of the custom op that returns an empty
454
- # tensor with the same shape as the expected output
450
+ # - You will need to provide abstract implementation so that PT2 can trace through it.
455
451
456
- @impl ( m , "custom_op" , "Meta " )
457
- def custom_op_meta (x ):
452
+ @torch . library . register_fake ( "mylib::sin " )
453
+ def _ (x ):
458
454
return torch .empty_like (x )
459
455
456
+ # - Sometimes, the custom op you are exporting has data-dependent output, meaning
457
+ # we can't determine the shape of the output at compile time. In this case, you can do
458
+ # following:
459
+ @torch .library .register_fake ("mylib::op_that_has_data_dependent" )
460
+ def _ (x ):
461
+ # Number of nonzero-elements is data-dependent.
462
+ # Since we cannot peek at the data in an abstract impl,
463
+ # we use the ctx object to construct a new symint that
464
+ # represents the data-dependent size.
465
+ ctx = torch .library .get_ctx ()
466
+ nnz = ctx .new_dynamic_size ()
467
+ shape = [nnz , x .dim ()]
468
+ result = x .new_empty (shape , dtype = torch .int64 )
469
+ return result
470
+
460
471
######################################################################
461
472
# - Call the custom op from the code you want to export using ``torch.ops``
462
473
463
474
def custom_op_example (x ):
464
475
x = torch .sin (x )
465
- x = torch .ops .my_custom_library . custom_op (x )
476
+ x = torch .ops .mylib . sin (x )
466
477
x = torch .cos (x )
467
478
return x
468
479
0 commit comments