Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Add episode length for Bit Flipping Env #125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/environments/examples/BitFlippingEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@ In Bit Flipping Environment we have n bits. The actions are 1 to n where executi
Refer [Hindsight Experience Replay paper](https://arxiv.org/pdf/1707.01495.pdf) for the motivation behind the environment.
"""

struct BitFlippingEnv <: AbstractEnv
mutable struct BitFlippingEnv <: AbstractEnv
N::Int
rng::AbstractRNG
state::BitArray{1}
goal_state::BitArray{1}
max_steps::Int
t::Int
end

function BitFlippingEnv(; N = 8, rng = Random.GLOBAL_RNG)
function BitFlippingEnv(; N = 8, T = N,rng = Random.GLOBAL_RNG)
state = bitrand(rng, N)
goal_state = bitrand(rng, N)
BitFlippingEnv(N, rng, state, goal_state)
max_steps = T
BitFlippingEnv(N, rng, state, goal_state, max_steps, 0)
end

Random.seed!(env::BitFlippingEnv, s) = Random.seed!(env.rng, s)
Expand All @@ -26,6 +29,7 @@ RLBase.action_space(env::BitFlippingEnv) = Base.OneTo(env.N)
RLBase.legal_action_space(env::BitFlippingEnv) = Base.OneTo(env.N)

function (env::BitFlippingEnv)(action::Int)
env.t += 1
if 1 <= action <= env.N
env.state[action] = !env.state[action]
nothing
Expand All @@ -39,9 +43,10 @@ RLBase.state(env::BitFlippingEnv, ::Observation) = env.state
RLBase.state(env::BitFlippingEnv, ::GoalState) = env.goal_state
RLBase.state_space(env::BitFlippingEnv, ::Observation) = Space(fill(false..true, env.N))
RLBase.state_space(env::BitFlippingEnv, ::GoalState) = Space(fill(false..true, env.N))
RLBase.is_terminated(env::BitFlippingEnv) = env.state == env.goal_state
RLBase.is_terminated(env::BitFlippingEnv) = (env.state == env.goal_state) || (env.t >= env.max_steps)

function RLBase.reset!(env::BitFlippingEnv)
env.t = 0
env.state .= bitrand(env.rng, env.N)
env.goal_state .= bitrand(env.rng, env.N)
end
Expand Down
6 changes: 5 additions & 1 deletion test/environments/examples/bit_flipping_env.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
@testset "bit_flipping_env" begin
rng = StableRNG(123)
env = BitFlippingEnv(; N = 7, rng = rng)
env = BitFlippingEnv(; N = 7, T = 2, rng = rng)
test_state = state(env, GoalState())
env(1)
@test is_terminated(env) == false
env(1)
@test is_terminated(env) == true
RLBase.test_interfaces!(env)
RLBase.test_runnable!(env)

Expand Down