Skip to content

Commit 867f2de

Browse files
authored
Merge pull request #19 from phelps-sg/openai-scala-client-15
RetryHelpers trait to implement non-blocking retries
2 parents 2348e8e + 82603af commit 867f2de

File tree

7 files changed

+346
-11
lines changed

7 files changed

+346
-11
lines changed

README.md

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ This extension of the standard chat completion is currently supported by the fol
242242
- `gpt-3.5-turbo-0613` (default), `gpt-3.5-turbo-16k-0613`, `gpt-4-0613`, and `gpt-4-32k-0613`.
243243

244244

245-
**✔️ Important Note**: After you are done using the service, you should close it by calling `service.close`. Otherwise, the underlying resources/threads won't be released.
245+
**✔️ Important Note**: After you are done using the service, you should close it by calling (🔥 new) `service.close`. Otherwise, the underlying resources/threads won't be released.
246246

247247
**III. Using multiple services (🔥 new)**
248248

@@ -276,17 +276,48 @@ This extension of the standard chat completion is currently supported by the fol
276276
}
277277
```
278278

279+
- Create completion and retry on transient errors (e.g. rate limit error)
280+
```scala
281+
import akka.actor.{ActorSystem, Scheduler}
282+
import io.cequence.openaiscala.RetryHelpers
283+
import io.cequence.openaiscala.RetryHelpers.RetrySettings
284+
import io.cequence.openaiscala.domain.{ChatRole, MessageSpec}
285+
import io.cequence.openaiscala.service.{OpenAIService, OpenAIServiceFactory}
286+
287+
import javax.inject.Inject
288+
import scala.concurrent.duration.DurationInt
289+
import scala.concurrent.{ExecutionContext, Future}
290+
291+
class MyCompletionService @Inject() (
292+
val actorSystem: ActorSystem,
293+
implicit val ec: ExecutionContext,
294+
implicit val scheduler: Scheduler
295+
)(val apiKey: String)
296+
extends RetryHelpers {
297+
val service: OpenAIService = OpenAIServiceFactory(apiKey)
298+
implicit val retrySettings: RetrySettings =
299+
RetrySettings(interval = 10.seconds)
300+
301+
def ask(prompt: String): Future[String] =
302+
for {
303+
completion <- service
304+
.createChatCompletion(
305+
List(MessageSpec(ChatRole.User, prompt))
306+
)
307+
.retryOnFailure
308+
} yield completion.choices.head.message.content
309+
}
310+
```
311+
279312
- Retries with `OpenAIRetryServiceAdapter`
280313

281314
```scala
282315
val serviceAux = ... // your service
283316

317+
implicit val retrySettings: RetrySettings =
318+
RetrySettings(maxAttempts = 10).constantInterval(10.seconds)
284319
// wrap it with the retry adapter
285-
val service = OpenAIRetryServiceAdapter(
286-
serviceAux,
287-
maxAttempts = 10,
288-
sleepOnFailureMs = Some(1000) // 1 second
289-
)
320+
val service = OpenAIRetryServiceAdapter(serviceAux)
290321

291322
service.listModels.map { models =>
292323
models.foreach(println)

build.sbt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,35 @@ import sbt.Keys.test
44
val scala212 = "2.12.18"
55
val scala213 = "2.13.11"
66
val scala3 = "3.2.2"
7+
val AkkaVersion = "2.6.1"
78

89
ThisBuild / organization := "io.cequence"
910
ThisBuild / scalaVersion := scala212
1011
ThisBuild / version := "0.4.0"
1112
ThisBuild / isSnapshot := false
1213

14+
lazy val commonSettings = Seq(
15+
libraryDependencies += "org.scalactic" %% "scalactic" % "3.2.16",
16+
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.16" % Test,
17+
libraryDependencies += "org.mockito" %% "mockito-scala-scalatest" % "1.17.14" % Test,
18+
libraryDependencies += "com.typesafe.akka" %% "akka-actor-testkit-typed" % AkkaVersion % Test
19+
)
20+
1321
lazy val core = (project in file("openai-core"))
22+
.settings(commonSettings: _*)
1423

1524
lazy val client = (project in file("openai-client"))
25+
.settings(commonSettings: _*)
1626
.dependsOn(core)
1727
.aggregate(core)
1828

1929
lazy val client_stream = (project in file("openai-client-stream"))
30+
.settings(commonSettings: _*)
2031
.dependsOn(client)
2132
.aggregate(client)
2233

2334
lazy val guice = (project in file("openai-guice"))
35+
.settings(commonSettings: _*)
2436
.dependsOn(client)
2537
.aggregate(client_stream)
2638

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package io.cequence.openaiscala
2+
3+
import akka.actor.{ActorSystem, Scheduler}
4+
import akka.pattern.after
5+
import akka.stream.Materializer
6+
import io.cequence.openaiscala.RetryHelpers.{RetrySettings, retry}
7+
8+
import scala.concurrent.duration.{DurationInt, FiniteDuration}
9+
import scala.concurrent.{ExecutionContext, Future}
10+
import scala.util.control.NonFatal
11+
12+
object RetryHelpers {
13+
private[openaiscala] def delay(
14+
n: Integer
15+
)(implicit retrySettings: RetrySettings): FiniteDuration =
16+
FiniteDuration(
17+
scala.math.round(
18+
retrySettings.delayOffset.length + scala.math.pow(
19+
retrySettings.delayBase,
20+
n.doubleValue()
21+
)
22+
),
23+
retrySettings.delayOffset.unit
24+
)
25+
26+
private[openaiscala] def retry[T](
27+
attempt: () => Future[T],
28+
attempts: Int
29+
)(implicit
30+
ec: ExecutionContext,
31+
scheduler: Scheduler,
32+
retrySettings: RetrySettings
33+
): Future[T] = {
34+
try {
35+
if (attempts > 0) {
36+
attempt().recoverWith { case Retryable(_) =>
37+
after(delay(attempts), scheduler) {
38+
retry(attempt, attempts - 1)
39+
}
40+
}
41+
} else {
42+
attempt()
43+
}
44+
} catch {
45+
case NonFatal(error) => Future.failed(error)
46+
}
47+
}
48+
49+
final case class RetrySettings(
50+
maxRetries: Int = 5,
51+
delayOffset: FiniteDuration = 2.seconds,
52+
delayBase: Double = 2
53+
) {
54+
def constantInterval(interval: FiniteDuration): RetrySettings =
55+
copy(delayBase = 0).copy(delayOffset = interval)
56+
}
57+
58+
object RetrySettings {
59+
def apply(interval: FiniteDuration): RetrySettings =
60+
RetrySettings().constantInterval(
61+
interval
62+
)
63+
64+
}
65+
66+
}
67+
68+
trait RetryHelpers {
69+
70+
def actorSystem: ActorSystem
71+
implicit val materializer: Materializer = Materializer(actorSystem)
72+
73+
implicit class FutureWithRetry[T](f: Future[T]) {
74+
75+
def retryOnFailure(implicit
76+
retrySettings: RetrySettings,
77+
ec: ExecutionContext,
78+
scheduler: Scheduler
79+
): Future[T] = {
80+
retry(() => f, retrySettings.maxRetries)
81+
}
82+
}
83+
84+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package io.cequence.openaiscala.service
2+
3+
import akka.actor.{ActorSystem, Scheduler}
4+
import io.cequence.openaiscala.RetryHelpers
5+
import io.cequence.openaiscala.RetryHelpers.RetrySettings
6+
7+
import scala.concurrent.{ExecutionContext, Future}
8+
9+
private class OpenAIRetryServiceAdapter(
10+
underlying: OpenAIService,
11+
val actorSystem: ActorSystem,
12+
implicit val ec: ExecutionContext,
13+
implicit val retrySettings: RetrySettings,
14+
implicit val scheduler: Scheduler
15+
) extends OpenAIServiceWrapper
16+
with RetryHelpers {
17+
18+
override def close: Unit =
19+
underlying.close
20+
21+
override protected def wrap[T](
22+
fun: OpenAIService => Future[T]
23+
): Future[T] = {
24+
fun(underlying).retryOnFailure
25+
}
26+
}
27+
28+
object OpenAIRetryServiceAdapter {
29+
def apply(underlying: OpenAIService)(implicit
30+
ec: ExecutionContext,
31+
retrySettings: RetrySettings,
32+
scheduler: Scheduler,
33+
actorSystem: ActorSystem
34+
): OpenAIService =
35+
new OpenAIRetryServiceAdapter(
36+
underlying,
37+
actorSystem,
38+
ec,
39+
retrySettings,
40+
scheduler
41+
)
42+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package io.cequence.openaiscala
2+
3+
import akka.actor.{ActorSystem, Scheduler}
4+
import akka.testkit.TestKit
5+
import io.cequence.openaiscala.RetryHelpers.{RetrySettings, delay, retry}
6+
import org.mockito.scalatest.MockitoSugar
7+
import org.scalatest.RecoverMethods._
8+
import org.scalatest.concurrent.ScalaFutures
9+
import org.scalatest.matchers.should.Matchers
10+
import org.scalatest.wordspec.AnyWordSpecLike
11+
import org.scalatest.{BeforeAndAfterAll, Succeeded}
12+
13+
import scala.concurrent.ExecutionContext.Implicits.global
14+
import scala.concurrent.duration._
15+
import scala.concurrent.{Future, Promise}
16+
17+
class RetryHelpersSpec
18+
extends TestKit(ActorSystem("RetryHelpersSpec"))
19+
with AnyWordSpecLike
20+
with Matchers
21+
with BeforeAndAfterAll
22+
with MockitoSugar
23+
with ScalaFutures
24+
with RetryHelpers {
25+
val successfulResult = 42
26+
27+
implicit val patience: PatienceConfig = PatienceConfig(timeout = 10.seconds)
28+
29+
override def afterAll(): Unit = {
30+
TestKit.shutdownActorSystem(system)
31+
}
32+
33+
"RetrySettings" should {
34+
"allow easy configuration of a constant interval" in {
35+
val interval = 10.seconds
36+
val result = RetrySettings(interval)
37+
result.delayBase shouldBe 0
38+
result.delayOffset shouldBe interval
39+
}
40+
}
41+
42+
"RetryHelpers" should {
43+
44+
"retry when encountering a retryable failure" in {
45+
val attempts = 2
46+
val ex = new OpenAIScalaClientTimeoutException("retryable test exception")
47+
testWithException(ex) { (mockRetryable, result) =>
48+
result.futureValue shouldBe successfulResult
49+
verifyNumAttempts(n = attempts, result, mockRetryable)
50+
}
51+
}
52+
53+
"not retry when encountering a non-retryable failure" in {
54+
val ex = new OpenAIScalaClientUnknownHostException(
55+
"non retryable test exception"
56+
)
57+
testWithException(ex) { (mockRetryable, result) =>
58+
val f = for {
59+
_ <- recoverToExceptionIf[OpenAIScalaClientUnknownHostException](
60+
result
61+
)
62+
} yield mockRetryable
63+
verifyNumAttempts(n = 1, f, mockRetryable)
64+
}
65+
}
66+
67+
"not retry on success" in {
68+
testWithResults(attempts = 2, Seq(Future.successful(successfulResult))) {
69+
(mockRetryable, result) =>
70+
result.futureValue shouldBe successfulResult
71+
verifyNumAttempts(n = 1, result, mockRetryable)
72+
}
73+
}
74+
75+
"fail when max retries exceeded" in {
76+
val ex = Future.failed {
77+
new OpenAIScalaClientTimeoutException("retryable exception")
78+
}
79+
testWithResults(
80+
attempts = 2,
81+
Seq(ex, ex, ex, Future.successful(successfulResult))
82+
) { (_, result) =>
83+
recoverToSucceededIf[OpenAIScalaClientTimeoutException](
84+
result
85+
).futureValue shouldBe Succeeded
86+
}
87+
}
88+
89+
"compute the correct delay when using constant interval" in {
90+
val interval = 10.seconds
91+
val settings = RetrySettings(interval)
92+
delay(1)(settings) shouldBe interval
93+
delay(5)(settings) shouldBe interval
94+
}
95+
96+
"compute the correct delay when using strictly positive base" in {
97+
val settings = RetrySettings(
98+
maxRetries = 5,
99+
delayOffset = 2.seconds,
100+
delayBase = 2
101+
)
102+
delay(1)(settings) shouldBe 4.seconds
103+
delay(2)(settings) shouldBe 6.seconds
104+
delay(3)(settings) shouldBe 10.seconds
105+
}
106+
107+
}
108+
109+
implicit val scheduler: Scheduler = actorSystem.scheduler
110+
111+
override def patienceConfig: PatienceConfig = patience
112+
implicit val retrySettings: RetrySettings = RetrySettings(
113+
maxRetries = 5,
114+
delayOffset = 0.seconds,
115+
delayBase = 1
116+
)
117+
118+
def testWithException(ex: OpenAIScalaClientException)(
119+
test: (Retryable, Future[Int]) => Unit
120+
): Unit = {
121+
val results = Seq(Future.failed(ex), Future.successful(successfulResult))
122+
testWithResults(results.length, results)(test)
123+
}
124+
125+
def testWithResults(attempts: Int, results: Seq[Future[Int]])(
126+
test: (Retryable, Future[Int]) => Unit
127+
): Unit = {
128+
val future = Promise[Int]().future
129+
val mockRetryable = mock[Retryable]
130+
when(mockRetryable.attempt())
131+
.thenReturn(results.head, results.takeRight(results.length - 1): _*)
132+
val result = retry(() => mockRetryable.attempt(), attempts)
133+
test(mockRetryable, result)
134+
}
135+
136+
def verifyNumAttempts[T](n: Int, f: Future[T], mock: Retryable): Unit =
137+
whenReady(f) { _ =>
138+
verify(mock, times(n)).attempt()
139+
}
140+
141+
override def actorSystem: ActorSystem = system
142+
}
143+
144+
trait Retryable {
145+
def attempt(): Future[Int]
146+
}

0 commit comments

Comments
 (0)