Skip to content
This repository was archived by the owner on Sep 25, 2024. It is now read-only.

Commit 812b2fc

Browse files
refactor: add more ruff checks
1 parent 65bebcf commit 812b2fc

20 files changed

+208
-100
lines changed

cortex_shell/cache.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Generator
2-
from hashlib import md5
2+
from functools import wraps
3+
from hashlib import sha256
34
from pathlib import Path
45
from typing import Any, Callable
56

@@ -23,14 +24,15 @@ def __init__(self, size: int, cache_path: Path) -> None:
2324
self._cache_path = cache_path
2425
self._cache_path.mkdir(parents=True, exist_ok=True)
2526

26-
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
27+
def __call__(self, func: Callable[..., Generator[str, None, None]]) -> Callable[..., Generator[str, None, None]]:
2728
"""
2829
The Cache decorator.
2930
3031
:param func: The function to cache.
3132
:return: Wrapped function with caching.
3233
"""
3334

35+
@wraps(func)
3436
def wrapper(*args: Any, **kwargs: Any) -> Generator[str, None, None]:
3537
cache_file = self._cache_path / self._get_hash_from_request(**kwargs)
3638
if kwargs.pop("caching", False) and cache_file.exists():
@@ -66,7 +68,7 @@ def _get_hash_from_request(**kwargs: Any) -> str:
6668
kwargs.pop("caching")
6769
# delete every message except the last one, which is the most recent user prompt
6870
kwargs["messages"] = kwargs["messages"][-1:]
69-
return md5(yaml_dump_str(kwargs).encode("utf-8")).hexdigest()
71+
return sha256(yaml_dump_str(kwargs).encode("utf-8")).hexdigest()
7072

7173
@classmethod
7274
@option_callback

cortex_shell/error_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def _write_line(s: str | None = None, **kwargs: Any) -> None:
9191
ErrorHandler._write_line_b(s.encode() if s is not None else s, **kwargs)
9292

9393
@staticmethod
94-
def _force_bytes(exc: Any) -> bytes:
94+
def _force_bytes(obj: Any) -> bytes: # noqa: ANN401
9595
with contextlib.suppress(TypeError):
96-
return bytes(exc)
96+
return bytes(obj)
9797
with contextlib.suppress(Exception):
98-
return str(exc).encode()
99-
return f"<unprintable {type(exc).__name__} object>".encode()
98+
return str(obj).encode()
99+
return f"<unprintable {type(obj).__name__} object>".encode()

cortex_shell/handlers/repl_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _get_prompt_from_input(self, additional_prompt: str | None = None, **kwargs:
4141

4242
super().handle(prompt, **kwargs)
4343

44-
def _get_user_input(self) -> Any:
44+
def _get_user_input(self) -> str:
4545
bindings = KeyBindings()
4646

4747
@bindings.add(Keys.ControlE) # type: ignore[misc]

cortex_shell/post_processing/shell_execution_post_processing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ class Option(Enum):
3131

3232

3333
class ShellExecutionPostProcessing(IPostProcessing):
34-
def __init__(self, shell_role: BuiltinRoleShell, describe_shell_role: BuiltinRoleDescribeShell, client: IClient):
34+
def __init__(
35+
self,
36+
shell_role: BuiltinRoleShell,
37+
describe_shell_role: BuiltinRoleDescribeShell,
38+
client: IClient,
39+
) -> None:
3540
self._shell_role = shell_role
3641
self._describe_shell_role = describe_shell_role
3742
self._handler = self._get_shell_describe_handler(client)

cortex_shell/renderer/plain_renderer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class PlainRenderer(IRenderer):
15-
def __init__(self, role: Role):
15+
def __init__(self, role: Role) -> None:
1616
self._role = role
1717

1818
def __enter__(self) -> Self:
@@ -22,5 +22,5 @@ def __exit__(self, *args: object) -> None:
2222
# new line
2323
print_formatted_text()
2424

25-
def __call__(self, text: str, chunk: str) -> None:
25+
def __call__(self, text: str, chunk: str) -> None: # noqa: ARG002
2626
print_formatted_text(get_colored_text(chunk, self._role.output.color), end="") # type: ignore[arg-type]

cortex_shell/session/chat_session.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, cast
44

5+
from ..types import YamlType
56
from ..yaml import yaml_dump, yaml_load
67

78
if TYPE_CHECKING: # pragma: no cover
89
from pathlib import Path
910

10-
from ..configuration.schema import Role
11+
from ruamel.yaml import StreamType
12+
1113
from ..types import Message
1214

1315

1416
class ChatSession:
15-
def __init__(self, file_path: Path, size: int | None = None):
17+
def __init__(self, file_path: Path, size: int | None = None) -> None:
1618
self._file_path = file_path
1719
self._size = size
1820

@@ -38,32 +40,24 @@ def messages(self) -> list[Message]:
3840
def write_messages(self, messages: list[Message]) -> None:
3941
if self._size is not None:
4042
messages = messages[-self._size :]
41-
self._write("messages", messages)
42-
43-
def get_role_name(self) -> str | None:
44-
return self._read("role")
45-
46-
def set_role(self, role: Role) -> None:
47-
self._write("role", role.name)
48-
self._write("system_prompt", role.description)
49-
50-
def get_system_prompt(self) -> str:
51-
return self._read("system_prompt") or ""
43+
self._write("messages", cast(YamlType, messages))
5244

53-
def _write(self, key: str, value: Any) -> None:
45+
def _write(self, key: str, value: YamlType) -> None:
5446
data = self._yaml_load() if self._file_path.exists() else {}
5547
data[key] = value
5648

5749
self._yaml_dump(data)
5850

59-
def _read(self, key: str) -> Any | None:
51+
def _read(self, key: str) -> YamlType:
6052
data = self._yaml_load()
61-
return data.get(key) if data else None
53+
if isinstance(data, dict):
54+
return data.get(key)
55+
return None
6256

63-
def _yaml_load(self) -> Any:
57+
def _yaml_load(self) -> Any: # noqa: ANN401
6458
if not self._file_path.exists():
6559
return None
6660
return yaml_load(stream=self._file_path)
6761

68-
def _yaml_dump(self, data: Any) -> Any:
62+
def _yaml_dump(self, data: Path | StreamType) -> Any: # noqa: ANN401
6963
yaml_dump(data=data, stream=self._file_path)

cortex_shell/session/chat_session_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class ChatSessionManager:
17-
def __init__(self, storage_path: Path, history_size: int | None = None):
17+
def __init__(self, storage_path: Path, history_size: int | None = None) -> None:
1818
self._storage_path = storage_path
1919
self._history_size = history_size
2020
self._storage_path.mkdir(parents=True, exist_ok=True)

cortex_shell/types/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

3-
from typing import Protocol, TypedDict
4-
5-
6-
class StringConvertible(Protocol):
7-
def __str__(self) -> str:
8-
pass
3+
from typing import TypedDict, Union
94

105

116
class Message(TypedDict):
127
role: str
138
content: str
9+
10+
11+
YamlType = Union[str, int, float, bool, None, list["YamlType"], dict[str, "YamlType"]]

cortex_shell/types/prompt_toolkit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def __init__( # noqa: C901
3939
values: Sequence[tuple[_T, AnyFormattedText]],
4040
default_value: _T | None = None,
4141
) -> None:
42-
assert len(values) > 0
43-
4442
self.values = values
4543

4644
keys: list[_T] = [value for (value, _) in values]

cortex_shell/util.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from importlib import resources
1414
from pathlib import Path
1515
from tempfile import gettempdir
16-
from typing import Any, Callable, cast
16+
from typing import Any, Callable, Union, cast
1717

1818
import click.exceptions
1919
import distro
2020
import typer
21+
from click import Context
2122
from pathvalidate import is_valid_filepath
2223
from prompt_toolkit import print_formatted_text as print_formatted_text_orig
2324
from prompt_toolkit.formatted_text import FormattedText
@@ -26,12 +27,16 @@
2627

2728
from . import constants as C # noqa: N812
2829

30+
_TypeOrContext = Union[type, Context]
2931

30-
def option_callback(func: Callable[[Any, str], Any]) -> Callable[[Any, str], Any]:
31-
def wrapper(cls: Any, value: str) -> None:
32+
33+
def option_callback(
34+
func: Callable[[_TypeOrContext, Any], None],
35+
) -> Callable[[_TypeOrContext, Any], None]:
36+
def wrapper(cls_or_ctx: _TypeOrContext, value: Any) -> None: # noqa: ANN401
3237
if not value:
3338
return
34-
func(cls, value)
39+
func(cls_or_ctx, value)
3540
raise typer.Exit
3641

3742
return wrapper
@@ -189,22 +194,13 @@ def get_cache_dir() -> Path:
189194

190195

191196
def print_formatted_text(*values: str | FormattedText, **kwargs: Any) -> None:
192-
# workaround for NoConsoleScreenBufferError
193197
if is_tty():
194198
print_formatted_text_orig(*values, **kwargs)
195199
else:
196-
197-
def to_text(val: Any) -> str:
198-
if isinstance(val, str):
199-
return val
200-
elif isinstance(val, list):
201-
return "".join([v[1] for v in val])
202-
else:
203-
raise TypeError
204-
205-
end = kwargs.get("end", "\n")
206-
sep = kwargs.get("sep", " ")
207-
print(sep.join([to_text(value) for value in values]), end=end)
200+
print(
201+
*(text if isinstance(text, str) else "".join(segment[1] for segment in text) for text in values),
202+
**kwargs,
203+
)
208204

209205

210206
def rmtree(path: Path) -> None:

cortex_shell/yaml.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic import BaseModel
1111

1212
if TYPE_CHECKING: # pragma: no cover
13-
from ruamel.yaml import ScalarNode, StreamTextType, StreamType
13+
from ruamel.yaml import BaseRepresenter, ScalarNode, StreamTextType, StreamType
1414

1515

1616
class YAML(ruamel.yaml.YAML):
@@ -25,41 +25,41 @@ def __init__(self) -> None:
2525
self.representer.add_representer(type(None), self._represent_none)
2626

2727
@staticmethod
28-
def _represent_none(representer: Any, data: Any) -> ScalarNode:
28+
def _represent_none(representer: BaseRepresenter, _data: Any) -> ScalarNode: # noqa: ANN401
2929
return representer.represent_scalar("tag:yaml.org,2002:null", "")
3030

3131

32-
def yaml_load(stream: Path | StreamTextType) -> Any:
32+
def yaml_load(stream: Path | StreamTextType) -> Any: # noqa: ANN401
3333
if isinstance(stream, Path):
3434
with stream.open("r", encoding="utf-8") as file:
3535
return YAML().load(file)
3636
else:
3737
return YAML().load(stream)
3838

3939

40-
def yaml_dump(data: Path | StreamType, stream: Any = None, *, transform: Any = None) -> Any:
40+
def yaml_dump(data: Path | StreamType, stream: Any = None, *, transform: Any = None) -> Any: # noqa: ANN401
4141
if isinstance(stream, Path):
4242
with stream.open("w", encoding="utf-8") as file:
43-
YAML().dump(data=data, stream=file, transform=transform)
43+
return YAML().dump(data=data, stream=file, transform=transform)
4444
else:
45-
YAML().dump(data=data, stream=stream, transform=transform)
45+
return YAML().dump(data=data, stream=stream, transform=transform)
4646

4747

48-
def yaml_dump_str(data: Path | StreamType, *, transform: Any = None) -> Any:
48+
def yaml_dump_str(data: Path | StreamType, *, transform: Any = None) -> str: # noqa: ANN401
4949
stream = StringIO()
5050
yaml_dump(data=data, stream=stream, transform=transform)
5151
return stream.getvalue()
5252

5353

54-
T = TypeVar("T", bound=BaseModel)
54+
_T = TypeVar("_T", bound=BaseModel)
5555

5656

57-
def from_yaml_file(model_type: type[T], file: Path) -> T:
57+
def from_yaml_file(model_type: type[_T], file: Path) -> _T:
5858
return pydantic.TypeAdapter(model_type).validate_python(yaml_load(file.resolve()))
5959

6060

6161
def to_yaml_file(
6262
file: Path,
63-
model: BaseModel | Any,
63+
model: BaseModel,
6464
) -> None:
6565
yaml_dump(json.loads(model.model_dump_json()), file.resolve())

0 commit comments

Comments
 (0)