Skip to content

Commit f1b36a8

Browse files
committed
Fix AI builders for Java consumers
1 parent 0a880cc commit f1b36a8

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,17 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
4444
public class Builder {
4545

4646
/** The producer of the content. Must be either 'user' or 'model'. By default, it's "user". */
47-
public var role: String? = "user"
47+
@JvmField public var role: String? = "user"
4848

4949
/**
5050
* The mutable list of [Part]s comprising the [Content].
5151
*
5252
* Prefer using the provided helper methods over modifying this list directly.
5353
*/
54-
public var parts: MutableList<Part> = arrayListOf()
54+
@JvmField public var parts: MutableList<Part> = arrayListOf()
55+
56+
public fun setRole(role: String?): Content.Builder = apply { this.role = role }
57+
public fun setParts(parts: MutableList<Part>): Content.Builder = apply { this.parts = parts }
5558

5659
/** Adds a new [Part] to [parts]. */
5760
@JvmName("addPart")

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,36 @@ private constructor(
136136
@JvmField public var responseSchema: Schema? = null
137137
@JvmField public var responseModalities: List<ResponseModality>? = null
138138

139+
public fun setTemperature(temperature: Float?): Builder = apply {
140+
this.temperature = temperature
141+
}
142+
public fun setTopK(topK: Int?): Builder = apply { this.topK = topK }
143+
public fun setTopP(topP: Float?): Builder = apply { this.topP = topP }
144+
public fun setCandidateCount(candidateCount: Int?): Builder = apply {
145+
this.candidateCount = candidateCount
146+
}
147+
public fun setMaxOutputTokens(maxOutputTokens: Int?): Builder = apply {
148+
this.maxOutputTokens = maxOutputTokens
149+
}
150+
public fun setPresencePenalty(presencePenalty: Float?): Builder = apply {
151+
this.presencePenalty = presencePenalty
152+
}
153+
public fun setFrequencyPenalty(frequencyPenalty: Float?): Builder = apply {
154+
this.frequencyPenalty = frequencyPenalty
155+
}
156+
public fun setStopSequences(stopSequences: List<String>?): Builder = apply {
157+
this.stopSequences = stopSequences
158+
}
159+
public fun setResponseMimeType(responseMimeType: String?): Builder = apply {
160+
this.responseMimeType = responseMimeType
161+
}
162+
public fun setResponseSchema(responseSchema: Schema?): Builder = apply {
163+
this.responseSchema = responseSchema
164+
}
165+
public fun setResponseModalities(responseModalities: List<ResponseModality>?): Builder = apply {
166+
this.responseModalities = responseModalities
167+
}
168+
139169
/** Create a new [GenerationConfig] with the attached arguments. */
140170
public fun build(): GenerationConfig =
141171
GenerationConfig(

firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@
5757
import com.google.firebase.ai.type.PublicPreviewAPI;
5858
import com.google.firebase.ai.type.ResponseModality;
5959
import com.google.firebase.ai.type.SafetyRating;
60+
import com.google.firebase.ai.type.Schema;
6061
import com.google.firebase.ai.type.SpeechConfig;
6162
import com.google.firebase.ai.type.TextPart;
6263
import com.google.firebase.ai.type.UsageMetadata;
6364
import com.google.firebase.ai.type.Voices;
6465
import com.google.firebase.concurrent.FirebaseExecutors;
66+
import java.util.ArrayList;
6567
import java.util.Calendar;
6668
import java.util.List;
6769
import java.util.Map;
@@ -92,8 +94,37 @@ public void initializeJava() throws Exception {
9294
}
9395

9496
private GenerationConfig getConfig() {
95-
return new GenerationConfig.Builder().build();
96-
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
97+
return new GenerationConfig.Builder()
98+
.setTopK(10)
99+
.setTopP(11.0F)
100+
.setTemperature(32.0F)
101+
.setCandidateCount(1)
102+
.setMaxOutputTokens(0xCAFEBABE)
103+
.setFrequencyPenalty(1.0F)
104+
.setPresencePenalty(2.0F)
105+
.setStopSequences(List.of("foo", "bar"))
106+
.setResponseMimeType("image/jxl")
107+
.setResponseModalities(List.of(ResponseModality.TEXT, ResponseModality.TEXT))
108+
.setResponseSchema(getSchema())
109+
.build();
110+
}
111+
112+
private Schema getSchema() {
113+
return Schema.obj(
114+
Map.of(
115+
"foo", Schema.numInt(),
116+
"bar", Schema.numInt("Some integer"),
117+
"baz", Schema.numInt("Some integer", false),
118+
"qux", Schema.numDouble(),
119+
"quux", Schema.numFloat("Some floating point number"),
120+
"xyzzy", Schema.array(Schema.numInt(), "A list of integers"),
121+
"fee", Schema.numLong(),
122+
"ber",
123+
Schema.obj(
124+
Map.of(
125+
"bez", Schema.array(Schema.numDouble("Nullable double", true)),
126+
"qez", Schema.enumeration(List.of("A", "B", "C"), "One of 3 letters"),
127+
"qeez", Schema.str("A funny string")))));
97128
}
98129

99130
private LiveGenerationConfig getLiveConfig() {
@@ -113,13 +144,14 @@ private LiveGenerationConfig getLiveConfig() {
113144
private void testFutures(GenerativeModelFutures futures) throws Exception {
114145
Content content =
115146
new Content.Builder()
147+
.setParts(new ArrayList<>())
116148
.addText("Fake prompt")
117149
.addFileData("fakeuri", "image/png")
118150
.addInlineData(new byte[] {}, "text/json")
119151
.addImage(Bitmap.createBitmap(0, 0, Bitmap.Config.HARDWARE))
120152
.addPart(new FunctionCallPart("fakeFunction", Map.of("fakeArg", JsonNull.INSTANCE)))
153+
.setRole("user")
121154
.build();
122-
// TODO b/406558430 Content.Builder.setParts and Content.Builder.setRole return void
123155
Executor executor = FirebaseExecutors.directExecutor();
124156
ListenableFuture<CountTokensResponse> countResponse = futures.countTokens(content);
125157
validateCountTokensResponse(countResponse.get());

0 commit comments

Comments
 (0)