|
| 1 | +export BitFlippingEnv, GoalState |
| 2 | + |
| 3 | +""" |
| 4 | +In Bit Flipping Environment we have n bits. The actions are 1 to n where executing i-th action flips the i-th bit of the state. For every episode we sample uniformly and inital state as well as the target state. |
| 5 | +
|
| 6 | +Refer [Hindsight Experience Replay paper](https://arxiv.org/pdf/1707.01495.pdf) for the motivation behind the environment. |
| 7 | +""" |
| 8 | + |
| 9 | +struct GoalState{T} <: RLBase.AbstractStateStyle end |
| 10 | +GoalState() = GoalState{Any}() |
| 11 | + |
| 12 | +struct BitFlippingEnv <: AbstractEnv |
| 13 | + N::Int |
| 14 | + rng::AbstractRNG |
| 15 | + state::BitArray{1} |
| 16 | + goal_state::BitArray{1} |
| 17 | +end |
| 18 | + |
| 19 | +function BitFlippingEnv(; N = 8) |
| 20 | + rng = Random.GLOBAL_RNG |
| 21 | + state = bitrand(rng,N) |
| 22 | + goal_state = bitrand(rng,N) |
| 23 | + BitFlippingEnv(N,rng,state,goal_state) |
| 24 | +end |
| 25 | + |
| 26 | +function BitFlippingEnv(; N = 8, rng = Random.GLOBAL_RNG) |
| 27 | + state = bitrand(rng,N) |
| 28 | + goal_state = bitrand(rng,N) |
| 29 | + BitFlippingEnv(N,rng,state,goal_state) |
| 30 | +end |
| 31 | + |
| 32 | +Random.seed!(env::BitFlippingEnv, s) = Random.seed!(env.rng, s) |
| 33 | + |
| 34 | +RLBase.action_space(env::BitFlippingEnv) = Base.OneTo(env.N) |
| 35 | + |
| 36 | +RLBase.legal_action_space(env::BitFlippingEnv) = Base.OneTo(env.N) |
| 37 | + |
| 38 | +function (env::BitFlippingEnv)(action::Int) |
| 39 | + if 1 <= action <= env.N |
| 40 | + env.state[action] = !env.state[action] |
| 41 | + nothing |
| 42 | + else |
| 43 | + @error "Invalid Action" |
| 44 | + end |
| 45 | +end |
| 46 | + |
| 47 | +RLBase.state(env::BitFlippingEnv) = state(env::BitFlippingEnv, Observation{BitArray{1}}()) |
| 48 | +RLBase.state(env::BitFlippingEnv, ::Observation) = env.state |
| 49 | +RLBase.state(env::BitFlippingEnv, ::GoalState) = env.goal_state |
| 50 | +RLBase.state_space(env::BitFlippingEnv, ::Observation) = Space(fill(false..true,env.N)) |
| 51 | +RLBase.state_space(env::BitFlippingEnv, ::GoalState) = Space(fill(false..true,env.N)) |
| 52 | +RLBase.is_terminated(env::BitFlippingEnv) = env.state == env.goal_state |
| 53 | + |
| 54 | +function RLBase.reset!(env::BitFlippingEnv) |
| 55 | + env.state .= bitrand(env.rng,env.N) |
| 56 | + env.goal_state .= bitrand(env.rng,env.N) |
| 57 | +end |
| 58 | + |
| 59 | +function RLBase.reward(env::BitFlippingEnv) |
| 60 | + if env.state == env.goal_state |
| 61 | + 0.0 |
| 62 | + else |
| 63 | + -1.0 |
| 64 | + end |
| 65 | +end |
| 66 | + |
| 67 | +RLBase.NumAgentStyle(::BitFlippingEnv) = SINGLE_AGENT |
| 68 | +RLBase.DynamicStyle(::BitFlippingEnv) = SEQUENTIAL |
| 69 | +RLBase.ActionStyle(::BitFlippingEnv) = MINIMAL_ACTION_SET |
| 70 | +RLBase.InformationStyle(::BitFlippingEnv) = PERFECT_INFORMATION |
| 71 | +RLBase.StateStyle(::BitFlippingEnv) = (Observation{BitArray{1}}(), GoalState{BitArray{1}}()) |
| 72 | +RLBase.RewardStyle(::BitFlippingEnv) = STEP_REWARD |
| 73 | +RLBase.UtilityStyle(::BitFlippingEnv) = GENERAL_SUM |
| 74 | +RLBase.ChanceStyle(::BitFlippingEnv) = STOCHASTIC |
0 commit comments