Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit b9920ae

Browse files
brettkooncerxwei
authored andcommitted
add gym blackjack qlearning demo (#173)
1 parent edd734e commit b9920ae

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

Gym/Blackjack/main.swift

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Python
16+
import TensorFlow
17+
18+
let gym = Python.import("gym")
19+
let environment = gym.make("Blackjack-v0")
20+
21+
let iterationCount = 10000
22+
let learningPhase = iterationCount * 5 / 100
23+
24+
typealias Strategy = Bool
25+
26+
class BlackjackState {
27+
var playerSum: Int = 0
28+
var dealerCard: Int = 0
29+
var useableAce: Int = 0
30+
31+
init(pythonState: PythonObject) {
32+
self.playerSum = Int(pythonState[0]) ?? 0
33+
self.dealerCard = Int(pythonState[1]) ?? 0
34+
self.useableAce = Int(pythonState[2]) ?? 0
35+
}
36+
}
37+
38+
enum SolverType: CaseIterable {
39+
case random, markov, qlearning, normal
40+
}
41+
42+
class Solver {
43+
var Q: [[[[Float]]]] = []
44+
var alpha: Float = 0.5
45+
let gamma: Float = 0.2
46+
47+
let playerStateCount = 32 // 21 + 10 + 1 offset
48+
let dealerVisibleStateCount = 11 // 10 + 1 offset
49+
let aceStateCount = 2 // useable / not bool
50+
let playerActionCount = 2 // hit / stay
51+
52+
init() {
53+
Q = Array(repeating: Array(repeating: Array(repeating: Array(repeating: 0.0,
54+
count: playerActionCount),
55+
count: aceStateCount),
56+
count: dealerVisibleStateCount),
57+
count: playerStateCount)
58+
}
59+
60+
func updateQLearningStrategy(prior: BlackjackState,
61+
action: Int,
62+
reward: Int,
63+
post: BlackjackState) {
64+
let oldQ = Q[prior.playerSum][prior.dealerCard][prior.useableAce][action]
65+
let priorQ = (1 - alpha) * oldQ
66+
67+
let maxReward = max(Q[post.playerSum][post.dealerCard][post.useableAce][0],
68+
Q[post.playerSum][post.dealerCard][post.useableAce][1])
69+
let postQ = alpha * (Float(reward) + gamma * maxReward)
70+
71+
Q[prior.playerSum][prior.dealerCard][prior.useableAce][action] += priorQ + postQ
72+
}
73+
74+
func qLearningStrategy(observation: BlackjackState, iteration: Int) -> Strategy {
75+
let qLookup = Q[observation.playerSum][observation.dealerCard][observation.useableAce]
76+
let stayReward = qLookup[0]
77+
let hitReward = qLookup[1]
78+
79+
if iteration < Int.random(in: 1...learningPhase) {
80+
return randomStrategy()
81+
} else {
82+
// quit learning after initial phase
83+
if iteration > learningPhase { alpha = 0.0 }
84+
}
85+
86+
if hitReward == stayReward {
87+
return randomStrategy()
88+
} else {
89+
return hitReward > stayReward
90+
}
91+
}
92+
93+
func randomStrategy() -> Strategy {
94+
return Strategy.random()
95+
}
96+
97+
func markovStrategy(observation: BlackjackState) -> Strategy {
98+
// hit @ 80% probability unless over 18, in which case do the reverse
99+
let flip = Float.random(in: 0..<1)
100+
let threshHold: Float = 0.8
101+
102+
if observation.playerSum < 18 {
103+
return flip < threshHold
104+
} else {
105+
return flip > threshHold
106+
}
107+
}
108+
109+
func normalStrategyLookup(playerSum: Int) -> String {
110+
// see figure 11: https://ieeexplore.ieee.org/document/1299399/
111+
switch playerSum {
112+
case 10: return "HHHHHSSHHH"
113+
case 11: return "HHSSSSSSHH"
114+
case 12: return "HSHHHHHHHH"
115+
case 13: return "HSSHHHHHHH"
116+
case 14: return "HSHHHHHHHH"
117+
case 15: return "HSSHHHHHHH"
118+
case 16: return "HSSSSSHHHH"
119+
case 17: return "HSSSSHHHHH"
120+
case 18: return "SSSSSSSSSS"
121+
case 19: return "SSSSSSSSSS"
122+
case 20: return "SSSSSSSSSS"
123+
case 21: return "SSSSSSSSSS"
124+
default: return "HHHHHHHHHH"
125+
}
126+
}
127+
128+
func normalStrategy(observation: BlackjackState) -> Strategy {
129+
if observation.playerSum == 0 {
130+
return true
131+
}
132+
let lookupString = normalStrategyLookup(playerSum: observation.playerSum)
133+
return Array(lookupString)[observation.dealerCard - 1] == "H"
134+
}
135+
136+
func strategy(observation: BlackjackState, solver: SolverType, iteration: Int) -> Strategy {
137+
switch solver {
138+
case .random:
139+
return randomStrategy()
140+
case .markov:
141+
return markovStrategy(observation: observation)
142+
case .qlearning:
143+
return qLearningStrategy(observation: observation, iteration: iteration)
144+
case .normal:
145+
return normalStrategy(observation: observation)
146+
}
147+
}
148+
}
149+
150+
let learner = Solver()
151+
152+
for solver in SolverType.allCases {
153+
var totalReward = 0
154+
155+
for i in 1...iterationCount {
156+
var isDone = false
157+
environment.reset()
158+
159+
while !isDone {
160+
let priorState = BlackjackState(pythonState: environment._get_obs())
161+
let action: Int = learner.strategy(observation: priorState,
162+
solver: solver,
163+
iteration: i) ? 1 : 0
164+
165+
let (pythonPostState, reward, done, _) = environment.step(action).tuple4
166+
167+
if solver == .qlearning {
168+
let postState = BlackjackState(pythonState: pythonPostState)
169+
learner.updateQLearningStrategy(prior: priorState,
170+
action: action,
171+
reward: Int(reward) ?? 0,
172+
post: postState)
173+
}
174+
175+
if done == true {
176+
totalReward += Int(reward) ?? 0
177+
isDone = true
178+
}
179+
}
180+
}
181+
print("Solver: \(solver), Total reward: \(totalReward) / \(iterationCount) trials")
182+
}

Gym/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ This directory contains reinforcement learning algorithms in [OpenAI Gym](https:
1010

1111
> The agent controls the movement of a character in a grid world. Some tiles of the grid are walkable, and others lead to the agent falling into the water. Additionally, the movement direction of the agent is uncertain and only partially depends on the chosen direction. The agent is rewarded for finding a walkable path to a goal tile.
1212
13+
## [Blackjack](https://gym.openai.com/envs/Blackjack-v0)
14+
15+
> This demonstrates four different approaches to playing the game Blackjack, including a q-learning approach.
16+
1317
## Setup
1418

1519
To begin, you'll need the [latest version of Swift for
@@ -26,4 +30,5 @@ To build and run the models, run:
2630
```bash
2731
swift run Gym-CartPole
2832
swift run Gym-FrozenLake
33+
swift run Gym-Blackjack
2934
```

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ let package = Package(
1818
.target(name: "Catch", path: "Catch"),
1919
.target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"),
2020
.target(name: "Gym-CartPole", path: "Gym/CartPole"),
21+
.target(name: "Gym-Blackjack", path: "Gym/Blackjack"),
2122
.target(name: "MNIST", path: "MNIST"),
2223
.target(name: "MiniGo", path: "MiniGo", exclude: ["main.swift"]),
2324
.target(name: "MiniGoDemo", dependencies: ["MiniGo"], path: "MiniGo",

0 commit comments

Comments
 (0)