Skip to content

Commit cef0ba9

Browse files
Starting page number
1 parent 13f5132 commit cef0ba9

File tree

1 file changed

+54
-12
lines changed

1 file changed

+54
-12
lines changed

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@
2929
)
3030
from unstructured_client.models import shared
3131

32+
# TODO: (Marek Połom) - Update documentation before merging
33+
3234
logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
3335

3436
PARTITION_FORM_FILES_KEY = "files"
3537
PARTITION_FORM_SPLIT_PDF_PAGE_KEY = "split_pdf_page"
38+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY = "starting_page_number"
3639

3740
SUBSTITUTE_FILENAME = "file_for_partition.pdf"
41+
DEFAULT_STARTING_PAGE_NUMBER = 1
3842

3943

4044
FormData = dict[str, Union[str, shared.Files]]
@@ -88,6 +92,10 @@ def before_request(
8892
Union[requests.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`,
8993
the last page request; otherwise, the original request.
9094
"""
95+
if self.client is None:
96+
logger.warning("HTTP client not accessible! Continuing without splitting.")
97+
return request
98+
9199
operation_id = hook_ctx.operation_id
92100
content_type = request.headers.get("Content-Type")
93101
body = request.body
@@ -104,9 +112,7 @@ def before_request(
104112
if file is None or not isinstance(file, shared.Files) or not self._is_pdf(file):
105113
return request
106114

107-
if self.client is None:
108-
logger.warning("HTTP client not accessible! Continuing without splitting.")
109-
return request
115+
starting_page_number = self._get_starting_page_number(form_data)
110116

111117
pages = self._get_pdf_pages(file.content)
112118
call_api_partial = functools.partial(
@@ -119,10 +125,12 @@ def before_request(
119125
self.partition_requests[operation_id] = []
120126
last_page_content = io.BytesIO()
121127
with ThreadPoolExecutor(max_workers=call_threads) as executor:
122-
for page_content, page_number, all_pages_number in pages:
128+
for page_content, page_index, all_pages_number in pages:
129+
page_number = page_index + starting_page_number
123130
# Check if this page is the last one
124-
if page_number == all_pages_number:
131+
if page_index == all_pages_number:
125132
last_page_content = page_content
133+
last_page_number = page_number
126134
break
127135
self.partition_requests[operation_id].append(
128136
executor.submit(call_api_partial, (page_content, page_number))
@@ -131,7 +139,7 @@ def before_request(
131139
# `before_request` method needs to return a request so we skip sending the last page in parallel
132140
# and return that last page at the end of this method
133141
last_page_request = self._create_request(
134-
request, form_data, last_page_content, file.file_name
142+
request, form_data, last_page_content, file.file_name, last_page_number
135143
)
136144
last_page_prepared_request = self.client.prepare_request(last_page_request)
137145
return last_page_prepared_request
@@ -270,8 +278,7 @@ def _get_pdf_pages(
270278
new_pdf.write(pdf_buffer)
271279
pdf_buffer.seek(0)
272280

273-
# 1-index the page numbers
274-
yield pdf_buffer, offset + 1, offset_end
281+
yield pdf_buffer, offset, offset_end
275282
offset += split_size
276283

277284
def _parse_form_data(self, decoded_data: MultipartDecoder) -> FormData:
@@ -362,7 +369,9 @@ def _call_api(
362369
raise RuntimeError("HTTP client not accessible!")
363370
page_content, page_number = page
364371

365-
new_request = self._create_request(request, form_data, page_content, filename)
372+
new_request = self._create_request(
373+
request, form_data, page_content, filename, page_number
374+
)
366375
prepared_request = self.client.prepare_request(new_request)
367376

368377
try:
@@ -377,6 +386,7 @@ def _create_request(
377386
form_data: FormData,
378387
page_content: io.BytesIO,
379388
filename: str,
389+
page_number: int,
380390
) -> requests.Request:
381391
"""
382392
Creates a request object for a part of a splitted PDF file.
@@ -392,7 +402,7 @@ def _create_request(
392402
original file.
393403
"""
394404
headers = self._prepare_request_headers(request.headers)
395-
payload = self._prepare_request_payload(form_data)
405+
payload = self._prepare_request_payload(form_data, page_number)
396406
body = MultipartEncoder(
397407
fields={
398408
**payload,
@@ -428,7 +438,9 @@ def _prepare_request_headers(
428438
headers.pop("Content-Length", None)
429439
return headers
430440

431-
def _prepare_request_payload(self, form_data: FormData) -> FormData:
441+
def _prepare_request_payload(
442+
self, form_data: FormData, page_number: int
443+
) -> FormData:
432444
"""
433445
Prepares the request payload by removing unnecessary keys and updating the
434446
file.
@@ -442,7 +454,12 @@ def _prepare_request_payload(self, form_data: FormData) -> FormData:
442454
payload = copy.deepcopy(form_data)
443455
payload.pop(PARTITION_FORM_SPLIT_PDF_PAGE_KEY, None)
444456
payload.pop(PARTITION_FORM_FILES_KEY, None)
445-
payload.update({PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false"})
457+
payload.pop(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, None)
458+
updated_parameters = {
459+
PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false",
460+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY: str(page_number),
461+
}
462+
payload.update(updated_parameters)
446463
return payload
447464

448465
def _create_response(
@@ -540,3 +557,28 @@ def _clear_operation(self, operation_id: str) -> None:
540557
"""
541558
self.partition_responses.pop(operation_id, None)
542559
self.partition_requests.pop(operation_id, None)
560+
561+
def _get_starting_page_number(self, form_data: FormData) -> int:
562+
starting_page_number = DEFAULT_STARTING_PAGE_NUMBER
563+
try:
564+
_starting_page_number = (
565+
form_data.get(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY)
566+
or DEFAULT_STARTING_PAGE_NUMBER
567+
)
568+
starting_page_number = int(_starting_page_number) # type: ignore
569+
except ValueError:
570+
logger.warning(
571+
"'%s' is not a valid integer. Using default value '%d'.",
572+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
573+
DEFAULT_STARTING_PAGE_NUMBER,
574+
)
575+
576+
if starting_page_number < 1:
577+
logger.warning(
578+
"'%s' is less than 1. Using default value '%d'.",
579+
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
580+
DEFAULT_STARTING_PAGE_NUMBER,
581+
)
582+
starting_page_number = DEFAULT_STARTING_PAGE_NUMBER
583+
584+
return starting_page_number

0 commit comments

Comments
 (0)