Skip to content

Commit f8c048a

Browse files
committed
fix: handle older huggingface-hub versions without ProgressCallback
1 parent 7f2fc19 commit f8c048a

File tree

1 file changed

+45
-31
lines changed

1 file changed

+45
-31
lines changed

src/vector_search.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import time
88
from typing import Any, Dict, List, Optional
9+
from concurrent.futures import ThreadPoolExecutor
910

1011
from qdrant_client import QdrantClient
1112
from qdrant_client.http import models
@@ -80,36 +81,45 @@ def _load_embedding_model(self, model_name: str) -> SentenceTransformer:
8081
valid_params['cache_folder'] = cache_dir
8182

8283
# Track progress for downloading model components
83-
from huggingface_hub import logging as hf_logging
84-
85-
# Create handler for tracking download progress
86-
class ProgressHandler(hf_logging.ProgressCallback):
87-
def __init__(self):
88-
super().__init__()
89-
self.current_file = None
90-
self.progress = {}
91-
92-
def on_download(self, filename: str, chunk_size: int, chunk_index: int, total_size: int):
93-
file_display_name = filename.split('/')[-1]
94-
self.current_file = file_display_name
95-
96-
if not total_size:
97-
# If total size is unknown, just log each chunk
98-
logger.info(f"Downloading {file_display_name}: chunk {chunk_index}")
99-
return
100-
101-
# Calculate progress percentage
102-
downloaded = chunk_index * chunk_size
103-
percentage = min(100, int(downloaded * 100 / total_size))
104-
105-
# Update progress
106-
if percentage % 10 == 0 and (file_display_name not in self.progress or self.progress[file_display_name] < percentage):
107-
self.progress[file_display_name] = percentage
108-
logger.info(f"Downloading {file_display_name}: {percentage}% ({downloaded//1024}KB / {total_size//1024}KB)")
84+
try:
85+
# Try using the ProgressCallback from huggingface_hub (newer versions)
86+
from huggingface_hub import logging as hf_logging
87+
88+
# Check if ProgressCallback exists in this version
89+
if hasattr(hf_logging, 'ProgressCallback'):
90+
# Create handler for tracking download progress
91+
class ProgressHandler(hf_logging.ProgressCallback):
92+
def __init__(self):
93+
super().__init__()
94+
self.current_file = None
95+
self.progress = {}
96+
97+
def on_download(self, filename: str, chunk_size: int, chunk_index: int, total_size: int):
98+
file_display_name = filename.split('/')[-1]
99+
self.current_file = file_display_name
100+
101+
if not total_size:
102+
# If total size is unknown, just log each chunk
103+
logger.info(f"Downloading {file_display_name}: chunk {chunk_index}")
104+
return
105+
106+
# Calculate progress percentage
107+
downloaded = chunk_index * chunk_size
108+
percentage = min(100, int(downloaded * 100 / total_size))
109+
110+
# Update progress
111+
if percentage % 10 == 0 and (file_display_name not in self.progress or self.progress[file_display_name] < percentage):
112+
self.progress[file_display_name] = percentage
113+
logger.info(f"Downloading {file_display_name}: {percentage}% ({downloaded//1024}KB / {total_size//1024}KB)")
109114

110-
# Register progress handler
111-
progress_handler = ProgressHandler()
112-
hf_logging.callback_registry.register_callback(progress_handler)
115+
# Register progress handler
116+
progress_handler = ProgressHandler()
117+
hf_logging.callback_registry.register_callback(progress_handler)
118+
logger.info("Using progress callback for HuggingFace model downloads")
119+
else:
120+
logger.info("Progress callback not available in this huggingface-hub version")
121+
except (ImportError, AttributeError):
122+
logger.info("HuggingFace progress tracking not available, continuing without progress reporting")
113123

114124
# Log start of model loading
115125
logger.info(f"Starting to load model: {model_name}")
@@ -122,8 +132,12 @@ def on_download(self, filename: str, chunk_size: int, chunk_index: int, total_si
122132
**valid_params,
123133
)
124134

125-
# Unregister progress handler after loading
126-
hf_logging.callback_registry.unregister_callback(progress_handler)
135+
# Unregister progress handler after loading if it was registered
136+
try:
137+
if 'progress_handler' in locals() and hasattr(hf_logging, 'callback_registry'):
138+
hf_logging.callback_registry.unregister_callback(progress_handler)
139+
except Exception as e:
140+
logger.debug(f"Failed to unregister progress callback: {e}")
127141

128142
# Log final message
129143
logger.info(f"Model {model_name} loaded successfully")

0 commit comments

Comments
 (0)