13
13
14
14
@forward Normalizer. os OnlineStats. mean, OnlineStats. std, Base. iterate, normalize, Base. length
15
15
16
-
17
-
18
16
# Treats last dim as batch dim
19
17
function OnlineStats. fit! (n:: Normalizer , data:: AbstractArray )
20
18
for d in eachslice (data, dims = ndims (data))
@@ -47,6 +45,47 @@ function OnlineStats.fit!(n::Normalizer, data::Number)
47
45
n
48
46
end
49
47
48
+ """
49
+ normalize(os::Moments, x)
50
+
51
+ Given an Moments estimate of the elements of x, a vector of scalar traces,
52
+ normalizes x elementwise to zero mean, and unit variance.
53
+ """
54
+ function normalize (os:: Moments , x)
55
+ T = eltype (x)
56
+ m, s = T (mean (os)), T (std (os))
57
+ return (x .- m) ./ s
58
+ end
59
+
60
+ """
61
+ normalize(os::Group{<:AbstractVector{<:Moments}}, x)
62
+
63
+ Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of each element of x,
64
+ normalizes each element of x to zero mean, and unit variance. Treats the last dimension as a batch dimension if `ndims(x) >= 2`.
65
+ """
66
+ function normalize (os:: Group{<:AbstractVector{<:Moments}} , x:: AbstractVector )
67
+ T = eltype (x)
68
+ m = [T (mean (stat)) for stat in os]
69
+ s = [T (std (stat)) for stat in os]
70
+ return (x .- m) ./ s
71
+ end
72
+
73
+ function normalize (os:: Group{<:AbstractVector{<:Moments}} , x:: AbstractArray )
74
+ xn = similar (x)
75
+ for (i, slice) in enumerate (eachslice (x, dims = ndims (x)))
76
+ xn[repeat ([:], ndims (x)- 1 )... , i] .= reshape (normalize (os, vec (slice)), size (x)[1 : end - 1 ]. .. )
77
+ end
78
+ return xn
79
+ end
80
+
81
+ function normalize (os:: Group{<:AbstractVector{<:Moments}} , x:: AbstractVector{<:AbstractArray} )
82
+ xn = similar (x)
83
+ for (i,el) in enumerate (x)
84
+ xn[i] = normalize (os, vec (el))
85
+ end
86
+ return xn
87
+ end
88
+
50
89
"""
51
90
scalar_normalizer(;weights = OnlineStats.EqualWeight())
52
91
@@ -65,7 +104,6 @@ See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/sta
65
104
"""
66
105
array_normalizer (size:: NTuple{N,Int} ; weight:: Weight = EqualWeight ()) where N = Normalizer (Group ([Moments (weight = weight) for _ in 1 : prod (size)]))
67
106
68
-
69
107
"""
70
108
NormalizedTrace(trace::Trace, normalizer::Normalizer)
71
109
@@ -87,91 +125,32 @@ t = Trajectory(
87
125
)
88
126
89
127
"""
90
- struct NormalizedTraces{T <: AbstractTraces , names, N} # <: AbstractTraces
128
+ struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT} , normnames, N} <: AbstractTraces{names, TT}
91
129
traces:: T
92
- normalizers:: NamedTuple{names, N}
93
- function NormalizedTraces (traces:: AbstractTraces ; trace_normalizer_pairs... )
94
- for key in keys (trace_normalizer_pairs)
95
- @assert key in keys (traces) " Traces do not have key $key , valid keys are $(keys (traces)) ."
96
- end
97
- nt = (; trace_normalizer_pairs... )
98
- new {typeof(traces), keys(nt), typeof(values(nt))} (traces, nt)
99
- end
130
+ normalizers:: NamedTuple{normnames, N}
100
131
end
101
132
102
- @forward NormalizedTraces. traces Base. length, Base. lastindex, Base. firstindex, Base. getindex, Base. view, Base. pop!, Base. popfirst!, Base. empty!, Base. prepend!, Base. pushfirst!
103
-
104
- function Base. push! (nt:: NormalizedTraces , x)
105
- for key in intersect (keys (nt. normalizers), keys (x))
106
- fit! (nt. normalizers[key], x[key])
133
+ function NormalizedTraces (traces:: AbstractTraces{names, TT} ; trace_normalizer_pairs... ) where names where TT
134
+ for key in keys (trace_normalizer_pairs)
135
+ @assert key in keys (traces) " Traces do not have key $key , valid keys are $(keys (traces)) ."
107
136
end
108
- push! (nt. traces, x)
137
+ nt = (; trace_normalizer_pairs... )
138
+ NormalizedTraces {names, TT, typeof(traces), keys(nt), typeof(values(nt))} (traces, nt)
109
139
end
110
140
111
- function Base. append! (nt:: NormalizedTraces , x)
112
- for key in intersect (keys (nt. normalizers), keys (x))
113
- fit! (nt. normalizers[key], x[key])
114
- end
115
- append! (nt. traces, x)
116
- end
117
-
118
- """
119
- normalize!(os::Moments, x)
120
-
121
- Given an Moments estimate of the elements of x, a vector of scalar traces,
122
- normalizes x elementwise to zero mean, and unit variance.
123
- """
124
- function normalize (os:: Moments , x:: AbstractVector )
125
- T = eltype (x)
126
- m, s = T (mean (os)), T (std (os))
127
- return (x .- m) ./ s
128
- end
129
-
130
- """
131
- normalize!(os::Group{<:AbstractVector{<:Moments}}, x)
132
-
133
- Given an os::Group{<:Tuple{Moments}}, that is, a multivariate estimator of the moments of each element of x,
134
- normalizes each element of x to zero mean, and unit variance. Treats the last dimension as a batch dimension if `ndims(x) >= 2`.
135
- """
136
- function normalize (os:: Group{<:AbstractVector{<:Moments}} , x:: AbstractVector )
137
- T = eltype (x)
138
- m = [T (mean (stat)) for stat in os]
139
- s = [T (std (stat)) for stat in os]
140
- return (x .- m) ./ s
141
- end
141
+ @forward NormalizedTraces. traces Base. length, Base. size, Base. lastindex, Base. firstindex, Base. getindex, Base. view, Base. pop!, Base. popfirst!, Base. empty!, Base. parent
142
142
143
- function normalize (os:: Group{<:AbstractVector{<:Moments}} , x:: AbstractArray )
144
- xn = similar (x)
145
- for (i, slice) in enumerate (eachslice (x, dims = ndims (x)))
146
- xn[repeat ([:], ndims (x)- 1 )... , i] .= reshape (normalize (os, vec (slice)), size (x)[1 : end - 1 ]. .. )
147
- end
148
- return xn
149
- end
150
-
151
- function normalize (os:: Group{<:AbstractVector{<:Moments}} , x:: AbstractVector{<:AbstractArray} )
152
- xn = similar (x)
153
- for (i,el) in enumerate (x)
154
- xn[i] = normalize (os, vec (el))
143
+ for f in (:push! , :pushfirst! , :append! , :prepend! )
144
+ @eval function Base. $f (nt:: NormalizedTraces , x:: NamedTuple )
145
+ for key in intersect (keys (nt. normalizers), keys (x))
146
+ fit! (nt. normalizers[key], x[key])
147
+ end
148
+ $ f (nt. traces, x)
155
149
end
156
- return xn
157
150
end
158
151
159
- function fetch (nt:: NormalizedTraces , inds)
160
- batch = fetch (nt. traces, inds)
161
- return (; (key => (key in keys (nt. normalizers) ? normalize (nt. normalizers[key]. os, data) : data) for (key, data) in pairs (batch)). .. )
152
+ function sample (s:: BatchSampler , nt:: NormalizedTraces , names)
153
+ inds = rand (s. rng, 1 : length (nt), s. batch_size)
154
+ maybe_normalize (data, key) = key in keys (nt. normalizers) ? normalize (nt. normalizers[key], data) : data
155
+ NamedTuple {names} (s. transformer (maybe_normalize (nt[x][inds], x) for x in names))
162
156
end
163
-
164
- #=
165
- abstract type AbstractFoo{T} end
166
-
167
- struct Foo{T} <: AbstractFoo{T}
168
- x::T
169
- end
170
-
171
- struct Baz{F <: Foo} <: supertype(Type(F))
172
- foo::F
173
- end
174
-
175
- struct Bal{F <: AbstractFoo{T where T}} <: AbstractFoo{T}
176
- foo::F{T}
177
- end=#
0 commit comments