Skip to content

Commit e004bef

Browse files
committed
decouple vec_to_tril functions
1 parent 67bcfc4 commit e004bef

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

src/ReinforcementLearningCore/src/utils/networks.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -249,24 +249,42 @@ function (model::CovGaussianNetwork)(state::AbstractMatrix, action::AbstractMatr
249249
return dropdims(output, dims=2)
250250
end
251251

252+
"""
253+
cholesky_matrix_to_vector_index(i, j)
254+
255+
Return the position in a cholesky_vec (of length da) of the element of the lower triangular matrix at coordinates (i,j).
256+
257+
For example if `cholesky_vec = [1,2,3,4,5,6]`, the corresponding lower triangular matrix is
258+
```
259+
L = [1 0 0
260+
2 4 0
261+
3 5 6]
262+
```
263+
and `cholesky_matrix_to_vector_index(3, 2) == 5`
264+
265+
"""
266+
cholesky_matrix_to_vector_index(i, j, da) = ((2da - j) * (j - 1)) ÷ 2 + i
267+
softplusbeta(x, beta = 10f0) = log(exp(x/beta) +1)*beta #a softer softplus to avoid vanishing values
268+
269+
function cholesky_columns(cholesky_vec, j, batch_size, da) #return a slice (da x 1 x batchsize) containing the jth columns of the lower triangular cholesky decomposition of the covariance
270+
diag_idx = cholesky_matrix_to_vector_index(j, j, da)
271+
tc_diag = softplusbeta.(cholesky_vec[diag_idx:diag_idx, :, :]) .+ 1f-5
272+
other_idxs = cholesky_matrix_to_vector_index(j, j, da)+1:cholesky_matrix_to_vector_index(j + 1, j + 1, da)-1 #indices of elements between two diagonal elements
273+
tc_other = cholesky_vec[other_idxs, :, :]
274+
zs = ignore_derivatives() do
275+
zs = similar(cholesky_vec, da - size(tc_other, 1) - 1, 1, batch_size)
276+
zs .= zero(eltype(cholesky_vec))
277+
return zs
278+
end
279+
[zs; tc_diag; tc_other]
280+
end
281+
252282
"""
253283
Transform a vector containing the non-zero elements of a lower triangular da x da matrix into that matrix.
254284
"""
255285
function vec_to_tril(cholesky_vec, da)
256-
batch_size = size(cholesky_vec, 3)
257-
c2idx(i, j) = ((2da - j) * (j - 1)) ÷ 2 + i #return the position in cholesky_vec of the element of the triangular matrix at coordinates (i,j)
258-
softplusbeta(x) = log(exp(0.1f0 * x) +1)*10f0 #a softer softplus to avoid vanishing values
259-
function f(j) #return a slice (da x 1 x batchsize) containing the jth columns of the lower triangular cholesky decomposition of the covariance
260-
tc_diag = softplusbeta.(cholesky_vec[c2idx(j, j):c2idx(j, j), :, :]) .+ 1f-5
261-
tc_other = cholesky_vec[c2idx(j, j)+1:c2idx(j + 1, j + 1)-1, :, :]
262-
zs = ignore_derivatives() do
263-
zs = similar(cholesky_vec, da - size(tc_other, 1) - 1, 1, batch_size)
264-
zs .= zero(eltype(cholesky_vec))
265-
return zs
266-
end
267-
[zs; tc_diag; tc_other]
268-
end
269-
return mapreduce(f, hcat, 1:da)
286+
batch_size = size(cholesky_vec, 3)
287+
return mapreduce(j->cholesky_columns(cholesky_vec, j, batch_size, da), hcat, 1:da)
270288
end
271289

272290
#####

0 commit comments

Comments
 (0)