You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/torch_export_tutorial.py
+13-12Lines changed: 13 additions & 12 deletions
Original file line number
Diff line number
Diff line change
@@ -634,19 +634,19 @@ def forward(self, x, y):
634
634
# ---------------------
635
635
#
636
636
# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or Could not extract specialized integer from data-dependent expression".
637
-
# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties
638
-
# (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle
639
-
# with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that
640
-
# FakeTensor tracing isn't providing the information required.
637
+
# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. While these have equivalent symbolic properties
638
+
# (e.g. sizes, strides, dtypes), they diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may be
639
+
# unable to out-of-the-box compile parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out,
640
+
# complaining that the value is not available.
641
641
#
642
642
# Data-dependent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors.
643
643
# How are these values represented in the exported program? In the `Constraints/Dynamic Shapes <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_
644
644
# section, we talked about allocating symbols to represent dynamic input dimensions.
645
-
# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts",
646
-
# in contrast to the "backed" symbols/SymInts allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence/absence of a "hint" for the symbol:
647
-
# a concrete value backing the symbol, that can inform the compiler on how to proceed.
645
+
# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols,
646
+
# in contrast to the "backed" symbols allocated for input dimensions. The `"backed/unbacked" <https://pytorch.org/docs/main/export.programming_model.html#basics-of-symbolic-shapes>`_
647
+
# nomenclature refers to the presence/absence of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler on how to proceed.
648
648
#
649
-
# In the input shape symbol case (backed SymInts), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties.
649
+
# In the input shape symbol case (backed symbols), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties.
650
650
# For data-dependent values, the symbols are taken from FakeTensor "data" during tracing, and so the compiler doesn't know the actual values (hints) that these symbols would take on.
651
651
#
652
652
# Let's see how these show up in exported programs:
@@ -668,14 +668,14 @@ def forward(self, x, y):
668
668
# The result is that 3 unbacked symbols (notice they're prefixed with "u", instead of the usual "s" for input shape/backed symbols) are allocated and returned:
669
669
# 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call.
670
670
# Note from the range constraints field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols,
671
-
# since we literally have no information on what these values are - they don't represent sizes, so don't necessarily have positive values.
671
+
# since we have no information on what these values are - they don't represent sizes, so don't necessarily have positive values.
# But the case above is easy to export, because the concrete values of these symbols aren't used in any compiler decision-making; all that's relevant is that the return values are unbacked symbols.
678
-
# The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered:
678
+
# The data-dependent errors highlighted in this section are cases like the following, where `data-dependent guards <https://pytorch.org/docs/main/export.programming_model.html#control-flow-static-vs-dynamic>`_ are encountered:
679
679
680
680
classFoo(torch.nn.Module):
681
681
defforward(self, x, y):
@@ -689,7 +689,7 @@ def forward(self, x, y):
689
689
# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output.
690
690
# Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5 (unhinted)``".
691
691
#
692
-
# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options:
692
+
# So how do we export this toy model? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here are some basic options:
693
693
#
694
694
# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using ``torch.compiler.is_compiling()`` to guard what's traced at compile-time.
695
695
# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()`` so we don't specialize on a branch.
# Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters.
814
-
# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors <https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs>`_.
814
+
# For more in-depth guides, please refer to the `Export Programming Model <https://pytorch.org/docs/main/export.programming_model.html>`_,
815
+
# or `Dealing with GuardOnDataDependentSymNode errors <https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs>`_.
0 commit comments