Skip to content

Documentation #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`77`: supports ConcatOfShape and Slice with the light API
* :pr:`76`: add a mode to compare models without execution
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
Expand Down
30 changes: 29 additions & 1 deletion _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,35 @@ def test_constant_of_shape(self):
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)

def test_constant_of_shape_value(self):
onx = (
start()
.vin("X", TensorProto.INT64, shape=[None, None])
.ConstantOfShape(value=np.array([1], dtype=np.float32))
.vout(shape=[])
.to_onnx()
)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)

def test_slice(self):
onx = (
start(opset=18, ir_version=9)
.cst(np.array([1], dtype=np.int64), name="one")
.cst(np.array([2], dtype=np.int64), name="two")
.vin("X", TensorProto.INT64, shape=[None, None])
.ConstantOfShape(value=np.array([1], dtype=np.float32))
.rename("CX")
.bring("CX", "one", "two", "one")
.Slice()
.vout(shape=[])
.to_onnx()
)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)


if __name__ == "__main__":
TestLightApi().test_add()
unittest.main(verbosity=2)
7 changes: 7 additions & 0 deletions onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ def Selu(
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
return self.make_node("Shrink", self, bias=bias, lambd=lambd)

def Slice(
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
) -> "Var":
if steps is None:
return self.make_node("Slice", self, starts, ends, axes)
return self.make_node("Slice", self, starts, ends, axes, steps)

def Softmax(self, axis: int = -1) -> "Var":
return self.make_node("Softmax", self, axis=axis)

Expand Down
2 changes: 1 addition & 1 deletion onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def to_onnx(self) -> GRAPH_PROTO:
return graph
model = make_model(graph, opset_imports=opsets)
if self.ir_version:
model.ir_version = ir_version
model.ir_version = self.ir_version
if not is_windows() or not is_azure():
# check_model fails sometimes on Windows
check_model(model)
Expand Down