Skip to content

Commit 6a91055

Browse files
committed
simplify iterate
1 parent f6b7258 commit 6a91055

File tree

3 files changed

+19
-36
lines changed

3 files changed

+19
-36
lines changed

src/common.jl

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -664,47 +664,34 @@ Iteration =#
664664
# `pairs(p)`: `i => pᵢ` possibly skipping over values of `i` with `pᵢ == 0` (SparsePolynomial)
665665
# and possibly non ordered (SparsePolynomial)
666666
# `monomials(p)`: iterates over pᵢ ⋅ basis(p, i) i ∈ keys(p)
667-
function Base.iterate(p::AbstractPolynomial, state=nothing)
668-
i = firstindex(p)
669-
if state == nothing
670-
return (p[i], i)
671-
else
672-
j = lastindex(p)
673-
if i <= state < j
674-
return (p[state+1], state+1)
675-
end
676-
return nothing
677-
end
667+
function _iterate(p, state)
668+
firstindex(p) <= state <= lastindex(p) || return nothing
669+
return p[state], state+1
678670
end
671+
Base.iterate(p::AbstractPolynomial, state = firstindex(p)) = _iterate(p, state)
679672

680673
# pairs map i -> aᵢ *possibly* skipping over ai == 0
681674
# cf. abstractdict.jl
682-
struct PolynomialKeys{P}
675+
struct PolynomialKeys{P} <: AbstractSet{Int}
683676
p::P
684677
end
685-
struct PolynomialValues{P}
678+
struct PolynomialValues{P, T} <: AbstractSet{T}
686679
p::P
680+
681+
PolynomialValues{P}(p::P) where {P} = new{P, eltype(p)}(p)
682+
PolynomialValues(p::P) where {P} = new{P, eltype(p)}(p)
687683
end
688684
Base.keys(p::AbstractPolynomial) = PolynomialKeys(p)
689685
Base.values(p::AbstractPolynomial) = PolynomialValues(p)
690686
Base.length(p::PolynomialValues) = length(p.p.coeffs)
691687
Base.length(p::PolynomialKeys) = length(p.p.coeffs)
692688
Base.size(p::Union{PolynomialValues, PolynomialKeys}) = (length(p),)
693-
function Base.iterate(v::PolynomialKeys, state=nothing)
694-
i = firstindex(v.p)
695-
state==nothing && return (i, i)
696-
j = lastindex(v.p)
697-
i <= state < j && return (state+1, state+1)
698-
return nothing
689+
function Base.iterate(v::PolynomialKeys, state = firstindex(v.p))
690+
firstindex(v.p) <= state <= lastindex(v.p) || return nothing
691+
return state, state+1
699692
end
700693

701-
function Base.iterate(v::PolynomialValues, state=nothing)
702-
i = firstindex(v.p)
703-
state==nothing && return (v.p[i], i)
704-
j = lastindex(v.p)
705-
i <= state < j && return (v.p[state+1], state+1)
706-
return nothing
707-
end
694+
Base.iterate(v::PolynomialValues, state = firstindex(v.p)) = _iterate(v.p, state)
708695

709696

710697
# iterate over monomials of the polynomial

src/polynomials/Poly.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,9 @@ _eltype(::Type{<:Poly{T}}) where {T} = T
4747
_eltype(::Type{Poly}) = Float64
4848

4949
# when interating over poly return monomials
50-
function Base.iterate(p::Poly, state=nothing)
51-
i = 0
52-
state == nothing && return (p[i]*one(p), i)
53-
j = degree(p)
54-
s = state + 1
55-
i <= state < j && return (p[s]*Polynomials.basis(p,s), s)
56-
return nothing
50+
function Base.iterate(p::Poly, state = firstindex(p))
51+
firstindex(p) <= state <= lastindex(p) || return nothing
52+
return p[state] * Polynomials.basis(p,state), state+1
5753
end
5854
Base.collect(p::Poly) = [pᵢ for pᵢ p]
5955

src/polynomials/SparsePolynomial.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ function Base.setindex!(p::SparsePolynomial, value::Number, idx::Int)
128128
end
129129

130130

131-
Base.firstindex(p::SparsePolynomial) = sort(collect(keys(p.coeffs)), by=x->x[1])[1]
132-
Base.lastindex(p::SparsePolynomial) = sort(collect(keys(p.coeffs)), by=x->x[1])[end]
133-
Base.eachindex(p::SparsePolynomial) = sort(collect(keys(p.coeffs)), by=x->x[1])
131+
Base.firstindex(p::SparsePolynomial) = sort!(collect(keys(p.coeffs)), by=x->x[1])[1]
132+
Base.lastindex(p::SparsePolynomial) = sort!(collect(keys(p.coeffs)), by=x->x[1])[end]
133+
Base.eachindex(p::SparsePolynomial) = sort!(collect(keys(p.coeffs)), by=x->x[1])
134134

135135
# pairs iterates only over non-zero
136136
# inherits order for underlying dictionary

0 commit comments

Comments
 (0)