Skip to content

Commit 0d15093

Browse files
authored
Merge pull request #9 from cequence-io/feature/generate_embeddings
Generate embeddings
2 parents 75db95f + 648c794 commit 0d15093

File tree

10 files changed

+303
-14
lines changed

10 files changed

+303
-14
lines changed

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Then you can obtain a service (pod or serverless-based) in one of the following
6161
- Custom config
6262
```scala
6363
val config = ConfigFactory.load("path_to_my_custom_config")
64-
val service = PineconeIndexServiceFactory(config)
64+
val service = PineconeInd[ServerlessPineconeInferenceServiceImplSpec.scala](pinecone-client%2Fsrc%2Ftest%2Fscala%2Fio%2Fcequence%2Fpineconescala%2Fservice%2FServerlessPineconeInferenceServiceImplSpec.scala)exServiceFactory(config)
6565
```
6666

6767
- Without config for pod-based service (with env)
@@ -81,7 +81,7 @@ Then you can obtain a service (pod or serverless-based) in one of the following
8181

8282
**Ib. Obtaining `PineconeVectorService`**
8383

84-
Same as with `PineconeIndexService`, you need to first provide implicit execution context and akka materializer. Then you can obtain a service in one of the following ways.
84+
Same as with `PineconeIndexService`, you need to first provide implicit execution context and Akka materializer. Then you can obtain a service in one of the following ways.
8585

8686
- Default config (expects env. variable(s) to be set as defined in `Config` section). Note that if the index with a given name is not available, the factory will return `None`.
8787
```scala
@@ -93,6 +93,23 @@ Same as with `PineconeIndexService`, you need to first provide implicit executio
9393
}
9494
```
9595

96+
**Ic. Obtaining `PineconeInferenceService`**
97+
98+
Same as with `PineconeIndexService`, you need to first provide implicit execution context and Akka materializer. Then you can obtain a service in one of the following ways.
99+
100+
With config
101+
```scala
102+
val config = ConfigFactory.load("path_to_my_custom_config")
103+
val service = PineconeInferenceServiceFactory(config)
104+
```
105+
106+
Directly with api-key
107+
```scala
108+
val service = PineconeInferenceServiceFactory(
109+
apiKey = "your_api_key"
110+
)
111+
```
112+
96113
- Custom config
97114
```scala
98115
val config = ConfigFactory.load("path_to_my_custom_config")
@@ -374,6 +391,19 @@ Examples:
374391
println(stats)
375392
)
376393
```
394+
395+
**Inference Operations**
396+
397+
- Generate embeddings
398+
399+
```scala
400+
pineconeInferenceService.createEmbeddings(Seq("The quick brown fox jumped over the lazy dog")).map { embeddings =>
401+
println(embeddings.data.mkString("\n"))
402+
}
403+
404+
}
405+
```
406+
377407
## Demo
378408

379409
For ready-to-run demos pls. refer to separate seed projects:

pinecone-client/src/main/scala/io/cequence/pineconescala/JsonFormats.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.cequence.pineconescala
22

33
import io.cequence.pineconescala.domain.response._
4+
import io.cequence.pineconescala.domain.settings.{EmbeddingsInputType, EmbeddingsTruncate}
5+
import io.cequence.pineconescala.domain.settings.EmbeddingsInputType.{Passage, Query}
46
import io.cequence.pineconescala.domain.{Metric, PVector, PodType, SparseVector}
57
import io.cequence.wsclient.JsonUtil.enumFormat
6-
import play.api.libs.json.{Format, Json}
8+
import play.api.libs.json.{Format, JsString, Json, Reads, Writes}
79

810
object JsonFormats {
911
// vector-stuff formats
@@ -77,4 +79,21 @@ object JsonFormats {
7779
Json.format[ServerlessIndexSpec]
7880
implicit val serverlessIndexInfoFormat: Format[ServerlessIndexInfo] =
7981
Json.format[ServerlessIndexInfo]
82+
83+
// embeddings
84+
implicit val embeddingUsageInfoReads: Reads[EmbeddingsUsageInfo] =
85+
Json.reads[EmbeddingsUsageInfo]
86+
implicit val embeddingInfoReads: Reads[EmbeddingsInfo] = Json.reads[EmbeddingsInfo]
87+
implicit val embeddingValuesReads: Reads[EmbeddingsValues] = Json.reads[EmbeddingsValues]
88+
implicit val embeddingResponseReads: Reads[GenerateEmbeddingsResponse] = Json.reads[GenerateEmbeddingsResponse]
89+
90+
implicit val embeddingsInputTypeWrites: Writes[EmbeddingsInputType] = enumFormat(
91+
Query,
92+
Passage
93+
)
94+
95+
implicit val embeddingsTruncateWrites: Writes[EmbeddingsTruncate] = enumFormat(
96+
EmbeddingsTruncate.None,
97+
EmbeddingsTruncate.End
98+
)
8099
}

pinecone-client/src/main/scala/io/cequence/pineconescala/service/EndPoint.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
package io.cequence.pineconescala.service
22

3-
import io.cequence.pineconescala.domain.Metric
4-
import io.cequence.pineconescala.domain.settings.IndexSettings.{
5-
CreatePodBasedIndexSettings,
6-
CreateServerlessIndexSettings
7-
}
8-
import io.cequence.wsclient.domain.NamedEnumValue
3+
import io.cequence.pineconescala.domain.settings.IndexSettings.{CreatePodBasedIndexSettings, CreateServerlessIndexSettings}
4+
import io.cequence.wsclient.domain.{EnumValue, NamedEnumValue}
95

106
sealed abstract class EndPoint(value: String = "") extends NamedEnumValue(value)
117

128
object EndPoint {
139
case object describe_index_stats extends EndPoint
10+
case object embed extends EndPoint
1411
case object query extends EndPoint
1512
case object vectors_delete extends EndPoint("vectors/delete")
1613
case object vectors_fetch extends EndPoint("vectors/fetch")
@@ -56,6 +53,11 @@ object Tag {
5653
case object region extends Tag
5754
case object spec extends Tag
5855
case object shards extends Tag
56+
case object inputs extends Tag
57+
case object input_type extends Tag
58+
case object model extends Tag
59+
case object parameters extends Tag
60+
case object truncate extends Tag
5961

6062
// TODO: move elsewhere
6163
def fromCreatePodBasedIndexSettings(
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package io.cequence.pineconescala.service
2+
3+
import akka.stream.Materializer
4+
import com.typesafe.config.Config
5+
import io.cequence.pineconescala.domain.response.GenerateEmbeddingsResponse
6+
import io.cequence.pineconescala.domain.settings.{GenerateEmbeddingsSettings, IndexSettings}
7+
import io.cequence.wsclient.JsonUtil.{JsonOps, toJson}
8+
import io.cequence.wsclient.service.ws.{Timeouts, WSRequestHelper}
9+
import play.api.libs.json.{JsArray, JsObject, JsValue, Json}
10+
import io.cequence.pineconescala.JsonFormats._
11+
import io.cequence.pineconescala.PineconeScalaClientException
12+
import io.cequence.wsclient.domain.WsRequestContext
13+
import play.api.libs.ws.StandaloneWSRequest
14+
15+
import scala.concurrent.{ExecutionContext, Future}
16+
17+
private class PineconeInferenceServiceImpl(
18+
apiKey: String,
19+
explicitTimeouts: Option[Timeouts] = None
20+
)(
21+
implicit val ec: ExecutionContext,
22+
val materializer: Materializer
23+
) extends PineconeInferenceService
24+
with WSRequestHelper {
25+
26+
override protected type PEP = EndPoint
27+
override protected type PT = Tag
28+
override val coreUrl: String = "https://api.pinecone.io/"
29+
override protected val requestContext = WsRequestContext(explTimeouts = explicitTimeouts)
30+
31+
/**
32+
* Uses the specified model to generate embeddings for the input sequence.
33+
*
34+
* @param inputs
35+
* Input sequence for which to generate embeddings.
36+
* @param settings
37+
* @return
38+
* list of embeddings inside an envelope
39+
*/
40+
override def createEmbeddings(
41+
inputs: Seq[String],
42+
settings: GenerateEmbeddingsSettings
43+
): Future[GenerateEmbeddingsResponse] = {
44+
val basicParams: Seq[(Tag, Option[JsValue])] = jsonBodyParams(
45+
Tag.inputs -> Some(JsArray(inputs.map(input => JsObject(Seq("text" -> toJson(input)))))),
46+
Tag.model -> Some(settings.model)
47+
)
48+
val otherParams: (Tag, Option[JsValue]) = {
49+
Tag.parameters -> Some(
50+
JsObject(
51+
Seq(
52+
Tag.input_type.toString() -> Json.toJson(settings.input_type),
53+
Tag.truncate.toString() -> Json.toJson(settings.truncate)
54+
)
55+
)
56+
)
57+
}
58+
execPOST(
59+
EndPoint.embed,
60+
bodyParams = basicParams :+ otherParams
61+
).map(
62+
_.asSafe[GenerateEmbeddingsResponse]
63+
)
64+
65+
}
66+
67+
override def addHeaders(request: StandaloneWSRequest) = {
68+
val apiKeyHeader = ("Api-Key", apiKey)
69+
val versionHeader = ("X-Pinecone-API-Version", "2024-07")
70+
request
71+
.addHttpHeaders(apiKeyHeader)
72+
.addHttpHeaders(versionHeader)
73+
}
74+
75+
override protected def handleErrorCodes(
76+
httpCode: Int,
77+
message: String
78+
): Nothing =
79+
throw new PineconeScalaClientException(s"Code ${httpCode} : ${message}")
80+
81+
override def close(): Unit =
82+
client.close()
83+
84+
}
85+
86+
object PineconeInferenceServiceFactory extends PineconeServiceFactoryHelper {
87+
88+
def apply[S <: IndexSettings](
89+
apiKey: String,
90+
timeouts: Option[Timeouts] = None
91+
)(
92+
implicit ec: ExecutionContext,
93+
materializer: Materializer
94+
): PineconeInferenceService = {
95+
new PineconeInferenceServiceImpl(apiKey, timeouts)
96+
}
97+
98+
def apply(
99+
config: Config
100+
)(
101+
implicit ec: ExecutionContext,
102+
materializer: Materializer
103+
): PineconeInferenceService = {
104+
val timeouts = loadTimeouts(config)
105+
106+
apply(
107+
apiKey = config.getString(s"$configPrefix.apiKey"),
108+
timeouts = timeouts.toOption
109+
)
110+
}
111+
112+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package io.cequence.pineconescala.service
2+
3+
import akka.actor.ActorSystem
4+
import akka.stream.Materializer
5+
import com.typesafe.config.{Config, ConfigFactory}
6+
import org.scalatest.matchers.must.Matchers
7+
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
8+
import org.scalatest.wordspec.AsyncWordSpec
9+
import org.scalatest.GivenWhenThen
10+
11+
import scala.concurrent.{ExecutionContext, Future}
12+
13+
class ServerlessPineconeInferenceServiceImplSpec
14+
extends AsyncWordSpec
15+
with GivenWhenThen
16+
with ServerlessFixtures with Matchers with PineconeServiceConsts{
17+
18+
implicit val ec: ExecutionContext = ExecutionContext.global
19+
implicit val materializer: Materializer = Materializer(ActorSystem())
20+
21+
val serverlessConfig: Config = ConfigFactory.load("serverless.conf")
22+
23+
def inferenceServiceBuilder: PineconeInferenceService =
24+
PineconeInferenceServiceFactory(serverlessConfig)
25+
26+
"Pinecone Inference Service" when {
27+
28+
"create embeddings should provide embeddings for input data" in {
29+
val service = inferenceServiceBuilder
30+
for {
31+
embeddings <- service.createEmbeddings(Seq("The quick brown fox jumped over the lazy dog"),
32+
settings = DefaultSettings.GenerateEmbeddings.withPassageInputType.withEndTruncate)
33+
} yield {
34+
embeddings.data.size should be(1)
35+
embeddings.data(0).values should not be empty
36+
embeddings.usage.total_tokens should be(16)
37+
}
38+
}
39+
40+
41+
}
42+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package io.cequence.pineconescala.domain.response
2+
3+
case class GenerateEmbeddingsResponse(
4+
data: Seq[EmbeddingsValues],
5+
model: String,
6+
usage: EmbeddingsUsageInfo
7+
)
8+
9+
case class EmbeddingsValues(values: Seq[Double])
10+
11+
case class EmbeddingsInfo(
12+
embedding: Seq[Double],
13+
index: Int
14+
)
15+
16+
case class EmbeddingsUsageInfo(
17+
total_tokens: Int
18+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.cequence.pineconescala.domain.settings
2+
3+
import io.cequence.wsclient.domain.{EnumValue, NamedEnumValue}
4+
5+
case class GenerateEmbeddingsSettings(
6+
// ID of the model to use.
7+
model: String,
8+
9+
// Common property used to distinguish between types of data.
10+
input_type: Option[EmbeddingsInputType] = None,
11+
12+
// The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
13+
truncate: EmbeddingsTruncate = EmbeddingsTruncate.End
14+
) {
15+
def withPassageInputType = copy(input_type = Some(EmbeddingsInputType.Passage))
16+
def withQueryInputType = copy(input_type = Some(EmbeddingsInputType.Query))
17+
def withoutTruncate = copy(truncate = EmbeddingsTruncate.None)
18+
def withEndTruncate = copy(truncate = EmbeddingsTruncate.End)
19+
}
20+
21+
sealed abstract class EmbeddingsInputType(name: String) extends NamedEnumValue(name)
22+
23+
object EmbeddingsInputType {
24+
case object Passage extends EmbeddingsInputType(name = "passage")
25+
case object Query extends EmbeddingsInputType(name = "query")
26+
}
27+
28+
sealed abstract class EmbeddingsTruncate(name: String) extends NamedEnumValue(name)
29+
30+
object EmbeddingsTruncate {
31+
case object None extends EmbeddingsTruncate(name = "NONE")
32+
case object End extends EmbeddingsTruncate(name = "END")
33+
}
34+
35+
sealed trait EmbeddingsEncodingFormat extends EnumValue
36+
37+
object EmbeddingsEncodingFormat {
38+
case object float extends EmbeddingsEncodingFormat
39+
case object base64 extends EmbeddingsEncodingFormat
40+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package io.cequence.pineconescala.service
2+
3+
import io.cequence.pineconescala.domain.response.GenerateEmbeddingsResponse
4+
import io.cequence.pineconescala.domain.settings.GenerateEmbeddingsSettings
5+
import io.cequence.wsclient.service.CloseableService
6+
7+
import scala.concurrent.Future
8+
9+
trait PineconeInferenceService extends CloseableService with PineconeServiceConsts {
10+
11+
/**
12+
* Uses the specified model to generate embeddings for the input sequence.
13+
*
14+
* @param inputs
15+
* Input sequence for which to generate embeddings.
16+
* @param settings
17+
* @return
18+
* list of embeddings inside an envelope
19+
*/
20+
def createEmbeddings(
21+
inputs: Seq[String],
22+
settings: GenerateEmbeddingsSettings = DefaultSettings.GenerateEmbeddings
23+
): Future[GenerateEmbeddingsResponse]
24+
25+
}

pinecone-core/src/main/scala/io/cequence/pineconescala/service/PineconeServiceConsts.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package io.cequence.pineconescala.service
22

3-
import io.cequence.pineconescala.domain.settings.IndexSettings.{
4-
CreatePodBasedIndexSettings,
5-
CreateServerlessIndexSettings
6-
}
3+
import io.cequence.pineconescala.domain.settings.IndexSettings.{CreatePodBasedIndexSettings, CreateServerlessIndexSettings}
74
import io.cequence.pineconescala.domain.{Metric, PodType}
85
import io.cequence.pineconescala.domain.settings._
96

@@ -38,5 +35,9 @@ trait PineconeServiceConsts {
3835
CloudProvider.AWS,
3936
Region.EUWest1
4037
)
38+
39+
val GenerateEmbeddings = GenerateEmbeddingsSettings(
40+
model = "multilingual-e5-large"
41+
)
4142
}
4243
}

project/build.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
sbt.version = 1.8.2
1+
sbt.version = 1.9.0

0 commit comments

Comments
 (0)