Skip to content

Commit 44d00c6

Browse files
committed
feat(prompting): add PromptOptimizer to trim code spaces #317
Introduce `PromptOptimizer` to remove leading/trailing spaces and empty lines from code prompts. Integrate it into `CustomLLMProvider` for optional prompt trimming based on settings. Include comprehensive unit tests for the new functionality.
1 parent 643bb41 commit 44d00c6

File tree

3 files changed

+204
-1
lines changed

3 files changed

+204
-1
lines changed

core/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cc.unitmesh.devti.llms.custom
22

33
import cc.unitmesh.devti.gui.chat.message.ChatRole
44
import cc.unitmesh.devti.llms.LLMProvider
5+
import cc.unitmesh.devti.prompting.optimizer.PromptOptimizer
56
import cc.unitmesh.devti.settings.AutoDevSettingsState
67
import cc.unitmesh.devti.settings.coder.coderSetting
78
import com.intellij.openapi.diagnostic.logger
@@ -92,7 +93,13 @@ class CustomLLMProvider(val project: Project) : LLMProvider, CustomSSEProcessor(
9293
}
9394

9495
fun prompt(instruction: String, input: String): String {
95-
messages += Message("user", instruction)
96+
val prompt = if (project.coderSetting.state.trimCodeBeforeSend) {
97+
PromptOptimizer.trimCodeSpace(instruction)
98+
} else {
99+
instruction
100+
}
101+
102+
messages += Message("user", prompt)
96103
val customRequest = CustomRequest(messages)
97104
val requestContent = Json.encodeToString<CustomRequest>(customRequest)
98105

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package cc.unitmesh.devti.prompting.optimizer
2+
3+
object PromptOptimizer {
4+
/**
5+
* Similar to the following shell command:
6+
* ```bash
7+
* grep -Ev '^[ \t]*$ input.rs | sed 's/^[ \t]*\/\/' | sed 's/[ \t]$//'
8+
* ```
9+
*/
10+
fun trimCodeSpace(prompt: String): String {
11+
/// check language of CodeFence skip for Python
12+
return prompt.lines()
13+
.filter { it.isNotBlank() }
14+
.joinToString("\n") { it.trim() }
15+
}
16+
}
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package cc.unitmesh.devti.prompting.optimizer
2+
3+
import org.junit.Test
4+
import org.assertj.core.api.Assertions.assertThat
5+
import org.intellij.lang.annotations.Language
6+
7+
class PromptOptimizerTest {
8+
@Test
9+
fun should_trim_leading_and_trailing_spaces_from_each_line() {
10+
// given
11+
val input = """
12+
// Leading space in first line
13+
// Leading tab in second line
14+
Third line with trailing space
15+
Fourth line with trailing tab
16+
""".trimIndent()
17+
18+
// when
19+
val result = PromptOptimizer.trimCodeSpace(input)
20+
21+
// then
22+
val expected = """
23+
// Leading space in first line
24+
// Leading tab in second line
25+
Third line with trailing space
26+
Fourth line with trailing tab
27+
""".trimIndent()
28+
assertThat(result).isEqualTo(expected)
29+
}
30+
31+
@Test
32+
fun should_remove_empty_lines_from_the_prompt() {
33+
// given
34+
val input = """
35+
First line
36+
37+
Second line
38+
39+
Third line
40+
""".trimIndent()
41+
42+
// when
43+
val result = PromptOptimizer.trimCodeSpace(input)
44+
45+
// then
46+
val expected = """
47+
First line
48+
Second line
49+
Third line
50+
""".trimIndent()
51+
assertThat(result).isEqualTo(expected)
52+
}
53+
54+
@Test
55+
fun should_handle_prompt_with_only_empty_lines() {
56+
// given
57+
val input = """
58+
59+
60+
61+
""".trimIndent()
62+
63+
// when
64+
val result = PromptOptimizer.trimCodeSpace(input)
65+
66+
// then
67+
val expected = ""
68+
assertThat(result).isEqualTo(expected)
69+
}
70+
71+
@Test
72+
fun should_return_original_prompt_if_no_spaces_or_empty_lines() {
73+
// given
74+
val input = """
75+
First line
76+
Second line
77+
Third line
78+
""".trimIndent()
79+
80+
// when
81+
val result = PromptOptimizer.trimCodeSpace(input)
82+
83+
// then
84+
val expected = """
85+
First line
86+
Second line
87+
Third line
88+
""".trimIndent()
89+
assertThat(result).isEqualTo(expected)
90+
}
91+
92+
@Test
93+
fun should_handle_for_rust_code_in_issue() {
94+
@Language("Rust")
95+
val code = """
96+
use crate::{find_target, Plot};
97+
use anyhow::{Context, Result};
98+
use std::{env, process};
99+
100+
impl Plot {
101+
pub fn generate_plot(&mut self) -> Result<(), anyhow::Error> {
102+
eprintln!("Generating plot");
103+
104+
self.target =
105+
find_target(&self.target).context("⚠️ couldn't find the target for plotting")?;
106+
107+
// The cargo executable
108+
let cargo = env::var("CARGO").unwrap_or_else(|_| String::from("cargo"));
109+
110+
let fuzzer_data_dir = format!(
111+
"{}/{}/afl/{}/",
112+
&self.ziggy_output.display(),
113+
&self.target,
114+
&self.input
115+
);
116+
117+
let plot_dir = self
118+
.output
119+
.display()
120+
.to_string()
121+
.replace("{ziggy_output}", &self.ziggy_output.display().to_string())
122+
.replace("{target_name}", &self.target);
123+
println!("{plot_dir}");
124+
println!("{}", self.target);
125+
126+
// We run the afl-plot command
127+
process::Command::new(cargo)
128+
.args(["afl", "plot", &fuzzer_data_dir, &plot_dir])
129+
.spawn()
130+
.context("⚠️ couldn't spawn afl plot")?
131+
.wait()
132+
.context("⚠️ couldn't wait for the afl plot")?;
133+
134+
Ok(())
135+
}
136+
}
137+
""".trimIndent()
138+
139+
// when
140+
val result = PromptOptimizer.trimCodeSpace(code)
141+
142+
// then
143+
val expected = """
144+
use crate::{find_target, Plot};
145+
use anyhow::{Context, Result};
146+
use std::{env, process};
147+
impl Plot {
148+
pub fn generate_plot(&mut self) -> Result<(), anyhow::Error> {
149+
eprintln!("Generating plot");
150+
self.target =
151+
find_target(&self.target).context("⚠️ couldn't find the target for plotting")?;
152+
// The cargo executable
153+
let cargo = env::var("CARGO").unwrap_or_else(|_| String::from("cargo"));
154+
let fuzzer_data_dir = format!(
155+
"{}/{}/afl/{}/",
156+
&self.ziggy_output.display(),
157+
&self.target,
158+
&self.input
159+
);
160+
let plot_dir = self
161+
.output
162+
.display()
163+
.to_string()
164+
.replace("{ziggy_output}", &self.ziggy_output.display().to_string())
165+
.replace("{target_name}", &self.target);
166+
println!("{plot_dir}");
167+
println!("{}", self.target);
168+
// We run the afl-plot command
169+
process::Command::new(cargo)
170+
.args(["afl", "plot", &fuzzer_data_dir, &plot_dir])
171+
.spawn()
172+
.context("⚠️ couldn't spawn afl plot")?
173+
.wait()
174+
.context("⚠️ couldn't wait for the afl plot")?;
175+
Ok(())
176+
}
177+
}
178+
"""
179+
}
180+
}

0 commit comments

Comments
 (0)