Skip to content

Commit 91cb881

Browse files
authored
Merge pull request #58 from cequence-io/token_count_polishing
Token count polishing
2 parents 8182b67 + ac671a7 commit 91cb881

File tree

4 files changed

+74
-19
lines changed

4 files changed

+74
-19
lines changed

README.md

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,16 +363,62 @@ For this to work you need to use `OpenAIServiceStreamedFactory` from `openai-sca
363363
}
364364
```
365365

366-
- 🔥 **New**: Count expected used tokens before calling `createChatCompletions` or `createChatFunCompletions`, this help you select proper model ex. `gpt-3.5-turbo` or `gpt-3.5-turbo-16k` and reduce costs. This is an experimental feature and it may not work for all models. Requires `openai-scala-count-tokens` lib.
366+
- 🔥 **New**: Count expected used tokens before calling `createChatCompletions` or `createChatFunCompletions`, this helps you select proper model ex. `gpt-3.5-turbo` or `gpt-3.5-turbo-16k` and reduce costs. This is an experimental feature and it may not work for all models. Requires `openai-scala-count-tokens` lib.
367367

368+
An example how to count message tokens:
369+
```scala
370+
import io.cequence.openaiscala.domain.{AssistantMessage, BaseMessage, FunctionSpec, ModelId, SystemMessage, UserMessage}
371+
372+
class MyCompletionService extends OpenAICountTokensHelper {
373+
def exec = {
374+
val model = ModelId.gpt_4_turbo_2024_04_09
375+
376+
// messages to be sent to OpenAI
377+
val messages: Seq[BaseMessage] = Seq(
378+
SystemMessage("You are a helpful assistant."),
379+
UserMessage("Who won the world series in 2020?"),
380+
AssistantMessage("The Los Angeles Dodgers won the World Series in 2020."),
381+
UserMessage("Where was it played?"),
382+
)
383+
384+
val tokens = countMessageTokens(model, messages)
385+
}
386+
}
387+
```
388+
389+
An example how to count message tokens when a function is involved:
368390
```scala
369391
import io.cequence.openaiscala.service.OpenAICountTokensHelper
370392
import io.cequence.openaiscala.domain.{ChatRole, FunMessageSpec, FunctionSpec}
371393

394+
// TODO: simpler example
395+
import io.cequence.openaiscala.domain.{BaseMessage, FunctionSpec, ModelId, SystemMessage, UserMessage}
396+
372397
class MyCompletionService extends OpenAICountTokensHelper {
373398
def exec = {
374-
val messages: Seq[FunMessageSpec] = ??? // messages to be sent to OpenAI
375-
val function: FunctionSpec = ??? // function to be called
399+
val model = ModelId.gpt_4_turbo_2024_04_09
400+
401+
// messages to be sent to OpenAI
402+
val messages: Seq[BaseMessage] =
403+
Seq(
404+
SystemMessage("You are a helpful assistant."),
405+
UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")
406+
)
407+
408+
// function to be called
409+
val function: FunctionSpec = FunctionSpec(
410+
name = "getWeather",
411+
parameters = Map(
412+
"type" -> "object",
413+
"properties" -> Map(
414+
"location" -> Map(
415+
"type" -> "string",
416+
"description" -> "The city to get the weather for"
417+
),
418+
"unit" -> Map("type" -> "string", "enum" -> List("celsius", "fahrenheit"))
419+
)
420+
)
421+
)
376422

377423
val tokens = countFunMessageTokens(model, messages, Seq(function), Some(function.name))
378424
}

openai-count-tokens/README.md

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,25 @@ or to *pom.xml* (if you use maven)
2727

2828
## Usage
2929

30+
An example how to count message tokens:
3031
```scala
31-
import io.cequence.openaiscala.service.OpenAICountTokensHelper
32-
import io.cequence.openaiscala.domain.{ChatRole, FunMessageSpec, FunctionSpec}
33-
34-
val messages: Seq[FunMessageSpec] = ??? // messages to be sent to OpenAI
35-
val function: FunctionSpec = ??? // function to be called
36-
37-
val service = new OpenAICountTokensService()
38-
39-
val tokens = service.countFunMessageTokens(messages, List(function), Some(function.name))
32+
import io.cequence.openaiscala.domain.{AssistantMessage, BaseMessage, FunctionSpec, ModelId, SystemMessage, UserMessage}
33+
34+
class MyCompletionService extends OpenAICountTokensHelper {
35+
def exec = {
36+
val model = ModelId.gpt_4_turbo_2024_04_09
37+
38+
// messages to be sent to OpenAI
39+
val messages: Seq[BaseMessage] = Seq(
40+
SystemMessage("You are a helpful assistant."),
41+
UserMessage("Who won the world series in 2020?"),
42+
AssistantMessage("The Los Angeles Dodgers won the World Series in 2020."),
43+
UserMessage("Where was it played?"),
44+
)
45+
46+
val tokens = countMessageTokens(model, messages)
47+
}
48+
}
4049
```
4150

4251

openai-count-tokens/src/main/scala/io/cequence/openaiscala/service/OpenAICountTokensHelper.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ trait OpenAICountTokensHelper {
4949

5050
private def tokensPerMessageAndName(model: String): (Int, Int) =
5151
model match {
52-
case "gpt-3.5-turbo-0301" =>
52+
case ModelId.gpt_3_5_turbo_0301 =>
5353
// every message follows <|start|>{role/name}\n{content}<|end|>\n
5454
// if there's a name, the role is omitted
5555
(4, -1)
56-
case "gpt-3.5-turbo-0613" | "gpt-3.5-turbo-16k-0613" | "gpt-4-0314" | "gpt-4-32k-0314" |
57-
"gpt-4-0613" | "gpt-4-32k-0613" =>
56+
case ModelId.gpt_3_5_turbo_0613 | ModelId.gpt_3_5_turbo_16k_0613 | ModelId.gpt_4_0613 |
57+
ModelId.gpt_4_32k_0613 | ModelId.gpt_4_turbo_2024_04_09 =>
5858
(3, 1)
59-
case "gpt-3.5-turbo" => tokensPerMessageAndName("gpt-3.5-turbo-0613")
60-
case "gpt-4" => tokensPerMessageAndName("gpt-4-0613")
61-
case _ =>
59+
case ModelId.gpt_3_5_turbo => tokensPerMessageAndName(ModelId.gpt_3_5_turbo_0613)
60+
case ModelId.gpt_4 => tokensPerMessageAndName(ModelId.gpt_4_0613)
61+
case _ =>
6262
// failover to (3, 1)
6363
(3, 1)
6464
}

openai-count-tokens/src/test/scala/io/cequence/openaiscala/service/OpenAICountTokensServiceSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import org.scalatestplus.mockito.MockitoSugar
1616
import org.scalatest.concurrent.ScalaFutures
1717
import org.scalatest.matchers.should.Matchers
1818
import org.scalatest.wordspec.AnyWordSpecLike
19-
import org.scalatest.{BeforeAndAfterAll, Ignore}
19+
import org.scalatest.BeforeAndAfterAll
2020

2121
import scala.collection.immutable.ListMap
2222
import scala.concurrent.ExecutionContext.Implicits.global

0 commit comments

Comments
 (0)