Skip to content

[Distributed] Support index + multi-bin loading #1275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 42 additions & 39 deletions torchchat/distributed/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from torchchat.cli.builder import BuilderArgs, _load_checkpoint


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
_DEFAULT_SAFETENSOR_INDEX = "model.safetensors.index.json"
_DEFAULT_BIN_INDEX = "pytorch_model.bin.index.json"
_CONFIG_NAME = "config.json"


Expand Down Expand Up @@ -81,31 +82,6 @@ def get_hf_path_from_model_id(model_id: str) -> str:
return file_location


def get_hf_weight_map_and_path(
model_id: str,
) -> Tuple[Dict[str, str], str, Dict[str, str]]:
"""Get the weight map for a given HF model id and also the cache path for loading the weights"""
index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
if not os.path.exists(index_file):
raise FileNotFoundError(
f"Weight index file for {model_id} does not exist in HF cache."
)
logger.info(
f"Loading weight map from: {index_file}"
)
weight_map = read_weights_from_json(index_file)
if weight_map is None:
raise ValueError(f"Weight map not found in config file {index_file}")
weight_map, new_to_old_keymap = remap_weight_keys(weight_map)
weight_path = os.path.dirname(index_file)
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weight path {weight_path} does not exist")
logger.info(
f"Loading weights from: {weight_path}"
)
return weight_map, weight_path, new_to_old_keymap


def remap_weight_keys(dictionary):
"""Remap the keys of a dictionary to match the expected format of the tune model."""
# hf_key : dist_model_key
Expand Down Expand Up @@ -141,12 +117,13 @@ def remap_weight_keys(dictionary):
return new_dict, key_mapping


def load_safetensor_weights(
def load_weights_per_map(
stage_module: Module,
weight_map: Dict[str, str],
file_location: str,
new_to_old_keymap: Dict[str, str],
device: torch.device = "cuda",
device: torch.device,
is_safetensor: bool,
purge_model_prefix: bool = True,
ignore_cache_layers: bool = True,
model_config: Optional[Dict] = None,
Expand All @@ -160,6 +137,7 @@ def load_safetensor_weights(
file_location (str): Directory containing the weight files.
new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
device (torch.device): The device to load tensors onto.
is_safetensor (bool): Whether the files are safetensors.
purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
model_config (Optional[Dict]): Model configuration.
Expand All @@ -178,9 +156,13 @@ def load_safetensor_weights(
for file in needed_files:
full_path = os.path.join(file_location, file)
# logger.info(f"Loading checkpoint file: {full_path}")
try:
checkpoint = load_safetensor_file(full_path, "cpu") # device)
# TODO: directly load to device
if is_safetensor:
checkpoint = load_safetensor_file(full_path)
else:
checkpoint = torch.load(full_path, mmap=True, weights_only=True)

try:
update_state_dict(
stage_state_dict,
checkpoint,
Expand All @@ -189,10 +171,9 @@ def load_safetensor_weights(
new_to_old_keymap=new_to_old_keymap,
updated_states=updated_states,
)
except FileNotFoundError:
logger.error(f"File not found: {full_path}")
except Exception as e:
logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
logger.error(f"Error during checkpoint processing:")
raise e

missing_keys = handle_missing_keys(
stage_state_dict, updated_states, ignore_cache_layers
Expand Down Expand Up @@ -244,12 +225,14 @@ def get_needed_files(
return needed_files


def load_safetensor_file(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]:
def load_safetensor_file(
full_path: str,
device: str = "cpu",
) -> Dict[str, torch.Tensor]:
tensors = {}
with safe_open(full_path, framework="pt", device=device) as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
return tensors


Expand Down Expand Up @@ -378,15 +361,35 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
files), and fill into `stage_module`. Model config is needed b/c we permute
wq and wk weights based on attn heads.
"""
# Get the weight map for a given HF model id
try:
index_file = cached_file(distribution, _DEFAULT_SAFETENSOR_INDEX)
is_safetensor = True
except:
index_file = cached_file(distribution, _DEFAULT_BIN_INDEX)
is_safetensor = False
logger.info(f"Loading weight map from: {index_file}")

# Read the weight map from the index file
weight_map = read_weights_from_json(index_file)
if weight_map is None:
raise ValueError(f"Weight map not found in config file {index_file}")

# Remap the FQNs to the FQNs in HF checkpoints
weight_map, new_to_old_keymap = remap_weight_keys(weight_map)

weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
# Get the dir containing the weight files
weight_dir = os.path.dirname(index_file)
logger.info(f"Loading weights from: {weight_dir}")

num_loaded_weights, num_missing_weights = load_safetensor_weights(
# Load the weights into the stage module
num_loaded_weights, num_missing_weights = load_weights_per_map(
stage_module,
weight_map,
weight_path,
key_map,
weight_dir,
new_to_old_keymap,
device,
is_safetensor,
model_config=model_config,
)
logger.info(
Expand Down
Loading