Skip to content

Commit 7f61196

Browse files
committed
Improve type inferrence of checkpoints, fix points for 1d x 2d
1 parent 4c6f85a commit 7f61196

File tree

4 files changed

+28
-5
lines changed

4 files changed

+28
-5
lines changed

src/Domains/ProductDomain.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ function pushappendpts!(ret, xx, pts)
1919
push!(ret,Vec(xx...))
2020
else
2121
for x in pts[1]
22-
pushappendpts!(ret,(xx...,x),pts[2:end])
22+
pushappendpts!(ret,(xx...,x...),pts[2:end])
2323
end
2424
end
2525
ret
2626
end
2727

2828
function checkpoints(d::ProductDomain)
29-
pts=map(checkpoints,d.domains)
30-
ret=Vector{Vec{length(d.domains),float(mapreduce(eltype,promote_type,d.domains))}}(undef, 0)
29+
pts = checkpoints.(d.domains)
30+
ret=Vector{Vec{sum(dimension.(d.domains)),float(promote_type(eltype.(eltype.(d.domains))...))}}(undef, 0)
3131

3232
pushappendpts!(ret,(),pts)
3333
ret

src/Multivariate/TensorSpace.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ TensorSpace(A::ProductDomain) = TensorSpace(tuple(map(Space,A.domains)...))
246246
(A::Space,B::TensorSpace) = TensorSpace(A,B.spaces...)
247247
(A::Space,B::Space) = TensorSpace(A,B)
248248

249-
domain(f::TensorSpace) = mapreduce(domain,×,f.spaces)
249+
domain(f::TensorSpace) = ×(domain.(f.spaces)...)
250250
Space(sp::ProductDomain) = TensorSpace(sp)
251251

252252
setdomain(sp::TensorSpace, d::ProductDomain) = TensorSpace(setdomain.(factors(sp), factors(d)))

src/Space.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ isambiguous(sp::Space) = isambiguous(rangetype(sp))
9898

9999
#TODO: should it default to canonicalspace?
100100
points(d::Space,n) = points(domain(d),n)
101+
points(d::Space) = points(d, dimension(d))
101102

102103

103104

test/SpacesTest.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ApproxFunBase, Test
2-
import ApproxFunBase: PointSpace, HeavisideSpace, PiecewiseSegment
2+
import ApproxFunBase: PointSpace, HeavisideSpace, PiecewiseSegment, dimension, Vec, checkpoints
33

44
@testset "Spaces" begin
55
@testset "PointSpace" begin
@@ -34,6 +34,10 @@ import ApproxFunBase: PointSpace, HeavisideSpace, PiecewiseSegment
3434

3535
S = HeavisideSpace([-1.0,0.0,1.0])
3636
@test Derivative(S) === Derivative(S,1)
37+
38+
a = HeavisideSpace(0:0.25:1)
39+
@test dimension(a) == 4
40+
@test @inferred(points(a)) == 0.125:0.25:0.875
3741
end
3842

3943
@testset "DiracDelta integration and differentiation" begin
@@ -57,4 +61,22 @@ import ApproxFunBase: PointSpace, HeavisideSpace, PiecewiseSegment
5761
@test h(2) == 0.3+1im
5862
@test h(3) == 3.3+1im
5963
end
64+
65+
@testset "Multivariate" begin
66+
a = HeavisideSpace(0:0.25:1)
67+
@test @inferred(dimension(a^2)) == dimension(a)^2
68+
@test @inferred(domain(a^2)) == domain(a)^2
69+
@test @inferred(points(a^2)) == vec(Vec.(points(a), points(a)'))
70+
@test @inferred(checkpoints(a^2)) == vec(Vec.(checkpoints(a)', checkpoints(a)))
71+
72+
aa2 = TensorSpace(a , a^2)
73+
@test dimension(aa2) == dimension(a)^3
74+
@test @inferred(domain(aa2)) == domain(a)^3
75+
@test @inferred(points(aa2)) == vec(Vec.(points(a), points(a)', reshape(points(a), 1,1,4)))
76+
@test @inferred(checkpoints(aa2)) == vec(Vec.(reshape(checkpoints(a), 1,1,length(checkpoints(a))), checkpoints(a)', checkpoints(a)))
77+
78+
@test dimension(a^3) == dimension(a)^3
79+
@test @inferred(domain(a^3)) == domain(a)^3
80+
@test_broken @inferred(points(a^3)) == vec(Vec.(points(a), points(a)', reshape(points(a), 1,1,4)))
81+
end
6082
end

0 commit comments

Comments
 (0)