|
5 | 5 | # This source code is licensed under the BSD-style license found in the
|
6 | 6 | # LICENSE file in the root directory of this source tree.
|
7 | 7 |
|
| 8 | +""" |
| 9 | +This script is run by CI after building the executorch wheel. Before running |
| 10 | +this, the job will install the matching torch package as well as the newly-built |
| 11 | +executorch package and its dependencies. |
| 12 | +""" |
| 13 | + |
| 14 | +# Import this first. If it can't find the torch.so libraries, the dynamic load |
| 15 | +# will fail and the process will exit. |
| 16 | +from executorch.extension.pybindings import portable_lib # usort: skip |
| 17 | + |
| 18 | +# Import this after importing the ExecuTorch pybindings. If the pybindings |
| 19 | +# links against a different torch.so than this uses, there will be a set of |
| 20 | +# symbol comflicts; the process will either exit now, or there will be issues |
| 21 | +# later in the smoke test. |
| 22 | +import torch # usort: skip |
| 23 | + |
| 24 | +# Import everything else later to help isolate the critical imports above. |
| 25 | +import os |
| 26 | +import tempfile |
| 27 | +from typing import Tuple |
| 28 | + |
| 29 | +from executorch.exir import to_edge |
| 30 | +from torch.export import export |
| 31 | + |
| 32 | + |
| 33 | +class LinearModel(torch.nn.Module): |
| 34 | + """Runs Linear on its input, which should have shape [4].""" |
| 35 | + |
| 36 | + def __init__(self): |
| 37 | + super().__init__() |
| 38 | + self.linear = torch.nn.Linear(4, 2) |
| 39 | + |
| 40 | + def forward(self, x: torch.Tensor): |
| 41 | + """Expects a single tensor of shape [4].""" |
| 42 | + return self.linear(x) |
| 43 | + |
| 44 | + |
| 45 | +def linear_model_inputs() -> Tuple[torch.Tensor]: |
| 46 | + """Returns some example inputs compatible with LinearModel.""" |
| 47 | + # The model takes a single tensor of shape [4] as an input. |
| 48 | + return (torch.ones(4),) |
| 49 | + |
| 50 | + |
| 51 | +def export_linear_model() -> bytes: |
| 52 | + """Exports LinearModel and returns the .pte data.""" |
| 53 | + |
| 54 | + # This helps the exporter understand the shapes of tensors used in the model. |
| 55 | + # Since our model only takes one input, this is a one-tuple. |
| 56 | + example_inputs = linear_model_inputs() |
| 57 | + |
| 58 | + # Export the pytorch model and process for ExecuTorch. |
| 59 | + print("Exporting program...") |
| 60 | + exported_program = export(LinearModel(), example_inputs) |
| 61 | + print("Lowering to edge...") |
| 62 | + edge_program = to_edge(exported_program) |
| 63 | + print("Creating ExecuTorch program...") |
| 64 | + et_program = edge_program.to_executorch() |
| 65 | + |
| 66 | + return et_program.buffer |
| 67 | + |
8 | 68 |
|
9 | 69 | def main():
|
10 |
| - """ |
11 |
| - Run ExecuTorch binary smoke tests. This is a placeholder for future tests. See |
12 |
| - https://github.com/pytorch/test-infra/wiki/Using-Nova-Reusable-Build-Workflows |
13 |
| - for more information about Nova binary workflow. |
14 |
| - """ |
| 70 | + """Tests the export and execution of a simple model.""" |
| 71 | + |
| 72 | + # If the pybindings loaded correctly, we should be able to ask for the set |
| 73 | + # of operators. |
| 74 | + ops = portable_lib._get_operator_names() |
| 75 | + assert len(ops) > 0, "Empty operator list" |
| 76 | + print(f"Found {len(ops)} operators; first element '{ops[0]}'") |
| 77 | + |
| 78 | + # Export LinearModel to .pte data. |
| 79 | + pte_data: bytes = export_linear_model() |
| 80 | + |
| 81 | + # Try saving to and loading from a file. |
| 82 | + with tempfile.TemporaryDirectory() as tempdir: |
| 83 | + pte_file = os.path.join(tempdir, "linear.pte") |
| 84 | + |
| 85 | + # Save the .pte data to a file. |
| 86 | + with open(pte_file, "wb") as file: |
| 87 | + file.write(pte_data) |
| 88 | + print(f"ExecuTorch program saved to {pte_file} ({len(pte_data)} bytes).") |
| 89 | + |
| 90 | + # Load the model from disk. |
| 91 | + m = portable_lib._load_for_executorch(pte_file) |
| 92 | + |
| 93 | + # Run the model. |
| 94 | + outputs = m.forward(linear_model_inputs()) |
| 95 | + |
| 96 | + # Should see a single output with shape [2]. |
| 97 | + assert len(outputs) == 1, f"Unexpected output length {len(outputs)}: {outputs}" |
| 98 | + assert outputs[0].shape == (2,), f"Unexpected output size {outputs[0].shape}" |
| 99 | + |
| 100 | + print("PASS") |
15 | 101 |
|
16 | 102 |
|
17 | 103 | if __name__ == "__main__":
|
|
0 commit comments