Skip to content

Commit 0b3f2db

Browse files
Thiago Crepaldiguilhermeleobas
authored andcommitted
Add huggingface gpt2 fake tensor unit test for torch.onnx.dynamo_export (pytorch#115380)
open llama, dolly v2 and falcon are still broken regardless of `ExportedProgram`, so they were not moved from `test_fx_to_onnx.py` to `fx_to_onnx_onnxruntime.py`. Dolly and falcon already have tracking issues, but a tracking issue was created for open llama: pytorch#115552 A tracking issue was created for `xfail_if_model_type_is_exportedprogram` and `xfail_if_model_type_is_not_exportedprogram` issues with unexpected success runs: pytorch#115747 Pull Request resolved: pytorch#115380 Approved by: https://github.com/titaiwangms
1 parent 65418d2 commit 0b3f2db

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

test/onnx/test_fx_to_onnx_with_onnxruntime.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ def create_kwargs():
11251125
"AssertionError: Dynamic shape check failed for graph inputs",
11261126
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
11271127
)
1128-
def test_large_scale_exporter_with_tiny_gpt2(self):
1128+
def test_fake_tensor_mode_huggingface_tiny_gpt2(self):
11291129
model_name = "sshleifer/tiny-gpt2"
11301130
device = "cpu"
11311131

@@ -1345,6 +1345,49 @@ def create_model():
13451345
model_type=self.model_type,
13461346
)
13471347

1348+
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
1349+
"AssertionError: Expected 5 inputs, got 3"
1350+
"Github issue: https://github.com/pytorch/pytorch/issues/115745"
1351+
)
1352+
@pytorch_test_common.skip_dynamic_fx_test(
1353+
"AssertionError: Dynamic shape check failed for graph inputs",
1354+
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1355+
)
1356+
def test_fake_tensor_mode_huggingface_gpt2(self):
1357+
config = transformers.GPT2Config(
1358+
vocab_size=8096, n_positions=256, n_embd=256, n_layer=2, n_head=2
1359+
)
1360+
1361+
def create_model():
1362+
return transformers.GPT2Model(config).eval()
1363+
1364+
def create_args():
1365+
return tuple()
1366+
1367+
def create_kwargs():
1368+
batch, seq = 4, 256
1369+
1370+
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
1371+
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
1372+
position_ids = torch.arange(0, seq, dtype=torch.long)
1373+
position_ids = position_ids.unsqueeze(0).view(-1, seq)
1374+
1375+
return {
1376+
"input_ids": input_ids,
1377+
"attention_mask": attention_mask,
1378+
"position_ids": position_ids,
1379+
}
1380+
1381+
self._test_fake_tensor_mode_exporter(
1382+
"huggingface_gpt2",
1383+
create_model,
1384+
create_args,
1385+
create_kwargs,
1386+
load_checkpoint_during_init=self.load_checkpoint_during_init,
1387+
export_within_fake_mode=self.export_within_fake_mode,
1388+
model_type=self.model_type,
1389+
)
1390+
13481391

13491392
if __name__ == "__main__":
13501393
common_utils.run_tests()

0 commit comments

Comments
 (0)