Skip to content

Commit ccbf2e9

Browse files
Fix typing issues in aesara.scan
1 parent 05da60c commit ccbf2e9

File tree

4 files changed

+101
-77
lines changed

4 files changed

+101
-77
lines changed

aesara/scan/opt.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import dataclasses
55
import logging
66
from sys import maxsize
7-
from typing import Dict, List, Tuple
7+
from typing import Dict, List, Optional, Tuple
88

99
import numpy as np
1010

@@ -15,8 +15,8 @@
1515
from aesara.compile.function.types import deep_copy_op
1616
from aesara.configdefaults import config
1717
from aesara.graph.basic import (
18+
Apply,
1819
Constant,
19-
Node,
2020
Variable,
2121
clone_replace,
2222
equal_computations,
@@ -30,6 +30,7 @@
3030
from aesara.graph.op import compute_test_value
3131
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
3232
from aesara.graph.optdb import EquilibriumDB, SequenceDB
33+
from aesara.graph.type import HasShape
3334
from aesara.graph.utils import InconsistencyError
3435
from aesara.scan.op import Scan, ScanInfo
3536
from aesara.scan.utils import (
@@ -652,7 +653,7 @@ def inner_sitsot_only_last_step_used(
652653

653654
if len(fgraph.clients[outer_var]) == 1:
654655
client = fgraph.clients[outer_var][0][0]
655-
if client != "output" and isinstance(client.op, Subtensor):
656+
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
656657
lst = get_idx_list(client.inputs, client.op.idx_list)
657658
if len(lst) == 1 and at.extract_constant(lst[0]) == -1:
658659
return True
@@ -662,22 +663,24 @@ def inner_sitsot_only_last_step_used(
662663

663664
def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
664665
"""Determine the number of dimension a variable would have if it was pushed out of a `Scan`."""
666+
assert isinstance(var.type, HasShape)
667+
665668
if var in scan_args.inner_in_non_seqs or isinstance(var, Constant):
666-
outer_ndim = var.ndim
669+
outer_ndim = var.type.ndim
667670
else:
668-
outer_ndim = var.ndim + 1
671+
outer_ndim = var.type.ndim + 1
669672

670673
return outer_ndim
671674

672675

673676
def push_out_inner_vars(
674677
fgraph: FunctionGraph,
675678
inner_vars: List[Variable],
676-
old_scan_node: Node,
679+
old_scan_node: Apply,
677680
old_scan_args: ScanArgs,
678681
) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]:
679682

680-
outer_vars = [None] * len(inner_vars)
683+
tmp_outer_vars: List[Optional[Variable]] = []
681684
new_scan_node = old_scan_node
682685
new_scan_args = old_scan_args
683686
replacements: Dict[Variable, Variable] = {}
@@ -688,26 +691,32 @@ def push_out_inner_vars(
688691

689692
var = inner_vars[idx]
690693

694+
new_outer_var: Optional[Variable] = None
695+
691696
if var in old_scan_args.inner_in_seqs:
692697
idx_seq = old_scan_args.inner_in_seqs.index(var)
693-
outer_vars[idx] = old_scan_args.outer_in_seqs[idx_seq]
698+
new_outer_var = old_scan_args.outer_in_seqs[idx_seq]
694699

695700
elif var in old_scan_args.inner_in_non_seqs:
696701
idx_non_seq = old_scan_args.inner_in_non_seqs.index(var)
697-
outer_vars[idx] = old_scan_args.outer_in_non_seqs[idx_non_seq]
702+
new_outer_var = old_scan_args.outer_in_non_seqs[idx_non_seq]
698703

699704
elif isinstance(var, Constant):
700-
outer_vars[idx] = var.clone()
705+
new_outer_var = var.clone()
701706

702707
elif var in old_scan_args.inner_out_nit_sot:
703708
idx_nitsot = old_scan_args.inner_out_nit_sot.index(var)
704-
outer_vars[idx] = old_scan_args.outer_out_nit_sot[idx_nitsot]
709+
new_outer_var = old_scan_args.outer_out_nit_sot[idx_nitsot]
710+
711+
tmp_outer_vars.append(new_outer_var)
705712

706713
# For the inner_vars that don't already exist in the outer graph, add
707714
# them as new nitsot outputs to the scan node.
708-
idx_add_as_nitsots = [i for i in range(len(outer_vars)) if outer_vars[i] is None]
715+
idx_add_as_nitsots = [i for i, v in enumerate(tmp_outer_vars) if v is None]
709716
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
710717

718+
new_outs: List[Variable] = []
719+
711720
if len(add_as_nitsots) > 0:
712721

713722
new_scan_node, replacements = add_nitsot_outputs(
@@ -724,18 +733,25 @@ def push_out_inner_vars(
724733
)
725734

726735
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
727-
for i in range(len(new_outs)):
728-
outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
736+
737+
outer_vars: List[Variable] = []
738+
739+
for i, v in enumerate(tmp_outer_vars):
740+
if i in idx_add_as_nitsots:
741+
outer_vars.append(new_outs.pop(0))
742+
else:
743+
assert v is not None
744+
outer_vars.append(v)
729745

730746
return outer_vars, new_scan_args, replacements
731747

732748

733749
def add_nitsot_outputs(
734750
fgraph: FunctionGraph,
735-
old_scan_node: Node,
751+
old_scan_node: Apply,
736752
old_scan_args: ScanArgs,
737753
new_outputs_inner,
738-
) -> Tuple[Node, Dict[Variable, Variable]]:
754+
) -> Tuple[Apply, Dict[Variable, Variable]]:
739755

740756
nb_new_outs = len(new_outputs_inner)
741757

@@ -764,7 +780,10 @@ def add_nitsot_outputs(
764780
)
765781

766782
# Create the Apply node for the scan op
767-
new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[0].owner
783+
new_scan_outs = new_scan_op(*new_scan_args.outer_inputs, return_list=True)
784+
assert isinstance(new_scan_outs, list)
785+
new_scan_node = new_scan_outs[0].owner
786+
assert new_scan_node is not None
768787

769788
# Modify the outer graph to make sure the outputs of the new scan are
770789
# used instead of the outputs of the old scan
@@ -781,7 +800,7 @@ def add_nitsot_outputs(
781800
# replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs))
782801
# replacements["remove"] = [old_scan_node]
783802
# return new_scan_node, replacements
784-
fgraph.replace_all_validate_remove(
803+
fgraph.replace_all_validate_remove( # type: ignore
785804
list(zip(old_scan_node.outputs, new_node_old_outputs)),
786805
remove=[old_scan_node],
787806
reason="scan_pushout_add",

aesara/scan/scan_perform_ext.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import os
1010
import sys
1111
from importlib import reload
12+
from types import ModuleType
13+
from typing import Optional
1214

1315
import aesara
1416
from aesara.compile.compilelock import lock_ctx
@@ -24,6 +26,7 @@
2426
version = 0.312 # must match constant returned in function get_version()
2527

2628
need_reload = False
29+
scan_perform: Optional[ModuleType] = None
2730

2831

2932
def try_import():
@@ -107,7 +110,10 @@ def try_reload():
107110

108111
from scan_perform import scan_perform as scan_c
109112

110-
assert scan_perform._version == scan_c.get_version()
113+
assert (
114+
scan_perform is not None
115+
and scan_perform._version == scan_c.get_version()
116+
)
111117

112118
_logger.info(f"New version {scan_perform._version}")
113119

aesara/scan/utils.py

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import logging
66
import warnings
77
from collections import OrderedDict, namedtuple
8-
from typing import TYPE_CHECKING, Callable, List, Optional, Set, Tuple, Union
8+
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union
9+
from typing import cast as type_cast
910

1011
import numpy as np
1112

@@ -21,8 +22,9 @@
2122
graph_inputs,
2223
)
2324
from aesara.graph.op import get_test_value
25+
from aesara.graph.type import HasDataType
2426
from aesara.graph.utils import TestValueError
25-
from aesara.tensor.basic import AllocEmpty, get_scalar_constant_value
27+
from aesara.tensor.basic import AllocEmpty, cast, get_scalar_constant_value
2628
from aesara.tensor.subtensor import set_subtensor
2729
from aesara.tensor.var import TensorConstant
2830

@@ -55,8 +57,11 @@ def safe_new(
5557
nw_name = None
5658

5759
if isinstance(x, Constant):
58-
if dtype and x.dtype != dtype:
59-
casted_x = x.astype(dtype)
60+
# TODO: Do something better about this
61+
assert isinstance(x.type, HasDataType)
62+
63+
if dtype and x.type.dtype != dtype:
64+
casted_x = cast(x, dtype)
6065
nwx = type(x)(casted_x.type, x.data, x.name)
6166
nwx.tag = copy.copy(x.tag)
6267
return nwx
@@ -89,8 +94,12 @@ def safe_new(
8994
pass
9095

9196
# Cast `x` if needed. If `x` has a test value, this will also cast it.
92-
if dtype and x.dtype != dtype:
93-
x = x.astype(dtype)
97+
if dtype:
98+
# TODO: Do something better about this
99+
assert isinstance(x.type, HasDataType)
100+
101+
if x.type.dtype != dtype:
102+
x = cast(x, dtype)
94103

95104
nw_x = x.type()
96105
nw_x.name = nw_name
@@ -717,13 +726,13 @@ class ScanArgs:
717726

718727
def __init__(
719728
self,
720-
outer_inputs,
721-
outer_outputs,
722-
_inner_inputs,
723-
_inner_outputs,
724-
info,
725-
as_while,
726-
clone=True,
729+
outer_inputs: Sequence[Variable],
730+
outer_outputs: Sequence[Variable],
731+
_inner_inputs: Sequence[Variable],
732+
_inner_outputs: Sequence[Variable],
733+
info: "ScanInfo",
734+
as_while: bool,
735+
clone: Optional[bool] = True,
727736
):
728737
self.n_steps = outer_inputs[0]
729738
self.as_while = as_while
@@ -745,16 +754,16 @@ def __init__(
745754
q = 0
746755

747756
n_seqs = info.n_seqs
748-
self.outer_in_seqs = outer_inputs[p : p + n_seqs]
749-
self.inner_in_seqs = inner_inputs[q : q + n_seqs]
757+
self.outer_in_seqs = list(outer_inputs[p : p + n_seqs])
758+
self.inner_in_seqs = list(inner_inputs[q : q + n_seqs])
750759
p += n_seqs
751760
q += n_seqs
752761

753762
n_mit_mot = info.n_mit_mot
754763
n_mit_sot = info.n_mit_sot
755764

756-
self.mit_mot_in_slices = info.tap_array[:n_mit_mot]
757-
self.mit_sot_in_slices = info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot]
765+
self.mit_mot_in_slices = list(info.tap_array[:n_mit_mot])
766+
self.mit_sot_in_slices = list(info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
758767

759768
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
760769
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
@@ -775,63 +784,63 @@ def __init__(
775784
qq += len(sl)
776785
q += n_mit_sot_ins
777786

778-
self.outer_in_mit_mot = outer_inputs[p : p + n_mit_mot]
787+
self.outer_in_mit_mot = list(outer_inputs[p : p + n_mit_mot])
779788
p += n_mit_mot
780-
self.outer_in_mit_sot = outer_inputs[p : p + n_mit_sot]
789+
self.outer_in_mit_sot = list(outer_inputs[p : p + n_mit_sot])
781790
p += n_mit_sot
782791

783792
n_sit_sot = info.n_sit_sot
784-
self.outer_in_sit_sot = outer_inputs[p : p + n_sit_sot]
785-
self.inner_in_sit_sot = inner_inputs[q : q + n_sit_sot]
793+
self.outer_in_sit_sot = list(outer_inputs[p : p + n_sit_sot])
794+
self.inner_in_sit_sot = list(inner_inputs[q : q + n_sit_sot])
786795
p += n_sit_sot
787796
q += n_sit_sot
788797

789798
n_shared_outs = info.n_shared_outs
790-
self.outer_in_shared = outer_inputs[p : p + n_shared_outs]
791-
self.inner_in_shared = inner_inputs[q : q + n_shared_outs]
799+
self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs])
800+
self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs])
792801
p += n_shared_outs
793802
q += n_shared_outs
794803

795804
n_nit_sot = info.n_nit_sot
796-
self.outer_in_nit_sot = outer_inputs[p : p + n_nit_sot]
805+
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
797806
p += n_nit_sot
798807

799-
self.outer_in_non_seqs = outer_inputs[p:]
800-
self.inner_in_non_seqs = inner_inputs[q:]
808+
self.outer_in_non_seqs = list(outer_inputs[p:])
809+
self.inner_in_non_seqs = list(inner_inputs[q:])
801810

802811
# now for the outputs
803812
p = 0
804813
q = 0
805814

806815
self.mit_mot_out_slices = info.mit_mot_out_slices
807816
n_mit_mot_outs = info.n_mit_mot_outs
808-
self.outer_out_mit_mot = outer_outputs[p : p + n_mit_mot]
809-
iomm = inner_outputs[q : q + n_mit_mot_outs]
810-
self.inner_out_mit_mot = ()
817+
self.outer_out_mit_mot = list(outer_outputs[p : p + n_mit_mot])
818+
iomm = list(inner_outputs[q : q + n_mit_mot_outs])
819+
self.inner_out_mit_mot: Tuple[List[Variable], ...] = ()
811820
qq = 0
812821
for sl in self.mit_mot_out_slices:
813822
self.inner_out_mit_mot += (iomm[qq : qq + len(sl)],)
814823
qq += len(sl)
815824
p += n_mit_mot
816825
q += n_mit_mot_outs
817826

818-
self.outer_out_mit_sot = outer_outputs[p : p + n_mit_sot]
819-
self.inner_out_mit_sot = inner_outputs[q : q + n_mit_sot]
827+
self.outer_out_mit_sot = list(outer_outputs[p : p + n_mit_sot])
828+
self.inner_out_mit_sot = list(inner_outputs[q : q + n_mit_sot])
820829
p += n_mit_sot
821830
q += n_mit_sot
822831

823-
self.outer_out_sit_sot = outer_outputs[p : p + n_sit_sot]
824-
self.inner_out_sit_sot = inner_outputs[q : q + n_sit_sot]
832+
self.outer_out_sit_sot = list(outer_outputs[p : p + n_sit_sot])
833+
self.inner_out_sit_sot = list(inner_outputs[q : q + n_sit_sot])
825834
p += n_sit_sot
826835
q += n_sit_sot
827836

828-
self.outer_out_nit_sot = outer_outputs[p : p + n_nit_sot]
829-
self.inner_out_nit_sot = inner_outputs[q : q + n_nit_sot]
837+
self.outer_out_nit_sot = list(outer_outputs[p : p + n_nit_sot])
838+
self.inner_out_nit_sot = list(inner_outputs[q : q + n_nit_sot])
830839
p += n_nit_sot
831840
q += n_nit_sot
832841

833-
self.outer_out_shared = outer_outputs[p : p + n_shared_outs]
834-
self.inner_out_shared = inner_outputs[q : q + n_shared_outs]
842+
self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs])
843+
self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs])
835844
p += n_shared_outs
836845
q += n_shared_outs
837846

@@ -978,12 +987,18 @@ def get_alt_field(
978987
-------
979988
The alternate variable.
980989
"""
990+
_var_info: FieldInfo
981991
if not isinstance(var_info, FieldInfo):
982-
var_info = self.find_among_fields(var_info)
992+
find_var_info = self.find_among_fields(var_info)
993+
if find_var_info is None:
994+
raise ValueError(f"Couldn't find {var_info} among fields")
995+
_var_info = find_var_info
996+
else:
997+
_var_info = var_info
983998

984-
alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :]
985-
alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[var_info.index]
986-
return alt_var
999+
alt_type = _var_info.name[(_var_info.name.index("_", 6) + 1) :]
1000+
alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[_var_info.index]
1001+
return type_cast(Variable, alt_var)
9871002

9881003
def find_among_fields(
9891004
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
@@ -1054,7 +1069,7 @@ def _remove_from_fields(
10541069
return field_info
10551070

10561071
def get_dependent_nodes(
1057-
self, i: Variable, seen: Optional[Set[int]] = None
1072+
self, i: Variable, seen: Optional[Set[Variable]] = None
10581073
) -> Set[Variable]:
10591074
if seen is None:
10601075
seen = {i}

0 commit comments

Comments
 (0)