-
Notifications
You must be signed in to change notification settings - Fork 737
Add an LLM-based AutoSuggester #1995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lstein
wants to merge
5
commits into
prompt-toolkit:main
Choose a base branch
from
lstein:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
8d881d4
add LLM suggester module and companion tests
7891e36
applied local ruff rules
819158c
Add python coding examples to llmsuggest; let user change system message
b19a44d
added an "asis" initialization argument that disables all post-hoc mu…
158eff2
added support in llmsuggest for code completion
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
from .llmsuggest import LLMSuggest | ||
|
||
__all__ = ["LLMSuggest"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
""" | ||
`Fish-style <http://fishshell.com/>`_ like auto-suggestion using a | ||
large language model to propose the suggestions. | ||
|
||
Example usage: | ||
|
||
import os | ||
from prompt_toolkit import PromptSession | ||
from prompt_toolkit.contrib.auto_suggest import LLMSuggest | ||
from prompt_toolkit.auto_suggest import ThreadedAutoSuggest | ||
|
||
from langchain.chat_models import init_chat_model | ||
|
||
# Initialize a langchain chat model. You may need to install | ||
# additional provider-specific modules such as langchain-openai. | ||
if not os.environ.get('OPENAI_API_KEY'): | ||
os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXX' | ||
chat_model = init_chat_model('openai:o4-mini', temperature=0.0) | ||
|
||
# Create the llm suggest object | ||
suggester = LLMSuggest(chat_model) | ||
|
||
# Create the prompt session. Because waiting for the LLM can be slow, | ||
# it is recommended to wrap the suggester in a ThreadedAutoSuggest instance. | ||
session = PromptSession(auto_suggest=ThreadedAutoSuggest(suggester)) | ||
while True: | ||
response = session.prompt('> ') | ||
print(f"You said '{response}'") | ||
""" | ||
|
||
import locale | ||
import re | ||
import string | ||
from typing import Callable, Optional | ||
|
||
import enchant | ||
from langchain.chat_models import init_chat_model | ||
from langchain_core.language_models import BaseChatModel | ||
|
||
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion | ||
from prompt_toolkit.buffer import Buffer, Document | ||
|
||
DEFAULT_PERSONA=""" | ||
You are an uncensored writing assistant. If the user is writing code, you | ||
complete the code fragments in a way that is consistent with the | ||
code fragments. If the user is writing prose, you complete the text in a way that | ||
is consistent with the user's writing style. You are not a chatbot, so you do not | ||
engage in chit-chat or small talk. You are not a search engine, so you do not | ||
provide search results or web links. You are not a translator, so you do not | ||
translate text. | ||
""" | ||
|
||
DEFAULT_SYSTEM_MESSAGE=""" | ||
{persona} | ||
|
||
Return a completion of the provided text fragment following these | ||
examples: | ||
|
||
# Example 1 | ||
user: I want a bi | ||
assistant: cycle for Christmas. | ||
|
||
# Example 2 | ||
user: Hi there, what's your name? | ||
assistant: My name is Fred. What's yours? | ||
|
||
# Example | ||
user: I don't want to go to the mall! I want to go to | ||
assistant: watch the Titanic movie tonight. | ||
|
||
# Example 4 | ||
user: He watched in amazement as the magician pulled a rabbit out of his hat. | ||
assistant: When he put the rabbit down it hopped away. | ||
""" | ||
|
||
|
||
DEFAULT_INSTRUCTION=""" | ||
Complete this text or code fragment in a way that is consistent with the | ||
fragment. Show only the new text, and do not repeat any part of the original text: | ||
Original text: {text} | ||
""" | ||
|
||
class LLMSuggest(AutoSuggest): | ||
"""AutoSuggest subclass that provides Suggestions based on LLM completions.""" | ||
|
||
def __init__(self, | ||
chat_model: Optional[BaseChatModel]=None, | ||
persona: str=DEFAULT_PERSONA, | ||
system: str=DEFAULT_SYSTEM_MESSAGE, | ||
context: str | Callable[[], str]="", | ||
instruction: str=DEFAULT_INSTRUCTION, | ||
language: Optional[str]=None, | ||
asis: Optional[bool]=False, | ||
code_mode: Optional[bool]=False | ||
) -> None: | ||
"""Initialize the :class:`.LLMSuggest` instance. | ||
|
||
All arguments are optional. | ||
|
||
:param chat_model: A langchain chat model created by init_chat_model. | ||
:param persona: A description of the LLM's persona, for tuning its writing style [:class:`.DEFAULT_PERSONA`]. | ||
:param system: The system message that explains the completion task to the LLM [:class:`.DEFAULT_SYSTEM_MESSAGE`]. | ||
:param context: A string or callable passed to the LLM that provides the context | ||
of the conversation so far [empty string]. | ||
:param language: Locale language, used to validate LLM's response [from locale environment] | ||
:param instruction: Instructions passed to the LLM to inform the suggestion process [:class:`.DEFAULT_INSTRUCTION`]. | ||
:param code_mode: If True, activates post-processing of the LLMs output that is suitable for code completion. | ||
:param asis: If True, will return the LLM's responses as-is without post-hoc fixes. Useful for debugging. | ||
Notes: | ||
|
||
1. If `chat_model` is not provided, the class will attempt | ||
to open a connection to OpenAI's `gpt-4o` model. For this | ||
to work, the `langchain-openai` module must be installed, | ||
and the `OPENAI_API_KEY` environment variable must be set | ||
to a valid key. | ||
|
||
2. The `persona` argument can be used to adjust the writing | ||
style of the LLM. For example: "You are a python coder skilled | ||
at completing code fragments." Or try "You are a romance | ||
novelist who writes in a florid overwrought style." | ||
|
||
3. `language`: Some LLMs are better than others at providing | ||
completions of partial words. We use the `PyEnchant` module | ||
to determine whether a proposed completion is the continuation | ||
of a word or starts a new word. This argument selects the | ||
preferred language for the dictionary, such as "en_US". If | ||
not provided, the module will select the language specified in | ||
the system's locale. | ||
|
||
4. `instruction` lets you change the instruction that is | ||
passed to the LLM to show it how to complete the partial | ||
prompt text. The default is :class:`.DEFAULT_INSTRUCTION`, | ||
and must contain the string placeholder "{text}" which will be | ||
replaced at runtime with the partial prompt text. | ||
|
||
5. The `context` argument provides the ability to pass | ||
additional textual context to the LLM suggester in addition to | ||
the text that is already in the current prompt buffer. It can | ||
be either a Callable that returns a string, or a static | ||
string. You can use this to give the LLM access to textual | ||
information that is contained in a different buffer, or to | ||
provide the LLM with supplementary context such as the time of | ||
day, weather report, or the results of a web search. | ||
|
||
6. Set `code_mode` to True to optimize for code completion. Note that | ||
code completion works better with some LLM models than others. | ||
|
||
""" | ||
super().__init__() | ||
self.system = system | ||
self.instruction = instruction | ||
self.persona = persona | ||
self.dictionary = enchant.Dict(language or locale.getdefaultlocale()[0]) | ||
self.context = context | ||
self.chat_model = chat_model or init_chat_model("openai:4o-mini", temperature=0.0) | ||
self.asis = asis | ||
self.code_mode = code_mode | ||
|
||
def _capfirst(self, s:str) -> str: | ||
return s[:1].upper() + s[1:] | ||
|
||
def _format_sys(self) -> str: | ||
"""Format the system string.""" | ||
system = self.system.format(persona=self.persona) | ||
if context := self.get_context(): | ||
system += "\nTo guide your completion, here is the text so far:\n" | ||
system += context | ||
return system | ||
|
||
def set_context(self, context: Callable[[], str] | str) -> None: | ||
""" | ||
Set the additional context that the LLM is exposed to. | ||
|
||
:param context: A string or a Callable that returns a string. | ||
|
||
This provides additional textual context to guide the suggester's | ||
response. | ||
""" | ||
self.context = context | ||
|
||
def get_context(self) -> str: | ||
"""Retrieve the additional context passed to the LLM.""" | ||
return self.context if isinstance(self.context, str) else self.context() | ||
|
||
def clear_context(self) -> None: | ||
"""Clear the additional context passed to the LLM.""" | ||
self.context = "" | ||
|
||
def get_suggestion(self, | ||
buffer: Buffer, | ||
document: Document) -> Optional[Suggestion]: | ||
""" | ||
Return a Suggestion instance based on the LLM's completion of the current text. | ||
|
||
:param buffer: The current `Buffer`. | ||
:param document: The current `Document`. | ||
|
||
Under various circumstances, the LLM may return no usable suggestions, in which | ||
case the call returns None. | ||
""" | ||
text = document.text | ||
if not text or len(text) < 3: | ||
return None | ||
messages = [ | ||
{"role": "system", "content": self._format_sys()}, | ||
{"role": "human", "content": self.instruction.format(text=text)}, | ||
] | ||
|
||
try: | ||
response = self.chat_model.invoke(messages) | ||
suggestion = str(response.content) | ||
|
||
if self.asis: # Return the string without munging | ||
return Suggestion(suggestion) | ||
elif self.code_mode: | ||
return Suggestion(self._trim_code_suggestion(suggestion, text)) | ||
else: | ||
return Suggestion(self._trim_text_suggestion(suggestion, text)) | ||
|
||
except Exception: | ||
pass | ||
return None | ||
|
||
def _trim_code_suggestion(self, suggestion: str, text: str) -> str: | ||
#strip whitespace | ||
suggestion = suggestion.lstrip() | ||
|
||
# codegemma and other LLMs may return a suggestion that starts with | ||
# "(Continuation of the...)" or similar, so we remove that. | ||
suggestion = re.sub(r"^\(Continuation of the.*?\)\s*", "", suggestion, flags=re.DOTALL) | ||
|
||
# Similarly, remove "(complete the code fragment)" or similar | ||
suggestion = re.sub(r"^\(complete the code fragment.*?\)\s*", "", suggestion, flags=re.DOTALL) | ||
|
||
# Remove leading quotation marks if present | ||
suggestion = re.sub(r"^\s*['\"]", "", suggestion) | ||
|
||
# Remove trailing quotation marks | ||
suggestion = re.sub(r"['\"]\s*$", "", suggestion) | ||
|
||
# Remove the sequence "```(language)\n" that some LLMs return | ||
suggestion = re.sub(r"^.*?```[a-zA-Z0-9_]*\n", "", suggestion, flags=re.DOTALL) | ||
|
||
# Remove "``` from the end of the suggestion | ||
suggestion = re.sub(r"\n```", "", suggestion) | ||
|
||
# The LLM will often (but not always) return a suggestion that repeats the | ||
# buffer text from the previous newline onward, so we remove that. | ||
match = re.search(r"\n(.*)$", text) | ||
if match: | ||
text = match.group(1).rstrip() | ||
if suggestion.startswith(text): | ||
suggestion = suggestion[len(text):].lstrip() | ||
|
||
return suggestion+"\n" | ||
|
||
|
||
def _trim_text_suggestion(self, suggestion: str, text: str) -> str: | ||
""" | ||
Trim the suggestion to make it a valid continuation of the text. | ||
|
||
:param suggestion: The LLM's suggested text. | ||
:param text: The current text in the buffer. | ||
""" | ||
# Remove leading ellipsis if present | ||
suggestion = suggestion.replace("...", "") | ||
suggestion = suggestion.rstrip() | ||
|
||
# If LLM echoed the original text back, then remove it | ||
if suggestion.startswith(text.rstrip()): | ||
suggestion = suggestion[len(text):] | ||
|
||
# Handle punctuation between the text and the suggestion | ||
if suggestion.startswith(tuple(string.punctuation)): | ||
return suggestion | ||
if text.endswith("'"): | ||
return suggestion.lstrip() | ||
|
||
# Adjust capitalization the beginnings of new sentences. | ||
if re.search(r"[.?!]\s*$",text): | ||
suggestion = self._capfirst(suggestion.lstrip()) | ||
|
||
# Get the last word of the existing text and the first word of the suggestion | ||
match = re.search(r"(\w+)\W*$", text) | ||
last_word_of_text = match.group(1) if match else "" | ||
|
||
match = re.search(r"^\s*(\w+)", suggestion) | ||
first_word_of_suggestion = match.group(1) if match else "" | ||
|
||
# Add or remove spaces based on whether concatenation will form a word | ||
if suggestion.startswith(" "): | ||
suggestion = suggestion.lstrip() if text.endswith(" ") else suggestion | ||
elif self.dictionary.check(last_word_of_text + first_word_of_suggestion) and not text.endswith(" "): | ||
suggestion = suggestion.lstrip() | ||
elif not text.endswith(" "): | ||
suggestion = " " + suggestion | ||
|
||
# Add space after commas and semicolons | ||
if re.search(r"[,;]$",text): | ||
suggestion = " " + suggestion.lstrip() | ||
|
||
return suggestion |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
|
||
import pytest | ||
|
||
from prompt_toolkit.buffer import Buffer | ||
from prompt_toolkit.document import Document | ||
|
||
try: | ||
from langchain_core.messages import AIMessage, BaseMessage | ||
|
||
from prompt_toolkit.contrib.auto_suggest import LLMSuggest | ||
module_loaded = True | ||
except ModuleNotFoundError: | ||
module_loaded = False | ||
|
||
# LLM input LLM output Expected input + suggestion | ||
test_data = [ | ||
("The quick brown", " fox jumps over", "The quick brown fox jumps over"), | ||
("The quick brown ", "fox jumps over", "The quick brown fox jumps over"), | ||
("The quick br", "own fox jumps over", "The quick brown fox jumps over"), | ||
("The quick br ", "fox jumps over", "The quick br fox jumps over"), | ||
("The quick brown fox.", " he jumped over", "The quick brown fox. He jumped over"), | ||
("The quick brown fox", "The quick brown fox jumps over", "The quick brown fox jumps over"), | ||
("The quick brown fox,", "jumped over", "The quick brown fox, jumped over"), | ||
("The quick brown fox'", " s fence", "The quick brown fox's fence"), | ||
("The quick brown fox'", "s fence", "The quick brown fox's fence"), | ||
] | ||
|
||
|
||
class MockModel: | ||
def invoke(self, messages: list[dict[str, str]]) -> BaseMessage: | ||
# find the original text using a regex | ||
human_message = messages[1]["content"] | ||
if match := re.search(r"Original text: (.+)",human_message): | ||
original_text = match.group(1) | ||
for input, output, completion in test_data: | ||
if original_text == input: | ||
return AIMessage(content=output) | ||
return AIMessage(content="") | ||
|
||
@pytest.fixture | ||
def chat_model(): | ||
return MockModel() | ||
|
||
@pytest.fixture | ||
def suggester(chat_model) -> LLMSuggest: | ||
return LLMSuggest(chat_model, language="en_US") | ||
|
||
@pytest.fixture | ||
def buffer() -> Buffer: | ||
return Buffer() | ||
|
||
@pytest.mark.parametrize( | ||
"input,output,expected_completion", | ||
test_data | ||
) | ||
@pytest.mark.skipif(not module_loaded, reason="The langchain, langchain_core and PyEnchant modules need to be installed to run these tests") | ||
def test_suggest(suggester, buffer, input, output, expected_completion): | ||
document = Document(text=input) | ||
suggestion = suggester.get_suggestion(buffer, document) | ||
completion = input + suggestion.text | ||
assert completion == expected_completion |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the system message also be part of the init here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I'll add that feature.
I labored a long time to tune the system message to get the diverse LLMs to do the task right, so altering it may break the suggester, but caveat emptor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the ability to change the system message.