Skip to content

Commit cf61c73

Browse files
Merge branch 'main' into pylint
2 parents 09078f9 + 1b76af3 commit cf61c73

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1632
-1267
lines changed

.github/release.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This file has been mostly taken verbatim from https://github.com/pymc-devs/pymc/blob/main/.github/release.yml
2+
#
3+
# This file contains configuration for the automatic generation of release notes in GitHub.
4+
# It's not perfect, but it makes it a little less laborious to write informative release notes.
5+
# Also see https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes
6+
changelog:
7+
exclude:
8+
labels:
9+
- no releasenotes
10+
categories:
11+
- title: Major Changes 🛠
12+
labels:
13+
- major
14+
- title: New Features & Bugfixes 🎉
15+
labels:
16+
- bug
17+
- enhancements
18+
- feature-request
19+
- title: Docs & Maintenance 🔧
20+
labels:
21+
- docs
22+
- installation
23+
- maintenance
24+
- pre-commit
25+
- tests
26+
- "*"

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ Contributing
114114
We welcome bug reports and fixes and improvements to the documentation.
115115

116116
For more information on contributing, please see the
117-
`contributing guide <https://github.com/pymc-devs/pytensor/CONTRIBUTING.md>`.
117+
`contributing guide <https://github.com/pymc-devs/pytensor/CONTRIBUTING.md>`__.
118118

119119
A good place to start contributing is by looking through the issues
120-
`here <https://github.com/pymc-devs/pytensor/issues`.
120+
`here <https://github.com/pymc-devs/pytensor/issues>`__.
121121

122122

123123
.. |Project Name| replace:: PyTensor

doc/library/compile/shared.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
.. class:: SharedVariable
1515

16-
Variable with Storage that is shared between functions that it appears in.
16+
Variable with storage that is shared between the compiled functions that it appears in.
1717
These variables are meant to be created by registered *shared constructors*
1818
(see :func:`shared_constructor`).
1919

@@ -68,18 +68,17 @@
6868

6969
A container to use for this SharedVariable when it is an implicit function parameter.
7070

71-
:type: class:`Container`
7271

7372
.. autofunction:: shared
7473

7574
.. function:: shared_constructor(ctor)
7675

7776
Append `ctor` to the list of shared constructors (see :func:`shared`).
7877

79-
Each registered constructor ``ctor`` will be called like this:
78+
Each registered constructor `ctor` will be called like this:
8079

8180
.. code-block:: python
8281
8382
ctor(value, name=name, strict=strict, **kwargs)
8483
85-
If it do not support given value, it must raise a TypeError.
84+
If it do not support given value, it must raise a `TypeError`.

pytensor/compile/builders.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33
from copy import copy
44
from functools import partial
5-
from typing import List, Optional, Sequence, cast
5+
from typing import Dict, List, Optional, Sequence, Tuple, cast
66

77
import pytensor.tensor as at
88
from pytensor import function
@@ -19,7 +19,6 @@
1919
clone_replace,
2020
graph_inputs,
2121
io_connection_pattern,
22-
replace_nominals_with_dummies,
2322
)
2423
from pytensor.graph.fg import FunctionGraph
2524
from pytensor.graph.null_type import NullType
@@ -82,6 +81,81 @@ def local_traverse(out):
8281
return ret
8382

8483

84+
def construct_nominal_fgraph(
85+
inputs: Sequence[Variable], outputs: Sequence[Variable]
86+
) -> Tuple[
87+
FunctionGraph,
88+
Sequence[Variable],
89+
Dict[Variable, Variable],
90+
Dict[Variable, Variable],
91+
]:
92+
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
93+
dummy_inputs = []
94+
for n, inp in enumerate(inputs):
95+
if (
96+
not isinstance(inp, Variable)
97+
or isinstance(inp, Constant)
98+
or isinstance(inp, SharedVariable)
99+
):
100+
raise TypeError(
101+
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
102+
)
103+
104+
dummy_inputs.append(inp.type())
105+
106+
dummy_shared_inputs = []
107+
shared_inputs = []
108+
for var in graph_inputs(outputs, inputs):
109+
if isinstance(var, SharedVariable):
110+
# To correctly support shared variables the inner-graph should
111+
# not see them; otherwise, there will be problems with
112+
# gradients.
113+
# That's why we collect the shared variables and replace them
114+
# with dummies.
115+
shared_inputs.append(var)
116+
dummy_shared_inputs.append(var.type())
117+
elif var not in inputs and not isinstance(var, Constant):
118+
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
119+
120+
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
121+
122+
new = rebuild_collect_shared(
123+
cast(Sequence[Variable], outputs),
124+
inputs=inputs + shared_inputs,
125+
replace=replacements,
126+
copy_inputs_over=False,
127+
)
128+
(
129+
local_inputs,
130+
local_outputs,
131+
(clone_d, update_d, update_expr, new_shared_inputs),
132+
) = new
133+
134+
assert len(local_inputs) == len(inputs) + len(shared_inputs)
135+
assert len(local_outputs) == len(outputs)
136+
assert not update_d
137+
assert not update_expr
138+
assert not new_shared_inputs
139+
140+
fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
141+
142+
# The inputs need to be `NominalVariable`s so that we can merge
143+
# inner-graphs
144+
nominal_local_inputs = tuple(
145+
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
146+
)
147+
148+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
149+
150+
for i, inp in enumerate(fgraph.inputs):
151+
nom_inp = nominal_local_inputs[i]
152+
fgraph.inputs[i] = nom_inp
153+
fgraph.clients.pop(inp, None)
154+
fgraph.add_input(nom_inp)
155+
156+
return fgraph, shared_inputs, update_d, update_expr
157+
158+
85159
class OpFromGraph(Op, HasInnerGraph):
86160
r"""
87161
This creates an `Op` from inputs and outputs lists of variables.
@@ -333,66 +407,21 @@ def __init__(
333407
if not (isinstance(inputs, list) and isinstance(outputs, list)):
334408
raise TypeError("Inputs and outputs must be lists")
335409

336-
for i in inputs + outputs:
337-
if not isinstance(i, Variable):
410+
for out in outputs:
411+
if not isinstance(out, Variable):
338412
raise TypeError(
339-
f"Inputs and outputs must be Variable instances; got {i}"
413+
f"Inputs and outputs must be Variable instances; got {out}"
340414
)
341-
if i in inputs:
342-
if isinstance(i, Constant):
343-
raise TypeError(f"Constants not allowed as inputs; {i}")
344-
if isinstance(i, SharedVariable):
345-
raise TypeError(f"SharedVariables not allowed as inputs; {i}")
346-
347-
for var in graph_inputs(outputs, inputs):
348-
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
349-
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
350415

351416
if "updates" in kwargs or "givens" in kwargs:
352-
raise NotImplementedError("Updates and givens are not allowed here")
417+
raise NotImplementedError("Updates and givens are not supported")
353418

354419
self.is_inline = inline
355420

356-
# To correctly support shared variables the inner fct should
357-
# not see them. Otherwise there is a problem with the gradient.
358-
self.shared_inputs = []
359-
for var in graph_inputs(outputs):
360-
if isinstance(var, SharedVariable):
361-
self.shared_inputs.append(var)
362-
363-
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
364-
365-
# The inputs should be `NominalVariable`s, so that graphs can be merged
366-
replacements = {}
367-
for n, v in enumerate(inputs):
368-
replacements[v] = NominalVariable(n, v.type)
369-
370-
shared_vars = [
371-
NominalVariable(n, var.type)
372-
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
373-
]
374-
375-
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
376-
377-
new = rebuild_collect_shared(
378-
cast(Sequence[Variable], outputs),
379-
inputs=inputs + shared_vars,
380-
replace=replacements,
381-
copy_inputs_over=False,
421+
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
422+
inputs, outputs
382423
)
383-
(
384-
local_inputs,
385-
local_outputs,
386-
(clone_d, update_d, update_expr, shared_inputs),
387-
) = new
388-
389-
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
390-
assert len(local_outputs) == len(outputs)
391-
assert not update_d
392-
assert not update_expr
393-
assert not shared_inputs
394-
395-
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
424+
396425
self.kwargs = kwargs
397426
self.input_types = [inp.type for inp in inputs]
398427
self.output_types = [out.type for out in outputs]
@@ -415,6 +444,7 @@ def __init__(
415444
else:
416445
self.set_lop_overrides("default")
417446
self._lop_type = "lop"
447+
418448
self.set_rop_overrides(rop_overrides)
419449

420450
self._connection_pattern = connection_pattern

pytensor/compile/function/pfunc.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
44
"""
55

6-
import logging
76
from copy import copy
87
from typing import Optional
98

@@ -16,11 +15,6 @@
1615
from pytensor.graph.fg import FunctionGraph
1716

1817

19-
_logger = logging.getLogger("pytensor.compile.function.pfunc")
20-
21-
__docformat__ = "restructuredtext en"
22-
23-
2418
def rebuild_collect_shared(
2519
outputs,
2620
inputs=None,
@@ -78,10 +72,12 @@ def rebuild_collect_shared(
7872
shared_inputs = []
7973

8074
def clone_v_get_shared_updates(v, copy_inputs_over):
81-
"""
82-
Clones a variable and its inputs recursively until all are in clone_d.
83-
Also appends all shared variables met along the way to shared inputs,
84-
and their default_update (if applicable) to update_d and update_expr.
75+
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
76+
77+
Also, it appends all `SharedVariable`\s met along the way to
78+
`shared_inputs` and their corresponding
79+
`SharedVariable.default_update`\s (when applicable) to `update_d` and
80+
`update_expr`.
8581
8682
"""
8783
# this co-recurses with clone_a
@@ -103,7 +99,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
10399
elif isinstance(v, SharedVariable):
104100
if v not in shared_inputs:
105101
shared_inputs.append(v)
106-
if hasattr(v, "default_update"):
102+
if v.default_update is not None:
107103
# Check that v should not be excluded from the default
108104
# updates list
109105
if no_default_updates is False or (
@@ -419,22 +415,24 @@ def construct_pfunc_ins_and_outs(
419415
givens = []
420416

421417
if not isinstance(params, (list, tuple)):
422-
raise Exception("in pfunc() the first argument must be a list or a tuple")
418+
raise TypeError("The `params` argument must be a list or a tuple")
423419

424420
if not isinstance(no_default_updates, bool) and not isinstance(
425421
no_default_updates, list
426422
):
427-
raise TypeError("no_default_update should be either a boolean or a list")
423+
raise TypeError("The `no_default_update` argument must be a boolean or list")
428424

429-
if len(updates) > 0 and any(
430-
isinstance(v, Variable) for v in iter_over_pairs(updates)
425+
if len(updates) > 0 and not all(
426+
isinstance(pair, (tuple, list))
427+
and len(pair) == 2
428+
and isinstance(pair[0], Variable)
429+
for pair in iter_over_pairs(updates)
431430
):
432-
raise ValueError(
433-
"The updates parameter must be an OrderedDict/dict or a list of "
434-
"lists/tuples with 2 elements"
431+
raise TypeError(
432+
"The `updates` parameter must be an ordered mapping or a list of pairs"
435433
)
436434

437-
# transform params into pytensor.compile.In objects.
435+
# Transform params into pytensor.compile.In objects.
438436
inputs = [
439437
_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params
440438
]

pytensor/compile/function/types.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333

3434
if TYPE_CHECKING:
35+
from pytensor.compile.mode import Mode
3536
from pytensor.link.vm import VM
3637

3738

@@ -1391,16 +1392,24 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
13911392

13921393
@staticmethod
13931394
def prepare_fgraph(
1394-
inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile
1395+
inputs,
1396+
outputs,
1397+
additional_outputs,
1398+
fgraph: FunctionGraph,
1399+
mode: "Mode",
1400+
profile,
13951401
):
13961402

1403+
rewriter = mode.optimizer
1404+
13971405
try:
13981406
start_rewriter = time.perf_counter()
13991407

14001408
rewriter_profile = None
14011409
rewrite_time = None
14021410

14031411
with config.change_flags(
1412+
mode=mode,
14041413
compute_test_value=config.compute_test_value_opt,
14051414
traceback__limit=config.traceback__compile_limit,
14061415
):
@@ -1440,7 +1449,7 @@ def prepare_fgraph(
14401449
stacklevel=3,
14411450
)
14421451

1443-
if not hasattr(linker, "accept"):
1452+
if not hasattr(mode.linker, "accept"):
14441453
raise ValueError(
14451454
"'linker' parameter of FunctionMaker should be "
14461455
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}"
@@ -1511,12 +1520,8 @@ def __init__(
15111520

15121521
self.fgraph = fgraph
15131522

1514-
rewriter, linker = mode.optimizer, copy.copy(mode.linker)
1515-
15161523
if not no_fgraph_prep:
1517-
self.prepare_fgraph(
1518-
inputs, outputs, found_updates, fgraph, rewriter, linker, profile
1519-
)
1524+
self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
15201525

15211526
assert len(fgraph.outputs) == len(outputs + found_updates)
15221527

@@ -1528,6 +1533,8 @@ def __init__(
15281533
if not spec.borrow
15291534
]
15301535

1536+
linker = copy.copy(mode.linker)
1537+
15311538
if no_borrow:
15321539
self.linker = linker.accept(
15331540
fgraph,

0 commit comments

Comments
 (0)