Skip to content

Commit 41ccfc8

Browse files
jamesjwupytorchmergebot
authored andcommitted
Log chromium event for automatic dynamic reasons (pytorch#137491)
Log a chromium event so that we can see the reasons for invoking automatic dynamic shapes in aggregate internally. Run following code: ``` import torch @torch.compile(backend="eager") def foo(t, x): return t.sin() + x torch._dynamo.config.automatic_dynamic_shapes = True torch._dynamo.config.assume_static_by_default = True # Change size x = torch.randn([1,2]) foo(x, 2) x = torch.randn([2,2]) foo(x, 2) torch._dynamo.reset() # Change dimensionality x = torch.randn([1,2]) foo(x, 2) x = torch.randn([1,2,3]) foo(x, 2) torch._dynamo.reset() # Change stride x = torch.randn([3,3]) foo(x, 2) x = torch.as_strided(x, [3,3], [2,2]) foo(x, 2) torch._dynamo.reset() # Change scalar x = torch.randn([1,2]) foo(x, 2) foo(x, 3) ``` Internal link to perfetto: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html?url=https%3A%2F%2Finterncache-all.fbcdn.net%2Fmanifold%2Ftlparse_reports%2Ftree%2Flogs%2Fjjwu%2Fcustom%2Fchromium_events.json#!/viewer?url=https%3A%2F%2Finterncache-all.fbcdn.net%2Fmanifold%2Ftlparse_reports%2Ftree%2Flogs%2Fjjwu%2Fcustom%2Fchromium_events.json&local_cache_key The events look like this: <img width="639" alt="image" src="https://github.com/user-attachments/assets/23916333-7f24-47c7-934b-201f33aebeab"> <img width="638" alt="image" src="https://github.com/user-attachments/assets/9f927c8d-04bb-4431-8802-685b032df656"> <img width="640" alt="image" src="https://github.com/user-attachments/assets/342e9e11-0dfc-422d-bd0b-01a8574d38ba"> <img width="635" alt="image" src="https://github.com/user-attachments/assets/dc2c97cd-7180-4069-b019-d6e63ee490bc"> Differential Revision: [D64184625](https://our.internmc.facebook.com/intern/diff/D64184625) Pull Request resolved: pytorch#137491 Approved by: https://github.com/Skylion007, https://github.com/oulgen
1 parent a06d49a commit 41ccfc8

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

torch/_dynamo/variables/builder.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import random
1515
import re
1616
import sys
17+
import time
1718
import types
1819
import warnings
1920
import weakref
@@ -33,6 +34,7 @@
3334

3435
import torch
3536
from torch import SymInt
37+
from torch._dynamo.utils import get_chromium_event_logger
3638
from torch._guards import GuardSource, TracingContext
3739
from torch._higher_order_ops.torchbind import call_torchbind
3840
from torch._ops import HigherOrderOperator
@@ -1760,6 +1762,17 @@ def update_frame_state(value):
17601762
value,
17611763
frame_state_entry.scalar,
17621764
)
1765+
get_chromium_event_logger().log_instant_event(
1766+
"automatic_dynamic",
1767+
time.time_ns(),
1768+
{
1769+
"name": name,
1770+
"dim_changed": "scalar",
1771+
"reason": "scalar change",
1772+
"cached": str(frame_state_entry.scalar),
1773+
"new": str(value),
1774+
},
1775+
)
17631776
if self.source.guard_source().is_unspecialized_nn_module():
17641777
log.info(
17651778
"%s",
@@ -2466,6 +2479,17 @@ def update_frame_state(size, stride):
24662479
len(size),
24672480
frame_state_entry.size,
24682481
)
2482+
get_chromium_event_logger().log_instant_event(
2483+
"automatic_dynamic",
2484+
time.time_ns(),
2485+
{
2486+
"name": name,
2487+
"dim_changed": "all",
2488+
"reason": "dimensionality change",
2489+
"cached": str(frame_state_entry.size),
2490+
"new": str(size),
2491+
},
2492+
)
24692493
frame_state_entry.size = None
24702494
frame_state_entry.stride = None
24712495
else:
@@ -2483,6 +2507,17 @@ def update_frame_state(size, stride):
24832507
size[i],
24842508
dim,
24852509
)
2510+
get_chromium_event_logger().log_instant_event(
2511+
"automatic_dynamic",
2512+
time.time_ns(),
2513+
{
2514+
"name": name,
2515+
"dim_changed": i,
2516+
"reason": "size change",
2517+
"cached": str(dim),
2518+
"new": str(size[i]),
2519+
},
2520+
)
24862521
frame_state_entry.size[i] = None
24872522
has_size_changed = (
24882523
has_size_changed or frame_state_entry.size[i] is None
@@ -2513,6 +2548,17 @@ def update_frame_state(size, stride):
25132548
stride[i],
25142549
dim,
25152550
)
2551+
get_chromium_event_logger().log_instant_event(
2552+
"automatic_dynamic",
2553+
time.time_ns(),
2554+
{
2555+
"name": name,
2556+
"dim_changed": i,
2557+
"reason": "stride change",
2558+
"cached": str(dim),
2559+
"new": str(stride[i]),
2560+
},
2561+
)
25162562
frame_state_entry.stride[i] = None
25172563
tx.output.frame_state[name] = frame_state_entry
25182564

0 commit comments

Comments
 (0)