Skip to content

Commit 33a99a7

Browse files
committed
add embeddings
1 parent b7c5ede commit 33a99a7

File tree

13 files changed

+1808
-19
lines changed

13 files changed

+1808
-19
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package embeddings
2+
3+
import com.cjcrafter.openai.embeddings.embeddingsRequest
4+
import com.cjcrafter.openai.openAI
5+
import io.github.cdimascio.dotenv.dotenv
6+
7+
/**
8+
* In this Kotlin example, we will be using the embeddings API to generate the
9+
* embeddings of a list of strings.
10+
*/
11+
fun main() {
12+
13+
// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
14+
// dependency. Then you can add a .env file in your project directory.
15+
val key = dotenv()["OPENAI_TOKEN"]
16+
val openai = openAI { apiKey(key) }
17+
18+
val request = embeddingsRequest {
19+
input("hi")
20+
model("text-embedding-ada-002")
21+
}
22+
23+
val response = openai.createEmbeddings(request)
24+
println(response)
25+
}

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import com.cjcrafter.openai.chat.tool.ToolChoice
55
import com.cjcrafter.openai.completions.CompletionRequest
66
import com.cjcrafter.openai.completions.CompletionResponse
77
import com.cjcrafter.openai.completions.CompletionResponseChunk
8+
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
9+
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
810
import com.cjcrafter.openai.util.OpenAIDslMarker
911
import com.fasterxml.jackson.annotation.JsonAutoDetect
1012
import com.fasterxml.jackson.annotation.JsonInclude
@@ -88,6 +90,17 @@ interface OpenAI {
8890
@Contract(pure = true)
8991
fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk>
9092

93+
/**
94+
* Calls the [embeddings](https://beta.openai.com/docs/api-reference/embeddings)
95+
* API endpoint to generate the vector representation of text. The returned
96+
* vector can be used in Machine Learning models. This method is blocking.
97+
*
98+
* @param request The request to send to the API
99+
* @return The response from the API
100+
*/
101+
@Contract(pure = true)
102+
fun createEmbeddings(request: EmbeddingsRequest): EmbeddingsResponse
103+
91104
@OpenAIDslMarker
92105
open class Builder internal constructor() {
93106
protected var apiKey: String? = null

src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import com.cjcrafter.openai.chat.*
44
import com.cjcrafter.openai.completions.CompletionRequest
55
import com.cjcrafter.openai.completions.CompletionResponse
66
import com.cjcrafter.openai.completions.CompletionResponseChunk
7+
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
8+
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
79
import com.fasterxml.jackson.databind.JavaType
810
import com.fasterxml.jackson.databind.node.ObjectNode
911
import okhttp3.*
@@ -138,8 +140,14 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
138140
}
139141
}
140142

143+
override fun createEmbeddings(request: EmbeddingsRequest): EmbeddingsResponse {
144+
val httpRequest = buildRequest(request, EMBEDDINGS_ENDPOINT)
145+
return executeRequest(httpRequest, EmbeddingsResponse::class.java)
146+
}
147+
141148
companion object {
142149
const val COMPLETIONS_ENDPOINT = "v1/completions"
143150
const val CHAT_ENDPOINT = "v1/chat/completions"
151+
const val EMBEDDINGS_ENDPOINT = "v1/embeddings"
144152
}
145153
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
/**
4+
* Represents 1 embedding as a vector of floats or strings.
5+
*
6+
* @property embedding
7+
* @property index
8+
* @constructor Create empty Embedding
9+
*/
10+
data class Embedding(
11+
val embedding: List<Any>,
12+
val index: Int,
13+
) {
14+
/**
15+
* Returns the embedding as a list of floats. Make sure to use [EncodingFormat.FLOAT].
16+
*/
17+
fun asDoubles() = embedding.map { it as Double }
18+
19+
/**
20+
* Returns the embedding as a list of strings. Make sure to use [EncodingFormat.BASE64].
21+
*/
22+
fun asBase64() = embedding.map { it as String }
23+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
import com.cjcrafter.openai.util.OpenAIDslMarker
4+
import com.fasterxml.jackson.annotation.JsonProperty
5+
6+
/**
7+
* Holds the options sent to the [embeddings](https://beta.openai.com/docs/api-reference/embeddings) endpoint.
8+
* The generated embeddings can be used in Machine Learning models.
9+
*
10+
* [input] can be either a string or a list of strings.
11+
*
12+
* @property input The input(s) to convert to embeddings.
13+
* @property model Which [model](https://platform.openai.com/docs/models/embeddings) to use to generate the embeddings.
14+
* @property encodingFormat Determines how the embeddings are encoded. Defaults to [EncodingFormat.FLOAT].
15+
* @property user The user ID to associate with this request.
16+
* @constructor Create empty Embeddings request
17+
*/
18+
data class EmbeddingsRequest internal constructor(
19+
var input: Any,
20+
var model: String,
21+
@JsonProperty("encoding_format") var encodingFormat: EncodingFormat? = null,
22+
var user: String? = null,
23+
) {
24+
25+
/**
26+
* A builder design pattern for constructing an [EmbeddingsRequest] instance.
27+
*/
28+
@OpenAIDslMarker
29+
class Builder internal constructor() {
30+
private var input: Any? = null
31+
private var model: String? = null
32+
private var encodingFormat: EncodingFormat? = null
33+
private var user: String? = null
34+
35+
fun input(input: String) = apply { this.input = input }
36+
fun input(input: List<String>) = apply { this.input = input }
37+
fun model(model: String) = apply { this.model = model }
38+
fun encodingFormat(encodingFormat: EncodingFormat) = apply { this.encodingFormat = encodingFormat }
39+
fun user(user: String) = apply { this.user = user }
40+
41+
fun build(): EmbeddingsRequest {
42+
return EmbeddingsRequest(
43+
input = input ?: throw IllegalStateException("input must be defined to use EmbeddingsRequest"),
44+
model = model ?: throw IllegalStateException("model must be defined to use EmbeddingsRequest"),
45+
encodingFormat = encodingFormat,
46+
user = user
47+
)
48+
}
49+
}
50+
51+
companion object {
52+
53+
/**
54+
* Returns a builder to construct an [EmbeddingsRequest] instance.
55+
*/
56+
@JvmStatic
57+
fun builder() = Builder()
58+
}
59+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
/**
4+
* Creates an [EmbeddingsRequest] using the [EmbeddingsRequest.Builder] using Kotlin DSL.
5+
*/
6+
fun embeddingsRequest(block: EmbeddingsRequest.Builder.() -> Unit) = EmbeddingsRequest.builder().apply(block).build()
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
/**
4+
* The API response from the [EmbeddingsRequest].
5+
*
6+
* @property data The embeddings data
7+
* @property model The exact model used to generate the embeddings
8+
* @property usage How many tokens were used by the API request
9+
*/
10+
data class EmbeddingsResponse(
11+
val data: List<Embedding>,
12+
val model: String,
13+
val usage: EmbeddingsUsage,
14+
) {
15+
/**
16+
* Returns the [data] at the given [index].
17+
*/
18+
operator fun get(index: Int): Embedding {
19+
return data[index]
20+
}
21+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty
4+
5+
/**
6+
* Holds the number of tokens used by the API request. Exact pricing can vary
7+
* based on the model used.
8+
*
9+
* @property promptTokens How many tokens were taken by the input strings
10+
* @property totalTokens The total number of tokens used
11+
*/
12+
data class EmbeddingsUsage(
13+
@JsonProperty("prompt_tokens") val promptTokens: Int,
14+
@JsonProperty("total_tokens") val totalTokens: Int,
15+
)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty
4+
5+
/**
6+
* Determines how the embeddings are encoded.
7+
*/
8+
enum class EncodingFormat {
9+
10+
/**
11+
* The default encoding format.
12+
*/
13+
@JsonProperty("text")
14+
FLOAT,
15+
16+
/**
17+
* Encodes the embedding as a base64 string.
18+
*/
19+
@JsonProperty("base64")
20+
BASE64
21+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.cjcrafter.openai
2+
3+
import okhttp3.OkHttpClient
4+
import okhttp3.mockwebserver.MockWebServer
5+
import org.junit.jupiter.api.AfterEach
6+
import org.junit.jupiter.api.BeforeEach
7+
8+
abstract class MockedTest {
9+
10+
protected val mockWebServer = MockWebServer()
11+
protected lateinit var client: OkHttpClient
12+
13+
@BeforeEach
14+
fun setUp() {
15+
mockWebServer.start()
16+
client = OkHttpClient.Builder().build()
17+
}
18+
19+
@AfterEach
20+
fun tearDown() {
21+
mockWebServer.shutdown()
22+
}
23+
24+
fun readResource(resource: String): String {
25+
return this::class.java.classLoader.getResource(resource)?.readText() ?: throw Exception("Resource '$resource' not found")
26+
}
27+
}

src/test/kotlin/com/cjcrafter/openai/chat/StreamChatCompletionTest.kt

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.cjcrafter.openai.chat
22

3+
import com.cjcrafter.openai.MockedTest
34
import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage
45
import com.cjcrafter.openai.chat.tool.ToolType
56
import com.cjcrafter.openai.openAI
@@ -11,21 +12,7 @@ import org.junit.jupiter.api.Assertions.assertEquals
1112
import org.junit.jupiter.api.BeforeEach
1213
import org.junit.jupiter.api.Test
1314

14-
class StreamChatCompletionTest {
15-
16-
private val mockWebServer = MockWebServer()
17-
private lateinit var client: OkHttpClient
18-
19-
@BeforeEach
20-
fun setUp() {
21-
mockWebServer.start()
22-
client = OkHttpClient.Builder().build()
23-
}
24-
25-
@AfterEach
26-
fun tearDown() {
27-
mockWebServer.shutdown()
28-
}
15+
class StreamChatCompletionTest : MockedTest() {
2916

3017
@Test
3118
fun `test stream`() {
@@ -81,8 +68,4 @@ class StreamChatCompletionTest {
8168
assertEquals(null, message.toolCalls)
8269
assertEquals(null, message.toolCallId)
8370
}
84-
85-
private fun readResource(resource: String): String {
86-
return this::class.java.classLoader.getResource(resource)?.readText() ?: throw Exception("Resource '$resource' not found")
87-
}
8871
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.cjcrafter.openai.embeddings
2+
3+
import com.cjcrafter.openai.MockedTest
4+
import com.cjcrafter.openai.openAI
5+
import okhttp3.mockwebserver.MockResponse
6+
import org.junit.jupiter.api.Test
7+
import org.junit.jupiter.api.assertThrows
8+
import org.junit.jupiter.api.Assertions.*
9+
10+
class CreateEmbeddingsTest : MockedTest() {
11+
12+
@Test
13+
fun `test create embeddings list`() {
14+
mockWebServer.enqueue(MockResponse().setBody(readResource("create_embeddings.txt")))
15+
16+
val openai = openAI {
17+
apiKey("sk-123456789")
18+
client(client)
19+
baseUrl(mockWebServer.url("/").toString())
20+
}
21+
22+
val dummyRequest = embeddingsRequest {
23+
input(listOf("Once upon a time", "There was a frog"))
24+
model("text-embedding-ada-002")
25+
}
26+
27+
val response = openai.createEmbeddings(dummyRequest)
28+
response[0].asDoubles() // This will throw an exception if it is not a list of floats
29+
assertThrows<ClassCastException> {
30+
response[0].asBase64()
31+
}
32+
33+
assertEquals(1, response.usage.promptTokens)
34+
assertEquals(1, response.usage.totalTokens)
35+
}
36+
}

0 commit comments

Comments
 (0)