Skip to content

Moderations #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/src/main/java/moderations/ModerationsExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package moderations;

import com.cjcrafter.openai.OpenAI;
import com.cjcrafter.openai.moderations.CreateModerationRequest;
import com.cjcrafter.openai.moderations.Moderation;
import io.github.cdimascio.dotenv.Dotenv;

import java.util.Comparator;
import java.util.Scanner;

public class ModerationsExample {

// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
// dependency. Then you can add a .env file in your project directory.
public static final OpenAI openai = OpenAI.builder()
.apiKey(Dotenv.load().get("OPENAI_TOKEN"))
.build();

public static final Scanner scan = new Scanner(System.in);

public static void main(String[] args) {
while (true) {
System.out.print("Input: ");
String input = scan.nextLine();
CreateModerationRequest request = CreateModerationRequest.builder()
.input(input)
.build();

Moderation moderation = openai.moderations().create(request);
Moderation.Result result = moderation.getResults().get(0);

// Finds the category with the highest score
String highest = result.getCategoryScores().keySet().stream()
.max(Comparator.comparing(a -> result.getCategoryScores().get(a)))
.orElseThrow(() -> new RuntimeException("No categories found!"));

System.out.println("Highest category: " + highest + ", with a score of " + result.getCategoryScores().get(highest));
}
}
}
25 changes: 25 additions & 0 deletions examples/src/main/kotlin/moderations/ModerationsExample.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package moderations

import com.cjcrafter.openai.moderations.create
import com.cjcrafter.openai.openAI
import io.github.cdimascio.dotenv.dotenv


fun main() {

// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
// dependency. Then you can add a .env file in your project directory.
val key = dotenv()["OPENAI_TOKEN"]
val openai = openAI { apiKey(key) }

while (true) {
print("Input: ")
val input = readln()
val moderation = openai.moderations.create {
input(input)
}

val max = moderation.results[0].categoryScores.entries.maxBy { it.value }
println("Highest category: ${max.key} with a score of ${max.value}")
}
}
14 changes: 14 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/OpenAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.message.TextAnnotation
import com.cjcrafter.openai.util.OpenAIDslMarker
Expand Down Expand Up @@ -135,6 +136,19 @@ interface OpenAI {
@Contract(pure = true)
fun files(): FileHandler = files

/**
* Returns the handler for the moderations endpoint. This handler can be used
* to create moderations.
*/
val moderations: ModerationHandler

/**
* Returns the handler for the moderations endpoint. This method is purely
* syntactic sugar for Java users.
*/
@Contract(pure = true)
fun moderations(): ModerationHandler = moderations

/**
* Returns the handler for the assistants endpoint. This handler can be used
* to create, retrieve, and delete assistants.
Expand Down
25 changes: 16 additions & 9 deletions src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.moderations.ModerationHandlerImpl
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.ThreadHandlerImpl
import com.fasterxml.jackson.databind.JavaType
Expand Down Expand Up @@ -127,23 +129,28 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
return requestHelper.executeRequest(httpRequest, EmbeddingsResponse::class.java)
}

private var files0: FileHandlerImpl? = null
override val files: FileHandler
get() = files0 ?: FileHandlerImpl(requestHelper, FILES_ENDPOINT).also { files0 = it }
override val files: FileHandler by lazy {
FileHandlerImpl(requestHelper, FILES_ENDPOINT)
}

private var assistants0: AssistantHandlerImpl? = null
override val assistants: AssistantHandler
get() = assistants0 ?: AssistantHandlerImpl(requestHelper, ASSISTANTS_ENDPOINT).also { assistants0 = it }
override val moderations: ModerationHandler by lazy {
ModerationHandlerImpl(requestHelper, MODERATIONS_ENDPOINT)
}

private var threads0: ThreadHandlerImpl? = null
override val threads: ThreadHandler
get() = threads0 ?: ThreadHandlerImpl(requestHelper, THREADS_ENDPOINT).also { threads0 = it }
override val assistants: AssistantHandler by lazy {
AssistantHandlerImpl(requestHelper, ASSISTANTS_ENDPOINT)
}

override val threads: ThreadHandler by lazy {
ThreadHandlerImpl(requestHelper, THREADS_ENDPOINT)
}

companion object {
const val COMPLETIONS_ENDPOINT = "v1/completions"
const val CHAT_ENDPOINT = "v1/chat/completions"
const val EMBEDDINGS_ENDPOINT = "v1/embeddings"
const val FILES_ENDPOINT = "v1/files"
const val MODERATIONS_ENDPOINT = "v1/moderations"
const val ASSISTANTS_ENDPOINT = "v1/assistants"
const val THREADS_ENDPOINT = "v1/threads"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.cjcrafter.openai.moderations

import com.cjcrafter.openai.util.OpenAIDslMarker

/**
* Represents a request to create a new moderation request.
*
* @property input The input to moderate
* @property model The model to use for moderation
*/
data class CreateModerationRequest internal constructor(
var input: Any,
var model: String? = null
) {

@OpenAIDslMarker
class Builder internal constructor() {
private var input: Any? = null
private var model: String? = null

/**
* Sets the input to moderate.
*
* @param input The input to moderate
*/
fun input(input: String) = apply { this.input = input }

/**
* Sets the input to moderate.
*
* @param input The input to moderate
*/
fun input(input: List<String>) = apply { this.input = input }

/**
* Sets the model to use for moderation.
*
* @param model The model to use for moderation
*/
fun model(model: String) = apply { this.model = model }

/**
* Builds the [CreateModerationRequest] instance.
*/
fun build(): CreateModerationRequest {
return CreateModerationRequest(
input = input ?: throw IllegalStateException("input must be defined to use CreateModerationRequest"),
model = model
)
}
}

companion object {
/**
* Returns a builder to construct a [CreateModerationRequest] instance.
*/
@JvmStatic
fun builder() = Builder()
}
}
29 changes: 29 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/moderations/Moderation.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.cjcrafter.openai.moderations

import com.fasterxml.jackson.annotation.JsonProperty

/**
* A moderation object returned by the moderations api.
*
* @property id The id of the moderation request. Always starts with "modr-".
* @property model The model which was used to moderate the content.
* @property results The results of the moderation request.
*/
data class Moderation(
@JsonProperty(required = true) val id: String,
@JsonProperty(required = true) val model: String,
@JsonProperty(required = true) val results: List<Result>,
) {
/**
* The results of the moderation request.
*
* @property flagged If any categories were flagged.
* @property categories The categories that were flagged.
* @property categoryScores The scores of each category.
*/
data class Result(
@JsonProperty(required = true) val flagged: Boolean,
@JsonProperty(required = true) val categories: Map<String, Boolean>,
@JsonProperty("category_scores", required = true) val categoryScores: Map<String, Double>,
)
}
10 changes: 10 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/moderations/ModerationDsl.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.cjcrafter.openai.moderations

fun createModerationRequest(block: CreateModerationRequest.Builder.() -> Unit): CreateModerationRequest {
return CreateModerationRequest.builder().apply(block).build()
}

fun ModerationHandler.create(block: CreateModerationRequest.Builder.() -> Unit): Moderation {
val request = createModerationRequest(block)
return create(request)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.cjcrafter.openai.moderations

/**
* Handler used to interact with [Moderation] objects.
*/
interface ModerationHandler {

/**
* Creates a new moderation request with the given options.
*
* @param request The values of the moderation to create
* @return The created moderation
*/
fun create(request: CreateModerationRequest): Moderation
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.cjcrafter.openai.moderations

import com.cjcrafter.openai.RequestHelper

class ModerationHandlerImpl(
private val requestHelper: RequestHelper,
private val endpoint: String,
): ModerationHandler {
override fun create(request: CreateModerationRequest): Moderation {
val httpRequest = requestHelper.buildRequest(request, endpoint).build()
return requestHelper.executeRequest(httpRequest, Moderation::class.java)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package com.cjcrafter.openai.chat

import com.cjcrafter.openai.MockedTest
import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage
import com.cjcrafter.openai.chat.tool.FunctionToolCall
import com.cjcrafter.openai.chat.tool.Tool
import com.cjcrafter.openai.chat.tool.ToolCall
import okhttp3.mockwebserver.MockResponse
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -46,9 +49,9 @@ class MockedChatStreamTest : MockedTest() {

// Assertions
assertEquals(ChatUser.ASSISTANT, toolMessage.role, "Tool call should be from the assistant")
assertEquals(ToolType.FUNCTION, toolMessage.toolCalls?.get(0)?.type, "Tool call should be a function")
assertEquals("solve_math_problem", toolMessage.toolCalls?.get(0)?.function?.name)
assertEquals("3/2", toolMessage.toolCalls?.get(0)?.function?.tryParseArguments()?.get("equation")?.asText())
assertEquals(Tool.Type.FUNCTION, toolMessage.toolCalls?.get(0)?.type, "Tool call should be a function")
assertEquals("solve_math_problem", (toolMessage.toolCalls?.get(0) as? FunctionToolCall)?.function?.name)
assertEquals("3/2", (toolMessage.toolCalls?.get(0) as? FunctionToolCall)?.function?.tryParseArguments()?.get("equation")?.asText())

assertEquals(ChatUser.ASSISTANT, message.role, "Message should be from the assistant")
assertEquals("The result of 3 divided by 2 is 1.5.", message.content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class FunctionCallTest {
name("enum_checker")
description("This function is used to test the enum parameter")
addEnumParameter("enum", mutableListOf("a", "b", "c"))
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"d\\\"}\"}" // d is not a valid enum
Expand All @@ -37,7 +37,7 @@ class FunctionCallTest {
name("enum_checker")
description("This function is used to test the enum parameter")
addEnumParameter("enum", mutableListOf("a", "b", "c"))
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"a\\\"}\"}" // a is a valid enum
Expand All @@ -55,7 +55,7 @@ class FunctionCallTest {
name("integer_checker")
description("This function is used to test the integer parameter")
addIntegerParameter("integer", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": \\\"not an integer\\\"}\"}" // not an integer
Expand All @@ -73,7 +73,7 @@ class FunctionCallTest {
name("integer_checker")
description("This function is used to test the integer parameter")
addIntegerParameter("integer", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": 1}\"}" // 1 is an integer
Expand All @@ -91,7 +91,7 @@ class FunctionCallTest {
name("boolean_checker")
description("This function is used to test the boolean parameter")
addBooleanParameter("is_true", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"boolean\\\": \\\"not a boolean\\\"}\"}" // not a boolean
Expand All @@ -109,7 +109,7 @@ class FunctionCallTest {
name("boolean_checker")
description("This function is used to test the boolean parameter")
addBooleanParameter("is_true", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"is_true\\\": true}\"}" // true is a boolean
Expand All @@ -128,7 +128,7 @@ class FunctionCallTest {
description("This function is used to test the required parameter")
addIntegerParameter("required", "test parameter", required = true)
addBooleanParameter("not_required", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"not_required\\\": true}\"}" // missing required parameter
Expand All @@ -147,7 +147,7 @@ class FunctionCallTest {
description("This function is used to test the required parameter")
addIntegerParameter("required", "test parameter", required = true)
addBooleanParameter("not_required", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"required\\\": 1, \\\"not_required\\\": true}\"}" // has required parameter
Expand All @@ -165,7 +165,7 @@ class FunctionCallTest {
name("function_name_checker")
description("This function is used to test the function name")
noParameters()
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"invalid_function_name\", \"arguments\": \"{}\"}" // invalid function name
Expand Down