Skip to content

Commit 4a4baff

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo, 3.12] force LOAD_SUPER_ATTR second bit on (pytorch#123686)
This was pretty painful to find haha Pull Request resolved: pytorch#123686 Approved by: https://github.com/jansel
1 parent d60135e commit 4a4baff

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

test/dynamo/test_misc.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9951,6 +9951,20 @@ def fn(x):
99519951
opt_fn = torch.compile(fn, backend="eager")
99529952
opt_fn(torch.randn(5, 5))
99539953

9954+
def test_super_after_graph_break(self):
9955+
class Foo(torch.nn.Sequential):
9956+
def __init__(self, layers):
9957+
torch._dynamo.graph_break()
9958+
super().__init__(*layers)
9959+
9960+
def fn(x):
9961+
layers = [torch.nn.Linear(3, 3) for _ in range(3)]
9962+
mod = Foo(layers)
9963+
return mod(x)
9964+
9965+
opt_fn = torch.compile(fn, backend="eager")
9966+
opt_fn(torch.randn(3, 3))
9967+
99549968
def test_raises_importerror1(self):
99559969
@torch.compile(backend="eager")
99569970
def fn(x):

torch/_dynamo/bytecode_transformation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,11 @@ def should_compute_arg():
10311031
elif instructions[i].opname == "LOAD_SUPER_ATTR":
10321032
assert instructions[i].arg is not None
10331033
assert instructions[i].argval is not _NotProvided
1034-
instructions[i].arg = (names[instructions[i].argval] << 2) + (
1035-
cast(int, instructions[i].arg) % 4
1034+
# Copy low bit, force second bit on for explicit super (the "+ 2")
1035+
instructions[i].arg = (
1036+
(names[instructions[i].argval] << 2)
1037+
+ (cast(int, instructions[i].arg) % 2)
1038+
+ 2
10361039
)
10371040
elif instructions[i].opcode in HAS_LOCAL:
10381041
if should_compute_arg():
@@ -1141,6 +1144,7 @@ def clean_and_assemble_instructions(
11411144
code_options["co_exceptiontable"] = assemble_exception_table(
11421145
compute_exception_table(instructions)
11431146
)
1147+
11441148
return instructions, types.CodeType(*[code_options[k] for k in keys])
11451149

11461150

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,8 +1905,7 @@ def LOAD_FAST_AND_CLEAR(self, inst):
19051905
self.symbolic_locals[inst.argval] = NullVariable()
19061906

19071907
def LOAD_SUPER_ATTR(self, inst):
1908-
super_vt, cls_vt, self_vt = self.popn(3)
1909-
self.call_function(super_vt, [cls_vt, self_vt], {})
1908+
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
19101909
if inst.arg & 1:
19111910
self.LOAD_METHOD(inst)
19121911
else:

0 commit comments

Comments
 (0)