Skip to content

Don't change constant names in Features #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,8 @@ class PreserveVariableAttributes(Feature):
"""

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None:
# Don't change the name of constants
if r.owner and r.name is not None and new_r.name is None:
new_r.name = r.name
if (
getattr(r.tag, "nan_guard_mode_check", False)
Expand Down
4 changes: 0 additions & 4 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,6 @@ def process_atomic(self, fgraph, c):
sig = c.merge_signature()
other_c = self.atomic_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
self.scheduled.append([[(c, other_c, "merge")]])
else:
# this is a new constant
Expand Down
3 changes: 2 additions & 1 deletion pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4236,7 +4236,8 @@ def fgraph(self):
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().rewrite(fgraph)
with config.change_flags(optimizer_verbose=False):
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
Expand Down