Skip to content

Commit 486d60c

Browse files
authored
Merge pull request #32 from hotip/master
feat(settings): 选择模型后,只显示对模型的参数列表
2 parents a0f16e9 + 946f5a0 commit 486d60c

File tree

12 files changed

+561
-37
lines changed

12 files changed

+561
-37
lines changed

src/main/kotlin/cc/unitmesh/devti/llms/LLMProviderFactory.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@ import cc.unitmesh.devti.llms.azure.AzureOpenAIProvider
44
import cc.unitmesh.devti.llms.custom.CustomLLMProvider
55
import cc.unitmesh.devti.llms.openai.OpenAIProvider
66
import cc.unitmesh.devti.llms.xianghuo.XingHuoProvider
7+
import cc.unitmesh.devti.settings.AIEngines
78
import cc.unitmesh.devti.settings.AutoDevSettingsState
89
import com.intellij.openapi.components.Service
910
import com.intellij.openapi.project.Project
1011

1112
@Service
1213
class LLMProviderFactory {
13-
private val aiEngine: String
14-
get() = AutoDevSettingsState.getInstance().aiEngine
14+
private val aiEngine: AIEngines
15+
get() = AIEngines.values().find { it.name.lowercase() == AutoDevSettingsState.getInstance().aiEngine.lowercase() } ?: AIEngines.OpenAI
1516
fun connector(project: Project): LLMProvider {
1617
return when (aiEngine) {
17-
// TODO use mapping and avoid hard code engine name
18-
"OpenAI" -> project.getService(OpenAIProvider::class.java)
19-
"Custom" -> project.getService(CustomLLMProvider::class.java)
20-
"Azure" -> project.getService(AzureOpenAIProvider::class.java)
21-
"XingHuo" -> project.getService(XingHuoProvider::class.java)
18+
AIEngines.OpenAI -> project.getService(OpenAIProvider::class.java)
19+
AIEngines.Custom -> project.getService(CustomLLMProvider::class.java)
20+
AIEngines.Azure -> project.getService(AzureOpenAIProvider::class.java)
21+
AIEngines.XingHuo -> project.getService(XingHuoProvider::class.java)
2222
else -> project.getService(OpenAIProvider::class.java)
2323
}
2424
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package cc.unitmesh.devti.llms.palm2
2+
3+
import cc.unitmesh.devti.llms.LLMProvider
4+
import cc.unitmesh.devti.llms.custom.CustomRequest
5+
import cc.unitmesh.devti.settings.AutoDevSettingsState
6+
import com.intellij.openapi.components.Service
7+
import com.intellij.openapi.project.Project
8+
import kotlinx.serialization.Serializable
9+
import kotlinx.serialization.encodeToString
10+
import kotlinx.serialization.json.Json
11+
import okhttp3.MediaType.Companion.toMediaTypeOrNull
12+
import okhttp3.OkHttpClient
13+
import okhttp3.Request
14+
import okhttp3.RequestBody
15+
import okhttp3.RequestBody.Companion.toRequestBody
16+
17+
@Serializable
18+
data class PaLM2Request(val prompt: String, val input: String)
19+
20+
@Service(Service.Level.PROJECT)
21+
class PaLM2Provider(val project: Project) : LLMProvider {
22+
private val key: String
23+
get() {
24+
return AutoDevSettingsState.getInstance().openAiKey
25+
}
26+
override fun prompt(input: String): String {
27+
// val requestContent = Json.encodeToString(CustomRequest(input, input))
28+
// val body = requestContent.toRequestBody("application/json; charset=utf-8".toMediaTypeOrNull())
29+
// val builder = Request.Builder()
30+
// .url("https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=$key")
31+
// .post(body)
32+
// OkHttpClient().newCall(builder.build()).execute().use { response ->
33+
// if (!response.isSuccessful) throw Exception("Unexpected code $response")
34+
// return response.body!!.string()
35+
// }
36+
TODO()
37+
}
38+
}

src/main/kotlin/cc/unitmesh/devti/llms/xianghuo/XingHuoProvider.kt

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import javax.crypto.spec.SecretKeySpec
2525
class XingHuoProvider(val project: Project) : LLMProvider {
2626
private val autoDevSettingsState = AutoDevSettingsState.getInstance()
2727
private val secrectKey: String
28-
get() = autoDevSettingsState.xingHuoSecrectKey
28+
get() = autoDevSettingsState.xingHuoApiSecrect
2929

3030

3131
private val appid: String
@@ -37,12 +37,13 @@ class XingHuoProvider(val project: Project) : LLMProvider {
3737
private val hmacsha256Algorithms = "hmacsha256"
3838
private val uid = UUID.randomUUID().toString().substring(0, 32)
3939

40-
private val hmacsha256 by lazy {
41-
val hmac = Mac.getInstance(hmacsha256Algorithms)
42-
val keySpec = SecretKeySpec(secrectKey.toByteArray(), hmacsha256Algorithms)
43-
hmac.init(keySpec)
44-
hmac
45-
}
40+
private val hmacsha256: Mac
41+
get() {
42+
val hmac = Mac.getInstance(hmacsha256Algorithms)
43+
val keySpec = SecretKeySpec(secrectKey.toByteArray(), hmacsha256Algorithms)
44+
hmac.init(keySpec)
45+
return hmac
46+
}
4647

4748
override fun prompt(promptText: String): String {
4849
// prompt 接口看似是无用的废弃接口,因为所有 LLM 请求都只能异步返回,不可能直接返回同步结果
@@ -54,12 +55,9 @@ class XingHuoProvider(val project: Project) : LLMProvider {
5455
override fun stream(promptText: String, systemPrompt: String): Flow<String> {
5556
return callbackFlow {
5657
val client = OkHttpClient()
57-
client.newWebSocket(request, MyListener(this, onSocketOpend = {
58+
client.newWebSocket(request, MyListener(this, onSocketOpen = {
5859
val msg = getSendBody(promptText)
59-
println("sending $msg")
6060
send(msg)
61-
}, onSocketClosed = {
62-
close()
6361
}))
6462
awaitClose()
6563
}
@@ -68,20 +66,18 @@ class XingHuoProvider(val project: Project) : LLMProvider {
6866

6967
class MyListener(
7068
private val producerScope: ProducerScope<String>,
71-
private val onSocketOpend: WebSocket.() -> Unit,
72-
private val onSocketClosed: WebSocket.() -> Unit
69+
private val onSocketOpen: WebSocket.() -> Unit,
7370
) : WebSocketListener() {
7471

7572
private var sockedOpen = false
7673
override fun onOpen(webSocket: WebSocket, response: Response) {
77-
webSocket.onSocketOpend()
78-
producerScope.trySend("WebSocket connected\n")
74+
webSocket.onSocketOpen()
7975
sockedOpen = true
8076
}
8177

82-
override fun onMessage(webSocket: WebSocket, body: String) {
83-
return runCatching {
84-
val element = Json.parseToJsonElement(body)
78+
override fun onMessage(webSocket: WebSocket, text: String) {
79+
runCatching {
80+
val element = Json.parseToJsonElement(text)
8581
val choices = element.jsonObject["payload"]!!.jsonObject["choices"]!!
8682
val statusCode: Int = choices.jsonObject["status"]?.jsonPrimitive?.int!!
8783
val message = choices.jsonObject["text"]!!.jsonArray[0]
@@ -102,6 +98,7 @@ class XingHuoProvider(val project: Project) : LLMProvider {
10298

10399
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
104100
// WebSocket connection failed
101+
println("failure ${t.message} ${response?.body} ${response?.message} ${response?.code}")
105102
producerScope.trySend("onFailure ${response?.body} ${response?.message} ${response?.code}")
106103
producerScope.close()
107104
}
@@ -119,10 +116,8 @@ class XingHuoProvider(val project: Project) : LLMProvider {
119116
|GET /v1.1/chat HTTP/1.1
120117
""".trimMargin()
121118
val signature = hmacsha256.doFinal(header.toByteArray()).encodeBase64()
122-
System.err.println(signature)
123119
val authorization =
124120
"""api_key="$apikey", algorithm="hmac-sha256", headers="host date request-line", signature="$signature""""
125-
System.err.println(authorization)
126121

127122
val params = mapOf(
128123
"authorization" to authorization.toByteArray().encodeBase64(),

src/main/kotlin/cc/unitmesh/devti/settings/AppSettingsComponent.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class AppSettingsComponent(settings: AutoDevSettingsState) {
242242
settings.maxTokenLength != getMaxTokenLength() ||
243243
settings.xingHuoAppId != getXingHuoAppId() ||
244244
settings.xingHuoApiKey != getXingHuoApiKey() ||
245-
settings.xingHuoSecrectKey != getXingHuoAppSecret()
245+
settings.xingHuoApiSecrect != getXingHuoAppSecret()
246246

247247
}
248248

@@ -264,7 +264,7 @@ class AppSettingsComponent(settings: AutoDevSettingsState) {
264264
maxTokenLength = getMaxTokenLength()
265265
xingHuoAppId = getXingHuoAppId()
266266
xingHuoApiKey = getXingHuoApiKey()
267-
xingHuoSecrectKey = getXingHuoAppSecret()
267+
xingHuoApiSecrect = getXingHuoAppSecret()
268268
}
269269
}
270270

@@ -286,7 +286,7 @@ class AppSettingsComponent(settings: AutoDevSettingsState) {
286286
setMaxTokenLength(it.maxTokenLength)
287287
setXingHuoAppId(it.xingHuoAppId)
288288
setXingHuoAppKey(it.xingHuoApiKey)
289-
setXingHuoApiSecret(it.xingHuoSecrectKey)
289+
setXingHuoApiSecret(it.xingHuoApiSecrect)
290290
}
291291
}
292292
}

src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsConfigurable.kt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,19 @@ import org.jetbrains.annotations.Nullable
66
import javax.swing.JComponent
77

88
class AutoDevSettingsConfigurable : Configurable {
9-
private lateinit var component: AppSettingsComponent
9+
private val component: LLMSettingComponent = LLMSettingComponent(AutoDevSettingsState.getInstance())
1010

1111
@Nls(capitalization = Nls.Capitalization.Title)
1212
override fun getDisplayName(): String {
1313
return "AutoDev"
1414
}
1515

16-
override fun getPreferredFocusedComponent(): JComponent {
17-
return component.preferredFocusedComponent
16+
override fun getPreferredFocusedComponent(): JComponent? {
17+
return null
1818
}
1919

2020
@Nullable
2121
override fun createComponent(): JComponent {
22-
component = AppSettingsComponent(AutoDevSettingsState.getInstance())
2322
return component.panel
2423
}
2524

@@ -29,7 +28,7 @@ class AutoDevSettingsConfigurable : Configurable {
2928
}
3029

3130
override fun apply() {
32-
component.exportSettings(target = AutoDevSettingsState.getInstance())
31+
component.exportSettings(AutoDevSettingsState.getInstance())
3332
}
3433

3534
override fun reset() {

src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class AutoDevSettingsState : PersistentStateComponent<AutoDevSettingsState> {
1919
var customPrompts = ""
2020

2121
var xingHuoAppId = ""
22-
var xingHuoSecrectKey = ""
22+
var xingHuoApiSecrect = ""
2323
var xingHuoApiKey = ""
2424

2525
/**

src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ package cc.unitmesh.devti.settings
22

33
val OPENAI_MODEL = arrayOf("gpt-3.5-turbo","gpt-3.5-turbo-16k", "gpt-4")
44
val AI_ENGINES = arrayOf("OpenAI", "Custom", "Azure", "XingHuo")
5+
6+
enum class AIEngines {
7+
OpenAI, Custom, Azure, XingHuo
8+
}
9+
510
val DEFAULT_AI_ENGINE = AI_ENGINES[0]
611

712
val HUMAN_LANGUAGES = arrayOf("English", "中文")

0 commit comments

Comments
 (0)