Skip to content

Commit 90db327

Browse files
author
Will Tebbutt
committed
Eltype in ones
1 parent 79fb7fa commit 90db327

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/basekernels/fbm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ _mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2))
5050

5151
function kernelmatrix::FBMKernel, x::AbstractVector)
5252
modx = _mod(x)
53-
modx_wide = modx * ones(1, length(modx)) # ad perf hack -- is unit tested
53+
modx_wide = modx * ones(eltype(modx), 1, length(modx)) # ad perf hack -- is unit tested
5454
modxx = pairwise(SqEuclidean(), x)
5555
return _fbm.(modx_wide, modx_wide', modxx, only.h))
5656
end
@@ -64,8 +64,8 @@ end
6464

6565
function kernelmatrix::FBMKernel, x::AbstractVector, y::AbstractVector)
6666
modxy = pairwise(SqEuclidean(), x, y)
67-
modx_wide = _mod(x) * ones(1, length(y)) # ad perf hack -- is unit tested
68-
mody_wide = _mod(y) * ones(1, length(x)) # ad perf hack -- is unit tested
67+
modx_wide = _mod(x) * ones(eltype(modxy), 1, length(y)) # ad perf hack -- is unit tested
68+
mody_wide = _mod(y) * ones(eltype(modxy), 1, length(x)) # ad perf hack -- is unit tested
6969
return _fbm.(modx_wide, mody_wide', modxy, only.h))
7070
end
7171

src/kernels/normalizedkernel.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ end
2121

2222
function kernelmatrix::NormalizedKernel, x::AbstractVector, y::AbstractVector)
2323
x_diag = kernelmatrix_diag.kernel, x)
24-
x_diag_wide = x_diag * ones(1, length(y)) # ad perf hack. Is unit tested.
24+
x_diag_wide = x_diag * ones(eltype(x_diag), 1, length(y)) # ad perf hack. Is unit tested
2525
y_diag = kernelmatrix_diag.kernel, y)
26-
y_diag_wide = y_diag * ones(1, length(x)) # ad perf hack. Is unit tested.
26+
y_diag_wide = y_diag * ones(eltype(y_diag), 1, length(x)) # ad perf hack. Is unit tested
2727
return kernelmatrix.kernel, x, y) ./ sqrt.(x_diag_wide .* y_diag_wide')
2828
end
2929

3030
function kernelmatrix::NormalizedKernel, x::AbstractVector)
3131
x_diag = kernelmatrix_diag.kernel, x)
32-
x_diag_wide = x_diag * ones(1, length(x_diag)) # ad perf hack. Is unit tested.
32+
x_diag_wide = x_diag * ones(eltype(x_diag), 1, length(x)) # ad perf hack. Is unit tested
3333
return kernelmatrix.kernel, x) ./ sqrt.(x_diag_wide .* x_diag_wide')
3434
end
3535

0 commit comments

Comments
 (0)