Skip to content

Commit 3928ca2

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo] update call map to allow multiple input parameters (pytorch#130748)
Fixes pytorch#128072. Commandeering pytorch#128282 since the issue is now hi pri. Pull Request resolved: pytorch#130748 Approved by: https://github.com/Skylion007, https://github.com/anijain2305
1 parent 6f32dc0 commit 3928ca2

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

test/dynamo/test_repros.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5327,6 +5327,27 @@ def random_op(tensor, params):
53275327
tensor = torch.randn([2, 3])
53285328
res = random_op(tensor, params)
53295329

5330+
# https://github.com/pytorch/pytorch/issues/128072
5331+
def test_map_with_multiple_args(self):
5332+
def f(a, b):
5333+
return a[0] * b[0] + a[1] * b[1]
5334+
5335+
def gen_inps(len_x, len_y):
5336+
x = [torch.randn(5) for _ in range(len_x)]
5337+
y = [torch.randn(5) for _ in range(len_y)]
5338+
return x, y
5339+
5340+
def g(x, y):
5341+
return tuple(map(f, x, y))
5342+
5343+
opt_g = torch.compile(g, fullgraph=True, backend="eager")
5344+
5345+
inps = gen_inps(3, 3)
5346+
self.assertEqual(g(*inps), opt_g(*inps))
5347+
5348+
inps = gen_inps(3, 5)
5349+
self.assertEqual(g(*inps), opt_g(*inps))
5350+
53305351

53315352
instantiate_parametrized_tests(ReproTests)
53325353

torch/_dynamo/variables/builtin.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,9 +1470,10 @@ def call_hasattr(self, tx, obj, attr):
14701470
return variables.ConstantVariable(hasattr(obj.fn, name))
14711471
return obj.call_hasattr(tx, name)
14721472

1473-
def call_map(self, tx, fn, seq):
1474-
if seq.has_unpack_var_sequence(tx):
1475-
items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)]
1473+
def call_map(self, tx, fn, *seqs):
1474+
if all(seq.has_unpack_var_sequence(tx) for seq in seqs):
1475+
unpacked = [seq.unpack_var_sequence(tx) for seq in seqs]
1476+
items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)]
14761477
return variables.TupleVariable(items)
14771478

14781479
def call_sum(self, tx, seq, start=_SENTINEL):

0 commit comments

Comments
 (0)