Skip to content

Commit 4942a2d

Browse files
Starting page number
1 parent 6145375 commit 4942a2d

File tree

1 file changed

+59
-13
lines changed

1 file changed

+59
-13
lines changed

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@
2929
)
3030
from unstructured_client.models import shared
3131

32+
# TODO: (Marek Połom) - Update documentation before merging
33+
3234
logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
3335

3436
PARTITION_FORM_FILES_KEY = "files"
3537
PARTITION_FORM_SPLIT_PDF_PAGE_KEY = "split_pdf_page"
38+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY = "starting_page_number"
39+
40+
DEFAULT_STARTING_PAGE_NUMBER = 1
41+
3642

3743
FormData = dict[str, Union[str, shared.Files]]
3844

@@ -85,6 +91,10 @@ def before_request(
8591
Union[requests.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`,
8692
the last page request; otherwise, the original request.
8793
"""
94+
if self.client is None:
95+
logger.warning("HTTP client not accessible! Continuing without splitting.")
96+
return request
97+
8898
operation_id = hook_ctx.operation_id
8999
content_type = request.headers.get("Content-Type")
90100
body = request.body
@@ -101,9 +111,7 @@ def before_request(
101111
if file is None or not isinstance(file, shared.Files) or not self._is_pdf(file):
102112
return request
103113

104-
if self.client is None:
105-
logger.warning("HTTP client not accessible! Continuing without splitting.")
106-
return request
114+
starting_page_number = self._get_starting_page_number(form_data)
107115

108116
pages = self._get_pdf_pages(file.content)
109117
call_api_partial = functools.partial(
@@ -116,10 +124,12 @@ def before_request(
116124
self.partition_requests[operation_id] = []
117125
last_page_content = io.BytesIO()
118126
with ThreadPoolExecutor(max_workers=call_threads) as executor:
119-
for page_content, page_number, all_pages_number in pages:
127+
for page_content, page_index, all_pages_number in pages:
128+
page_number = page_index + starting_page_number
120129
# Check if this page is the last one
121-
if page_number == all_pages_number:
130+
if page_index == all_pages_number:
122131
last_page_content = page_content
132+
last_page_number = page_number
123133
break
124134
self.partition_requests[operation_id].append(
125135
executor.submit(call_api_partial, (page_content, page_number))
@@ -128,7 +138,7 @@ def before_request(
128138
# `before_request` method needs to return a request so we skip sending the last page in parallel
129139
# and return that last page at the end of this method
130140
last_page_request = self._create_request(
131-
request, form_data, last_page_content, file.file_name
141+
request, form_data, last_page_content, file.file_name, last_page_number
132142
)
133143
last_page_prepared_request = self.client.prepare_request(last_page_request)
134144
return last_page_prepared_request
@@ -217,7 +227,9 @@ def _is_pdf(self, file: shared.Files) -> bool:
217227
bool: True if the file is a PDF, False otherwise.
218228
"""
219229
if not file.file_name.endswith(".pdf"):
220-
logger.warning("Given file doesn't have '.pdf' extension. Continuing without splitting.")
230+
logger.warning(
231+
"Given file doesn't have '.pdf' extension. Continuing without splitting."
232+
)
221233
return False
222234

223235
try:
@@ -267,8 +279,7 @@ def _get_pdf_pages(
267279
new_pdf.write(pdf_buffer)
268280
pdf_buffer.seek(0)
269281

270-
# 1-index the page numbers
271-
yield pdf_buffer, offset + 1, offset_end
282+
yield pdf_buffer, offset, offset_end
272283
offset += split_size
273284

274285
def _parse_form_data(self, decoded_data: MultipartDecoder) -> FormData:
@@ -349,7 +360,9 @@ def _call_api(
349360
raise RuntimeError("HTTP client not accessible!")
350361
page_content, page_number = page
351362

352-
new_request = self._create_request(request, form_data, page_content, filename)
363+
new_request = self._create_request(
364+
request, form_data, page_content, filename, page_number
365+
)
353366
prepared_request = self.client.prepare_request(new_request)
354367

355368
try:
@@ -364,6 +377,7 @@ def _create_request(
364377
form_data: FormData,
365378
page_content: io.BytesIO,
366379
filename: str,
380+
page_number: int,
367381
) -> requests.Request:
368382
"""
369383
Creates a request object for a part of a splitted PDF file.
@@ -379,7 +393,7 @@ def _create_request(
379393
original file.
380394
"""
381395
headers = self._prepare_request_headers(request.headers)
382-
payload = self._prepare_request_payload(form_data)
396+
payload = self._prepare_request_payload(form_data, page_number)
383397
body = MultipartEncoder(
384398
fields={
385399
**payload,
@@ -415,7 +429,9 @@ def _prepare_request_headers(
415429
headers.pop("Content-Length", None)
416430
return headers
417431

418-
def _prepare_request_payload(self, form_data: FormData) -> FormData:
432+
def _prepare_request_payload(
433+
self, form_data: FormData, page_number: int
434+
) -> FormData:
419435
"""
420436
Prepares the request payload by removing unnecessary keys and updating the
421437
file.
@@ -429,7 +445,12 @@ def _prepare_request_payload(self, form_data: FormData) -> FormData:
429445
payload = copy.deepcopy(form_data)
430446
payload.pop(PARTITION_FORM_SPLIT_PDF_PAGE_KEY, None)
431447
payload.pop(PARTITION_FORM_FILES_KEY, None)
432-
payload.update({PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false"})
448+
payload.pop(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, None)
449+
updated_parameters = {
450+
PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false",
451+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY: str(page_number),
452+
}
453+
payload.update(updated_parameters)
433454
return payload
434455

435456
def _create_response(
@@ -527,3 +548,28 @@ def _clear_operation(self, operation_id: str) -> None:
527548
"""
528549
self.partition_responses.pop(operation_id, None)
529550
self.partition_requests.pop(operation_id, None)
551+
552+
def _get_starting_page_number(self, form_data: FormData) -> int:
553+
starting_page_number = DEFAULT_STARTING_PAGE_NUMBER
554+
try:
555+
_starting_page_number = (
556+
form_data.get(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY)
557+
or DEFAULT_STARTING_PAGE_NUMBER
558+
)
559+
starting_page_number = int(_starting_page_number) # type: ignore
560+
except ValueError:
561+
logger.warning(
562+
"'%s' is not a valid integer. Using default value '%d'.",
563+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
564+
DEFAULT_STARTING_PAGE_NUMBER,
565+
)
566+
567+
if starting_page_number < 1:
568+
logger.warning(
569+
"'%s' is less than 1. Using default value '%d'.",
570+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
571+
DEFAULT_STARTING_PAGE_NUMBER,
572+
)
573+
starting_page_number = DEFAULT_STARTING_PAGE_NUMBER
574+
575+
return starting_page_number

0 commit comments

Comments
 (0)