|
5 | 5 | import logging
|
6 | 6 | import os
|
7 | 7 | import re
|
| 8 | +import tempfile |
8 | 9 | from collections.abc import Generator
|
9 | 10 | from contextlib import contextmanager
|
10 | 11 | from functools import cached_property
|
@@ -1298,6 +1299,135 @@ def from_repo(
|
1298 | 1299 | logger.exception(f"Failed to initialize codebase: {e}")
|
1299 | 1300 | raise
|
1300 | 1301 |
|
| 1302 | + @classmethod |
| 1303 | + def from_string( |
| 1304 | + cls, |
| 1305 | + code: str, |
| 1306 | + *, |
| 1307 | + language: Literal["python", "typescript"] | ProgrammingLanguage, |
| 1308 | + ) -> "Codebase": |
| 1309 | + """Creates a Codebase instance from a string of code. |
| 1310 | +
|
| 1311 | + Args: |
| 1312 | + code: String containing code |
| 1313 | + language: Language of the code. Defaults to Python. |
| 1314 | +
|
| 1315 | + Returns: |
| 1316 | + Codebase: A Codebase instance initialized with the provided code |
| 1317 | +
|
| 1318 | + Example: |
| 1319 | + >>> # Python code |
| 1320 | + >>> code = "def add(a, b): return a + b" |
| 1321 | + >>> codebase = Codebase.from_string(code, language="python") |
| 1322 | +
|
| 1323 | + >>> # TypeScript code |
| 1324 | + >>> code = "function add(a: number, b: number): number { return a + b; }" |
| 1325 | + >>> codebase = Codebase.from_string(code, language="typescript") |
| 1326 | + """ |
| 1327 | + if not language: |
| 1328 | + msg = "missing required argument language" |
| 1329 | + raise TypeError(msg) |
| 1330 | + |
| 1331 | + logger.info("Creating codebase from string") |
| 1332 | + |
| 1333 | + # Determine language and filename |
| 1334 | + prog_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language |
| 1335 | + filename = "test.ts" if prog_lang == ProgrammingLanguage.TYPESCRIPT else "test.py" |
| 1336 | + |
| 1337 | + # Create codebase using factory |
| 1338 | + from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory |
| 1339 | + |
| 1340 | + files = {filename: code} |
| 1341 | + |
| 1342 | + with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir: |
| 1343 | + logger.info(f"Using directory: {tmp_dir}") |
| 1344 | + codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang) |
| 1345 | + logger.info("Codebase initialization complete") |
| 1346 | + return codebase |
| 1347 | + |
| 1348 | + @classmethod |
| 1349 | + def from_files( |
| 1350 | + cls, |
| 1351 | + files: dict[str, str], |
| 1352 | + *, |
| 1353 | + language: Literal["python", "typescript"] | ProgrammingLanguage | None = None, |
| 1354 | + ) -> "Codebase": |
| 1355 | + """Creates a Codebase instance from multiple files. |
| 1356 | +
|
| 1357 | + Args: |
| 1358 | + files: Dictionary mapping filenames to their content, e.g. {"main.py": "print('hello')"} |
| 1359 | + language: Optional language override. If not provided, will be inferred from file extensions. |
| 1360 | + All files must have extensions matching the same language. |
| 1361 | +
|
| 1362 | + Returns: |
| 1363 | + Codebase: A Codebase instance initialized with the provided files |
| 1364 | +
|
| 1365 | + Raises: |
| 1366 | + ValueError: If file extensions don't match a single language or if explicitly provided |
| 1367 | + language doesn't match the extensions |
| 1368 | +
|
| 1369 | + Example: |
| 1370 | + >>> # Language inferred as Python |
| 1371 | + >>> files = {"main.py": "print('hello')", "utils.py": "def add(a, b): return a + b"} |
| 1372 | + >>> codebase = Codebase.from_files(files) |
| 1373 | +
|
| 1374 | + >>> # Language inferred as TypeScript |
| 1375 | + >>> files = {"index.ts": "console.log('hello')", "utils.tsx": "export const App = () => <div>Hello</div>"} |
| 1376 | + >>> codebase = Codebase.from_files(files) |
| 1377 | + """ |
| 1378 | + # Create codebase using factory |
| 1379 | + from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory |
| 1380 | + |
| 1381 | + if not files: |
| 1382 | + msg = "No files provided" |
| 1383 | + raise ValueError(msg) |
| 1384 | + |
| 1385 | + logger.info("Creating codebase from files") |
| 1386 | + |
| 1387 | + prog_lang = ProgrammingLanguage.PYTHON # Default language |
| 1388 | + |
| 1389 | + if files: |
| 1390 | + py_extensions = {".py"} |
| 1391 | + ts_extensions = {".ts", ".tsx", ".js", ".jsx"} |
| 1392 | + |
| 1393 | + extensions = {os.path.splitext(f)[1].lower() for f in files} |
| 1394 | + inferred_lang = None |
| 1395 | + |
| 1396 | + # all check to ensure that the from_files method is being used for small testing purposes only. |
| 1397 | + # If parsing an actual repo, it should not be used. Instead do Codebase("path/to/repo") |
| 1398 | + if all(ext in py_extensions for ext in extensions): |
| 1399 | + inferred_lang = ProgrammingLanguage.PYTHON |
| 1400 | + elif all(ext in ts_extensions for ext in extensions): |
| 1401 | + inferred_lang = ProgrammingLanguage.TYPESCRIPT |
| 1402 | + else: |
| 1403 | + msg = f"Cannot determine single language from extensions: {extensions}. Files must all be Python (.py) or TypeScript (.ts, .tsx, .js, .jsx)" |
| 1404 | + raise ValueError(msg) |
| 1405 | + |
| 1406 | + if language is not None: |
| 1407 | + explicit_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language |
| 1408 | + if explicit_lang != inferred_lang: |
| 1409 | + msg = f"Provided language {explicit_lang} doesn't match inferred language {inferred_lang} from file extensions" |
| 1410 | + raise ValueError(msg) |
| 1411 | + |
| 1412 | + prog_lang = inferred_lang |
| 1413 | + else: |
| 1414 | + # Default to Python if no files provided |
| 1415 | + prog_lang = ProgrammingLanguage.PYTHON if language is None else (ProgrammingLanguage(language.upper()) if isinstance(language, str) else language) |
| 1416 | + |
| 1417 | + logger.info(f"Using language: {prog_lang}") |
| 1418 | + |
| 1419 | + with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir: |
| 1420 | + logger.info(f"Using directory: {tmp_dir}") |
| 1421 | + |
| 1422 | + # Initialize git repo to avoid "not in a git repository" error |
| 1423 | + import subprocess |
| 1424 | + |
| 1425 | + subprocess.run(["git", "init"], cwd=tmp_dir, check=True, capture_output=True) |
| 1426 | + |
| 1427 | + codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang) |
| 1428 | + logger.info("Codebase initialization complete") |
| 1429 | + return codebase |
| 1430 | + |
1301 | 1431 | def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str]]:
|
1302 | 1432 | """Get all modified symbols in a pull request"""
|
1303 | 1433 | pr = self._op.get_pull_request(pr_id)
|
|
0 commit comments