Skip to content

Improve custom ops tutorials #3020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions advanced_source/cpp_custom_ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ To add ``torch.compile`` support for an operator, we must add a FakeTensor kerne
known as a "meta kernel" or "abstract impl"). FakeTensors are Tensors that have
metadata (such as shape, dtype, device) but no data: the FakeTensor kernel for an
operator specifies how to compute the metadata of output tensors given the metadata of input tensors.
The FakeTensor kernel should return dummy Tensors of your choice with
the correct Tensor metadata (shape/strides/``dtype``/device).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in double ticks?

Suggested change
the correct Tensor metadata (shape/strides/``dtype``/device).
the correct Tensor metadata (``shape/strides/<dtype>/device``).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to put dtype in double ticks because it's not an english word


We recommend that this be done from Python via the `torch.library.register_fake` API,
though it is possible to do this from C++ as well (see
Expand Down
15 changes: 10 additions & 5 deletions advanced_source/python_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def display(img):
######################################################################
# ``crop`` is not handled effectively out-of-the-box by
# ``torch.compile``: ``torch.compile`` induces a
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
# on functions it is unable to handle and graph breaks are bad for performance.
# The following code demonstrates this by raising an error
# (``torch.compile`` with ``fullgraph=True`` raises an error if a
Expand All @@ -85,9 +85,9 @@ def f(img):
#
# 1. wrap the function into a PyTorch custom operator.
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
# Given the metadata (e.g. shapes)
# of the input Tensors, this function says how to compute the metadata
# of the output Tensor(s).
# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
# this function should return dummy Tensors of your choice with the correct
# Tensor metadata (shape/strides/``dtype``/device).


from typing import Sequence
Expand Down Expand Up @@ -130,6 +130,11 @@ def f(img):
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
# has led to) silent incorrectness when composed with ``torch.compile``.
#
# If you don't need training support, there is no need to use
# ``torch.library.register_autograd``.
# If you end up training with a ``custom_op`` that doesn't have an autograd
# registration, we'll raise an error message.
#
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
# custom operator:
Expand Down Expand Up @@ -203,7 +208,7 @@ def setup_context(ctx, inputs, output):
######################################################################
# Mutable Python Custom operators
# -------------------------------
# You can also wrap a Python function that mutates its inputs into a custom
# You can also wrap a Python function that mutates its inputs into a custom
# operator.
# Functions that mutate inputs are common because that is how many low-level
# kernels are written; for example, a kernel that computes ``sin`` may take in
Expand Down
Loading