Skip to content

Commit

Permalink
Solve download bug
Browse files Browse the repository at this point in the history
  • Loading branch information
BobaZooba committed Dec 6, 2023
1 parent 561d108 commit d6c060b
Showing 1 changed file with 82 additions and 5 deletions.
87 changes: 82 additions & 5 deletions src/xllm/run/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
from transformers import (
AutoTokenizer,
)
from transformers.modeling_utils import CONFIG_NAME, cached_file
from transformers.modeling_utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
_add_variant,
cached_file,
extract_commit_hash,
)

from ..core.config import Config
from ..datasets.registry import datasets_registry
Expand Down Expand Up @@ -91,18 +100,86 @@ def prepare(config: Config) -> None:
logger.info(f"Tokenizer {config.correct_tokenizer_name_or_path} loaded")

# model
cached_file(
cache_dir = None
proxies = None
force_download = None
resume_download = None
local_files_only = False
token = None
revision = "main"
subfolder = ""

use_safetensors = None

variant = None

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}

filename = _add_variant(SAFE_WEIGHTS_NAME, variant)

# model
resolved_config_file = cached_file(
config.model_name_or_path,
CONFIG_NAME,
cache_dir=None,
cache_dir=cache_dir,
force_download=False,
resume_download=False,
proxies=None,
local_files_only=False,
token=None,
local_files_only=local_files_only,
token=token,
revision="main",
subfolder="",
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)

commit_hash = extract_commit_hash(resolved_config_file, None)

cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"token": token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}

resolved_archive_file = cached_file(config.model_name_or_path, filename, **cached_file_kwargs)

if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
_ = cached_file(
config.model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
"and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved "
"with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(WEIGHTS_NAME, variant)
resolved_archive_file = cached_file(config.model_name_or_path, filename, **cached_file_kwargs)

if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
resolved_archive_file = cached_file(
config.model_name_or_path,
_add_variant(WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)

if resolved_archive_file is None:
raise EnvironmentError(
f"{config.model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)}."
)

logger.info(f"Model {config.model_name_or_path} loaded")

0 comments on commit d6c060b

Please sign in to comment.