1
1
import traceback
2
- from http import HTTPStatus # Add this import
2
+ from http import HTTPStatus
3
3
from typing import Callable , TypeVar
4
4
5
- from starlette .background import BackgroundTasks
6
5
from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
7
6
from starlette .requests import Request
8
7
from starlette .responses import JSONResponse , Response
9
8
10
9
from codegen .runner .sandbox .runner import SandboxRunner
11
10
from codegen .shared .exceptions .compilation import UserCodeException
12
11
from codegen .shared .logging .get_logger import get_logger
13
- from codegen .shared .performance .stopwatch_utils import stopwatch
14
12
15
13
logger = get_logger (__name__ )
16
14
@@ -34,13 +32,10 @@ async def dispatch(self, request: TRequest, call_next: RequestResponseEndpoint)
34
32
return await call_next (request )
35
33
36
34
async def process_request (self , request : TRequest , call_next : RequestResponseEndpoint ) -> TResponse :
37
- background_tasks = BackgroundTasks ()
38
35
try :
39
36
logger .info (f"> (CodemodRunMiddleware) Request: { request .url .path } " )
40
37
self .runner .codebase .viz .clear_graphviz_data ()
41
38
response = await call_next (request )
42
- background_tasks .add_task (self .cleanup_after_codemod , is_exception = False )
43
- response .background = background_tasks
44
39
return response
45
40
46
41
except UserCodeException as e :
@@ -52,21 +47,4 @@ async def process_request(self, request: TRequest, call_next: RequestResponseEnd
52
47
message = f"Unexpected error for { request .url .path } "
53
48
logger .exception (message )
54
49
res = JSONResponse (status_code = HTTPStatus .INTERNAL_SERVER_ERROR , content = {"detail" : message , "error" : str (e ), "traceback" : traceback .format_exc ()})
55
- background_tasks .add_task (self .cleanup_after_codemod , is_exception = True )
56
- res .background = background_tasks
57
50
return res
58
-
59
- async def cleanup_after_codemod (self , is_exception : bool = False ):
60
- if is_exception :
61
- # TODO: instead of committing transactions, we should just rollback
62
- logger .info ("Committing pending transactions due to exception" )
63
- self .runner .codebase .ctx .commit_transactions (sync_graph = False )
64
- await self .reset_runner ()
65
-
66
- @stopwatch
67
- async def reset_runner (self ):
68
- logger .info ("=====[ reset_runner ]=====" )
69
- logger .info (f"Syncing runner to commit: { self .runner .commit } ..." )
70
- self .runner .codebase .checkout (commit = self .runner .commit )
71
- self .runner .codebase .clean_repo ()
72
- self .runner .codebase .checkout (branch = self .runner .codebase .default_branch , create_if_missing = True )
0 commit comments