Skip to content

sdpython/onnx-diagnostic

Repository files navigation

https://github.com/sdpython/onnx-diagnostic/raw/main/_doc/_static/logo.png

onnx-diagnostic: investigate onnx models

MIT License size https://codecov.io/gh/sdpython/onnx-diagnostic/branch/main/graph/badge.svg?token=Wb9ZGDta8J

The main feature is about patches: it helps exporting pytorch models into ONNX, mostly designed for LLMs using dynamic caches.

with torch_export_patches(patch_transformers=True) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
    # ...

It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...). See documentation of onnx-diagnostic and torch_export_patches.

Getting started

git clone https://github.com/sdpython/onnx-diagnostic.git
cd onnx-diagnostic
pip install -e .

or

pip install onnx-diagnostic

Enlightening Examples

Where to start to export a model

Torch Export

Investigate ONNX models

Snapshot of usefuls tools

torch_export_patches

with torch_export_patches(patch_transformers=True) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
    # ...

torch_export_rewrite

with torch_export_rewrite(rewrite=[Model.forward]) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
    # ...

string_type

import torch
from onnx_diagnostic.helpers import string_type

inputs = (
    torch.rand((3, 4), dtype=torch.float16),
    [torch.rand((5, 6), dtype=torch.float16), torch.rand((5, 6, 7), dtype=torch.float16)],
)

# with shapes
print(string_type(inputs, with_shape=True))
>>> (T10s3x4,#2[T10s5x6,T10s5x6x7])

onnx_dtype_name

import onnx
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name

itype = onnx.TensorProto.BFLOAT16
print(onnx_dtype_name(itype))
print(onnx_dtype_name(7))
>>> BFLOAT16
>>> INT64

max_diff

import torch
from onnx_diagnostic.helpers import max_diff

print(
    max_diff(
        (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
        (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
    )
)
>>> {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}s

guess_dynamic_shapes

inputs = [
    (torch.randn((5, 6)), torch.randn((1, 6))),
    (torch.randn((7, 8)), torch.randn((1, 8))),
]
ds = ModelInputs(model, inputs).guess_dynamic_shapes(auto="dim")
print(ds)
>>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {})

About

Investigate onnx models

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages