19
19
"""
20
20
import collections
21
21
22
- from typing import Dict , List , Optional , Union
22
+ from functools import partial
23
+ from typing import Callable , Dict , Optional , TypeVar
23
24
24
25
import numpy as np
25
26
26
27
__all__ = ["DictToArrayBijection" ]
27
28
29
+
30
+ T = TypeVar ("T" )
31
+ PointType = Dict [str , np .ndarray ]
32
+
28
33
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
29
34
# each of the raveled variables.
30
35
RaveledVars = collections .namedtuple ("RaveledVars" , "data, point_map_info" )
@@ -38,7 +43,7 @@ class DictToArrayBijection:
38
43
"""
39
44
40
45
@staticmethod
41
- def map (var_dict : Dict [ str , np . ndarray ] ) -> RaveledVars :
46
+ def map (var_dict : PointType ) -> RaveledVars :
42
47
"""Map a dictionary of names and variables to a concatenated 1D array space."""
43
48
vars_info = tuple ((v , k , v .shape , v .dtype ) for k , v in var_dict .items ())
44
49
raveled_vars = [v [0 ].ravel () for v in vars_info ]
@@ -50,42 +55,41 @@ def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars:
50
55
51
56
@staticmethod
52
57
def rmap (
53
- array : RaveledVars , as_list : Optional [bool ] = False
54
- ) -> Union [Dict [str , np .ndarray ], List [np .ndarray ]]:
58
+ array : RaveledVars ,
59
+ start_point : Optional [PointType ] = None ,
60
+ ) -> PointType :
55
61
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
56
62
57
63
Parameters
58
64
==========
59
65
array
60
66
The array to map.
61
- as_list
62
- When ``True``, return a list of the original variables instead of a
63
- ``dict`` keyed each variable's name.
67
+ start_point
68
+ An optional dictionary of initial values.
69
+
64
70
"""
65
- if as_list :
66
- res = []
71
+ if start_point :
72
+ res = dict ( start_point )
67
73
else :
68
74
res = {}
69
75
70
76
if not isinstance (array , RaveledVars ):
71
- raise TypeError ("`apt ` must be a `RaveledVars` type" )
77
+ raise TypeError ("`array ` must be a `RaveledVars` type" )
72
78
73
79
last_idx = 0
74
80
for name , shape , dtype in array .point_map_info :
75
81
arr_len = np .prod (shape , dtype = int )
76
82
var = array .data [last_idx : last_idx + arr_len ].reshape (shape ).astype (dtype )
77
- if as_list :
78
- res .append (var )
79
- else :
80
- res [name ] = var
83
+ res [name ] = var
81
84
last_idx += arr_len
82
85
83
86
return res
84
87
85
88
@classmethod
86
- def mapf (cls , f ):
87
- """
88
- function f: DictSpace -> T to ArraySpace -> T
89
+ def mapf (cls , f : Callable [[PointType ], T ], start_point : Optional [PointType ] = None ) -> T :
90
+ """Create a callable that first maps back to ``dict`` inputs and then applies a function.
91
+
92
+ function f: DictSpace -> T to ArraySpace -> T
89
93
90
94
Parameters
91
95
----------
@@ -95,7 +99,7 @@ def mapf(cls, f):
95
99
-------
96
100
f: array -> T
97
101
"""
98
- return Compose (f , cls .rmap )
102
+ return Compose (f , partial ( cls .rmap , start_point = start_point ) )
99
103
100
104
101
105
class Compose :
0 commit comments