Skip to content

[BE][export] add data-dependent section to export tutorial #3244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 23, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions intermediate_source/torch_export_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,188 @@ def forward(self, x, y):
"bool_val": None,
}

######################################################################
# Data-dependent errors
# ---------------------
#
# While trying to export models, you have may have encountered errors like ``Could not guard on data-dependent expression`` or ``Could not extract specialized integer from data-dependent expression``.
# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward.
#
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (for example, they may have the same or equivalent symbolic properties

# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may
# - sizes, strides, dtypes, and so on), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may

# struggle with user code that relies on data values. In short, if the compiler requires a concrete, specialized value that is dependent on tensor data in order to proceed, it will error, complaining that
# FakeTensor tracing isn't providing the information required.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking, you should highlight data-dependence as the main cause rather than making it about fake tensors. Also, export programming model has sections on shape vs. data dependence.

#
# Let's talk about where data-dependent values appear in programs. Common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors.
# How are these values represented in the exported program? In the ``Constraints/Dynamic Shapes`` section, we talked about allocating symbols to represent dynamic input dimensions, and the same happens here -
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked" ``SymInts``, in contrast to the "backed" symbols/``SymInts``

# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed.
#
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
# For dynamic input shapes (backed ``SymInts``), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.

# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts".
# On the other hand, data-dependent values are derived from ``FakeTensors`` during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols" or ``SymInts``.

#
# Let's see how these show up in exported programs, with this example:

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = y.tolist()
return b + [a]

inps = (
torch.tensor(1),
torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)

######################################################################
# The result is that 3 unbacked symbols (prefixed with ``u``) are allocated and returned; 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. Note from the
# ``Range constraints`` field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols.

######################################################################
# Guards, torch._check()
# ^^^^^^^^^^^^^^^^^^^^^^
#
# But the case above is easy to export, because the compiler doesn't need the concrete values of the unbacked symbols for anything. All that's relevant is that the return values are unbacked symbols.
# The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered:

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
if a // 2 >= 5:
return y + 2
else:
return y * 5

######################################################################
# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5``. Because the hint isn't available, the expression ``a // 2 >= 5``
# can't be concretely evaluated, and export errors out with ``Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)``.
#
# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut "actually"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"this" what? maybe "this code"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some options:

#
# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or by guarding undesired branches with ``torch.compiler.is_compiling()``.
# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()``, keeping both branches alive.
#
# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors;
# there are data-dependent errors that do not involve control-flow.
#
# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler regarding properties of symbols.
# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is deferred as a runtime assert, and any symbol properties that follow from the expression being true
# inform the symbolic shapes subsystem (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to describe properties that are generally true for these symbols via
# ``torch._check()`` calls, we can potentially bypass data-dependent guards without rewriting the offending model code.
#
# For example in the model above, inserting ``torch._check(a >= 10)`` tells the compiler that ``return y + 2`` can always be traced, and ``torch._check(a == 4)`` tells it to trace ``return y * 5``.
# See what happens when we re-export this model.

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
torch._check(a >= 10)
torch._check(a <= 60)
if a // 2 >= 5:
return y + 2
else:
return y * 5

inps = (
torch.tensor(32),
torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)

######################################################################
# Export succeeds, and note from the ``Range constraints`` field that the ``torch._check()`` calls have informed the compiler, giving ``u0`` a range of ``[10, 60]``.
#
# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are accepted:
#
# 1. Equality with simple, non-data-dependent expressions: ``torch._check()`` calls that communicate expressions like ``u0 == s0 + 4`` or ``u0 == 5``.
# 2. Range refinement: calls that provide lower or upper bounds for symbols refine symbol ranges.
# 3. Some basic reasoning around more complicated expressions: for example, a complicated expression like ``torch._check(a ** 2 - 3 * a <= 10)`` will get you past a guard with the same expression.
#
# As mentioned previously, ``torch._check()`` calls have applicability outside of data-dependent control flow. For example, here's a model where ``torch._check()`` insertion
# prevails while manual specialization & ``torch.cond()`` do not:

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
return y[a]

inps = (
torch.tensor(32),
torch.randn(60),
)
export(Foo(), inps)

######################################################################
# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with
# ``Could not guard on data-dependent expression -u0 > 60``, implying that the compiler doesn't know if this is a valid indexing operation;
# if the value of ``x`` is out-of-bounds for ``y`` or not. Here, manual specialization is too prohibitive, and ``torch.cond()`` has no place.
# Instead, informing the compiler of ``u0``'s range is sufficient:

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
torch._check(a >= 0)
torch._check(a <= y.shape[0])
return y[a]

inps = (
torch.tensor(32),
torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)

######################################################################
# Specialized values
# ^^^^^^^^^^^^^^^^^^
#
# Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value
# while tracing. This looks something like ``Could not extract specialized integer from data-dependent expression``, and is analogous to
# the previous class of errors; if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise
# with evaluating concrete boolean values.
#
# This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, list comprehension
# in Python requires an ``int()`` cast on the size of the list:

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = torch.cat([y for y in range(a)], dim=0)
return b + int(a)

inps = (
torch.tensor(32),
torch.randn(60),
)
export(Foo(), inps, strict=False)

######################################################################
# In this case, some basic options you have are:
#
# 1. Avoid unnecessary ``int()`` cast calls, in this case the ``int(a)`` in the return statement.
# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (e.g. with ``torch._check(a == 60)``).
# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. Therefore, the following rewrite avoids this error.

class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = y.unsqueeze(0).repeat(a, 1)
return b + a

inps = (
torch.tensor(32),
torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)

######################################################################
# Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters.
# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors <https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs>`.

######################################################################
# Custom Ops
# ----------
Expand Down
Loading