`."
- ) from e
- except RepositoryNotFoundError as e:
- logger.error(e)
- raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.") from e
- except RevisionNotFoundError as e:
- logger.error(e)
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
- f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
- ) from e
- except EntryNotFoundError:
- return False # File does not exist
- except requests.HTTPError:
- # Any authentication/authorization error will be caught here => default to cache
- return has_file_in_cache
-
-
-class PushToHubMixin:
- """
- A Mixin containing the functionality to push a model or tokenizer to the hub.
- """
-
- def _create_repo(
- self,
- repo_id: str,
- private: Optional[bool] = None,
- token: Optional[Union[bool, str]] = None,
- repo_url: Optional[str] = None,
- organization: Optional[str] = None,
- ) -> str:
- """
- Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
- the token.
- """
- if repo_url is not None:
- warnings.warn(
- "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` " "instead."
- )
- if repo_id is not None:
- raise ValueError("`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`.")
- repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
- if organization is not None:
- warnings.warn(
- "The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
- "organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
- )
- if not repo_id.startswith(organization):
- if "/" in repo_id:
- repo_id = repo_id.split("/")[-1]
- repo_id = f"{organization}/{repo_id}"
-
- url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
- return url.repo_id
-
- def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
- """
- Returns the list of files with their last modification timestamp.
- """
- return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
-
- def _upload_modified_files(
- self,
- working_dir: Union[str, os.PathLike],
- repo_id: str,
- files_timestamps: Dict[str, float],
- commit_message: Optional[str] = None,
- token: Optional[Union[bool, str]] = None,
- create_pr: bool = False,
- revision: str = None,
- commit_description: str = None,
- ):
- """
- Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
- """
- if commit_message is None:
- if "Model" in self.__class__.__name__:
- commit_message = "Upload model"
- elif "Config" in self.__class__.__name__:
- commit_message = "Upload config"
- elif "Tokenizer" in self.__class__.__name__:
- commit_message = "Upload tokenizer"
- elif "FeatureExtractor" in self.__class__.__name__:
- commit_message = "Upload feature extractor"
- elif "Processor" in self.__class__.__name__:
- commit_message = "Upload processor"
- else:
- commit_message = f"Upload {self.__class__.__name__}"
- modified_files = [
- f
- for f in os.listdir(working_dir)
- if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
- ]
-
- # filter for actual files + folders at the root level
- modified_files = [
- f
- for f in modified_files
- if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
- ]
-
- operations = []
- # upload standalone files
- for file in modified_files:
- if os.path.isdir(os.path.join(working_dir, file)):
- # go over individual files of folder
- for f in os.listdir(os.path.join(working_dir, file)):
- operations.append(
- CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f))
- )
- else:
- operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file))
-
- if revision is not None:
- create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
-
- logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
- return create_commit(
- repo_id=repo_id,
- operations=operations,
- commit_message=commit_message,
- commit_description=commit_description,
- token=token,
- create_pr=create_pr,
- revision=revision,
- )
-
- def push_to_hub(
- self,
- repo_id: str,
- use_temp_dir: Optional[bool] = None,
- commit_message: Optional[str] = None,
- private: Optional[bool] = None,
- token: Optional[Union[bool, str]] = None,
- max_shard_size: Optional[Union[int, str]] = "5GB",
- create_pr: bool = False,
- safe_serialization: bool = True,
- revision: str = None,
- commit_description: str = None,
- tags: Optional[List[str]] = None,
- **deprecated_kwargs,
- ) -> str:
- """
- Upload the {object_files} to the 🤗 Model Hub.
-
- Parameters:
- repo_id (`str`):
- The name of the repository you want to push your {object} to. It should contain your organization name
- when pushing to a given organization.
- use_temp_dir (`bool`, *optional*):
- Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
- Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
- commit_message (`str`, *optional*):
- Message to commit while pushing. Will default to `"Upload {object}"`.
- private (`bool`, *optional*):
- Whether or not the repository created should be private.
- token (`bool` or `str`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
- is not specified.
- max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
- Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
- will then be each of size lower than this size. If expressed as a string, needs to be digits followed
- by a unit (like `"5MB"`). We default it to `"5GB"` so that users can easily load models on free-tier
- Google Colab instances without any CPU OOM issues.
- create_pr (`bool`, *optional*, defaults to `False`):
- Whether or not to create a PR with the uploaded files or directly commit.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether or not to convert the model weights in safetensors format for safer serialization.
- revision (`str`, *optional*):
- Branch to push the uploaded files to.
- commit_description (`str`, *optional*):
- The description of the commit that will be created
- tags (`List[str]`, *optional*):
- List of tags to push on the Hub.
-
- Examples:
-
- ```python
- from transformers import {object_class}
-
- {object} = {object_class}.from_pretrained("google-bert/bert-base-cased")
-
- # Push the {object} to your namespace with the name "my-finetuned-bert".
- {object}.push_to_hub("my-finetuned-bert")
-
- # Push the {object} to an organization with the name "my-finetuned-bert".
- {object}.push_to_hub("huggingface/my-finetuned-bert")
- ```
- """
- use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
- ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
- token = use_auth_token
-
- repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
- if repo_path_or_name is not None:
- # Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
- # repo_id from the folder path, if it exists.
- warnings.warn(
- "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
- "`repo_id` instead.",
- FutureWarning,
- )
- if repo_id is not None:
- raise ValueError("`repo_id` and `repo_path_or_name` are both specified. Please set only the argument `repo_id`.")
- if os.path.isdir(repo_path_or_name):
- # repo_path: infer repo_id from the path
- repo_id = repo_id.split(os.path.sep)[-1]
- working_dir = repo_id
- else:
- # repo_name: use it as repo_id
- repo_id = repo_path_or_name
- working_dir = repo_id.split("/")[-1]
- else:
- # Repo_id is passed correctly: infer working_dir from it
- working_dir = repo_id.split("/")[-1]
-
- # Deprecation warning will be sent after for repo_url and organization
- repo_url = deprecated_kwargs.pop("repo_url", None)
- organization = deprecated_kwargs.pop("organization", None)
-
- repo_id = self._create_repo(repo_id, private=private, token=token, repo_url=repo_url, organization=organization)
-
- # Create a new empty model card and eventually tag it
- model_card = create_and_tag_model_card(repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors)
-
- if use_temp_dir is None:
- use_temp_dir = not os.path.isdir(working_dir)
-
- with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
- files_timestamps = self._get_files_timestamps(work_dir)
-
- # Save all files.
- self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
-
- # Update model card if needed:
- model_card.save(os.path.join(work_dir, "README.md"))
-
- return self._upload_modified_files(
- work_dir,
- repo_id,
- files_timestamps,
- commit_message=commit_message,
- token=token,
- create_pr=create_pr,
- revision=revision,
- commit_description=commit_description,
- )
-
-
-def send_example_telemetry(example_name, *example_args, framework="pytorch"):
- """
- Sends telemetry that helps tracking the examples use.
-
- Args:
- example_name (`str`): The name of the example.
- *example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
- try to extract the model and dataset name from those. Nothing else is tracked.
- framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
- """
- if is_offline_mode():
- return
-
- data = {"example": example_name, "framework": framework}
- for args in example_args:
- args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
- if "model_name_or_path" in args_as_dict:
- model_name = args_as_dict["model_name_or_path"]
- # Filter out local paths
- if not os.path.isdir(model_name):
- data["model_name"] = args_as_dict["model_name_or_path"]
- if "dataset_name" in args_as_dict:
- data["dataset_name"] = args_as_dict["dataset_name"]
- elif "task_name" in args_as_dict:
- # Extract script name from the example_name
- script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
- script_name = script_name.replace("_no_trainer", "")
- data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
-
- # Send telemetry in the background
- send_telemetry(
- topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
- )
-
-
-def convert_file_size_to_int(size: Union[int, str]):
- """
- Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
-
- Args:
- size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
-
- Example:
- ```py
- >>> convert_file_size_to_int("1MiB")
- 1048576
- ```
- """
- if isinstance(size, int):
- return size
- if size.upper().endswith("GIB"):
- return int(size[:-3]) * (2**30)
- if size.upper().endswith("MIB"):
- return int(size[:-3]) * (2**20)
- if size.upper().endswith("KIB"):
- return int(size[:-3]) * (2**10)
- if size.upper().endswith("GB"):
- int_size = int(size[:-2]) * (10**9)
- return int_size // 8 if size.endswith("b") else int_size
- if size.upper().endswith("MB"):
- int_size = int(size[:-2]) * (10**6)
- return int_size // 8 if size.endswith("b") else int_size
- if size.upper().endswith("KB"):
- int_size = int(size[:-2]) * (10**3)
- return int_size // 8 if size.endswith("b") else int_size
- raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
-
-
-def get_checkpoint_shard_files(
- pretrained_model_name_or_path,
- index_filename,
- cache_dir=None,
- force_download=False,
- proxies=None,
- resume_download=None,
- local_files_only=False,
- token=None,
- user_agent=None,
- revision=None,
- subfolder="",
- _commit_hash=None,
- **deprecated_kwargs,
-):
- """
- For a given model:
-
- - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
- Hub
- - returns the list of paths to all the shards, as well as some metadata.
-
- For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
- index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
- """
- import json
-
- use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
- token = use_auth_token
-
- if not os.path.isfile(index_filename):
- raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
-
- with open(index_filename, "r") as f:
- index = json.loads(f.read())
-
- shard_filenames = sorted(set(index["weight_map"].values()))
- sharded_metadata = index["metadata"]
- sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
- sharded_metadata["weight_map"] = index["weight_map"].copy()
-
- # First, let's deal with local folder.
- if os.path.isdir(pretrained_model_name_or_path):
- shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
- return shard_filenames, sharded_metadata
-
- # At this stage pretrained_model_name_or_path is a model identifier on the Hub
- cached_filenames = []
- # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
- # downloaded (if interrupted).
- last_shard = try_to_load_from_cache(
- pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
- )
- show_progress_bar = last_shard is None or force_download
- for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
- try:
- # Load from URL
- cached_filename = cached_file(
- pretrained_model_name_or_path,
- shard_filename,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _commit_hash=_commit_hash,
- )
- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
- # we don't have to catch them here.
- except EntryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
- "required according to the checkpoint index."
- )
- except HTTPError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
- " again after checking your internet connection."
- )
-
- cached_filenames.append(cached_filename)
-
- return cached_filenames, sharded_metadata
-
-
-# All what is below is for conversion between old cache format and new cache format.
-
-
-def get_all_cached_files(cache_dir=None):
- """
- Returns a list for all files cached with appropriate metadata.
- """
- if cache_dir is None:
- cache_dir = TRANSFORMERS_CACHE
- else:
- cache_dir = str(cache_dir)
- if not os.path.isdir(cache_dir):
- return []
-
- cached_files = []
- for file in os.listdir(cache_dir):
- meta_path = os.path.join(cache_dir, f"{file}.json")
- if not os.path.isfile(meta_path):
- continue
-
- with open(meta_path, encoding="utf-8") as meta_file:
- metadata = json.load(meta_file)
- url = metadata["url"]
- etag = metadata["etag"].replace('"', "")
- cached_files.append({"file": file, "url": url, "etag": etag})
-
- return cached_files
-
-
-def extract_info_from_url(url):
- """
- Extract repo_name, revision and filename from an url.
- """
- search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
- if search is None:
- return None
- repo, revision, filename = search.groups()
- cache_repo = "--".join(["models"] + repo.split("/"))
- return {"repo": cache_repo, "revision": revision, "filename": filename}
-
-
-def create_and_tag_model_card(
- repo_id: str,
- tags: Optional[List[str]] = None,
- token: Optional[str] = None,
- ignore_metadata_errors: bool = False,
-):
- """
- Creates or loads an existing model card and tags it.
-
- Args:
- repo_id (`str`):
- The repo_id where to look for the model card.
- tags (`List[str]`, *optional*):
- The list of tags to add in the model card
- token (`str`, *optional*):
- Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
- ignore_metadata_errors (`str`):
- If True, errors while parsing the metadata section will be ignored. Some information might be lost during
- the process. Use it at your own risk.
- """
- try:
- # Check if the model card is present on the remote repo
- model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
- except EntryNotFoundError:
- # Otherwise create a simple model card from template
- model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
- card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
- model_card = ModelCard.from_template(card_data, model_description=model_description)
-
- if tags is not None:
- for model_tag in tags:
- if model_tag not in model_card.data.tags:
- model_card.data.tags.append(model_tag)
-
- return model_card
-
-
-def clean_files_for(file):
- """
- Remove, if they exist, file, file.json and file.lock
- """
- for f in [file, f"{file}.json", f"{file}.lock"]:
- if os.path.isfile(f):
- os.remove(f)
-
-
-def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
- """
- Move file to repo following the new huggingface hub cache organization.
- """
- os.makedirs(repo, exist_ok=True)
-
- # refs
- os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
- if revision != commit_hash:
- ref_path = os.path.join(repo, "refs", revision)
- with open(ref_path, "w") as f:
- f.write(commit_hash)
-
- # blobs
- os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
- blob_path = os.path.join(repo, "blobs", etag)
- shutil.move(file, blob_path)
-
- # snapshots
- os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
- os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
- pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
- huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
- clean_files_for(file)
-
-
-def move_cache(cache_dir=None, new_cache_dir=None, token=None):
- if new_cache_dir is None:
- new_cache_dir = TRANSFORMERS_CACHE
- if cache_dir is None:
- # Migrate from old cache in .cache/huggingface/transformers
- old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
- if os.path.isdir(str(old_cache)):
- cache_dir = str(old_cache)
- else:
- cache_dir = new_cache_dir
- cached_files = get_all_cached_files(cache_dir=cache_dir)
- logger.info(f"Moving {len(cached_files)} files to the new cache system")
-
- hub_metadata = {}
- for file_info in tqdm(cached_files):
- url = file_info.pop("url")
- if url not in hub_metadata:
- try:
- hub_metadata[url] = get_hf_file_metadata(url, token=token)
- except requests.HTTPError:
- continue
-
- etag, commit_hash = hub_metadata[url].etag, hub_metadata[url].commit_hash
- if etag is None or commit_hash is None:
- continue
-
- if file_info["etag"] != etag:
- # Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
- clean_files_for(os.path.join(cache_dir, file_info["file"]))
- continue
-
- url_info = extract_info_from_url(url)
- if url_info is None:
- # Not a file from huggingface.co
- continue
-
- repo = os.path.join(new_cache_dir, url_info["repo"])
- move_to_new_cache(
- file=os.path.join(cache_dir, file_info["file"]),
- repo=repo,
- filename=url_info["filename"],
- revision=url_info["revision"],
- etag=etag,
- commit_hash=commit_hash,
- )
-
-
-class PushInProgress:
- """
- Internal class to keep track of a push in progress (which might contain multiple `Future` jobs).
- """
-
- def __init__(self, jobs: Optional[futures.Future] = None) -> None:
- self.jobs = [] if jobs is None else jobs
-
- def is_done(self):
- return all(job.done() for job in self.jobs)
-
- def wait_until_done(self):
- futures.wait(self.jobs)
-
- def cancel(self) -> None:
- self.jobs = [
- job
- for job in self.jobs
- # Cancel the job if it wasn't started yet and remove cancelled/done jobs from the list
- if not (job.cancel() or job.done())
- ]
-
-
-cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
-if not os.path.isfile(cache_version_file):
- cache_version = 0
-else:
- with open(cache_version_file) as f:
- try:
- cache_version = int(f.read())
- except ValueError:
- cache_version = 0
-
-cache_is_not_empty = os.path.isdir(TRANSFORMERS_CACHE) and len(os.listdir(TRANSFORMERS_CACHE)) > 0
-
-if cache_version < 1 and cache_is_not_empty:
- if is_offline_mode():
- logger.warning(
- "You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
- "cache seems to be the one of a previous version. It is very likely that all your calls to any "
- "`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
- "your cache be updated automatically, then you can go back to offline mode."
- )
- else:
- logger.warning(
- "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a "
- "one-time only operation. You can interrupt this and resume the migration later on by calling "
- "`transformers.utils.move_cache()`."
- )
- try:
- if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
- # Users set some env variable to customize cache storage
- move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
- else:
- move_cache()
- except Exception as e:
- trace = "\n".join(traceback.format_tb(e.__traceback__))
- logger.error(
- f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
- "file an issue at https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole "
- "message and we will do our best to help."
- )
-
-if cache_version < 1:
- try:
- os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
- with open(cache_version_file, "w") as f:
- f.write("1")
- except Exception:
- logger.warning(
- f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
- "the environment variable TRANSFORMERS_CACHE to a writable directory."
- )
diff --git a/src/maxdiffusion/transformers/utils/import_utils.py b/src/maxdiffusion/transformers/utils/import_utils.py
deleted file mode 100644
index 04398ffc3..000000000
--- a/src/maxdiffusion/transformers/utils/import_utils.py
+++ /dev/null
@@ -1,1596 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Import utilities: Utilities related to imports and our lazy inits.
-"""
-
-import importlib.metadata
-import importlib.util
-import json
-import os
-import shutil
-import subprocess
-import sys
-import warnings
-from collections import OrderedDict
-from functools import lru_cache
-from itertools import chain
-from types import ModuleType
-from typing import Any, Tuple, Union
-
-from packaging import version
-
-from . import logging
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
-def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
- # Check if the package spec exists and grab its version to avoid importing a local directory
- package_exists = importlib.util.find_spec(pkg_name) is not None
- package_version = "N/A"
- if package_exists:
- try:
- # Primary method to get the package version
- package_version = importlib.metadata.version(pkg_name)
- except importlib.metadata.PackageNotFoundError:
- # Fallback method: Only for "torch" and versions containing "dev"
- if pkg_name == "torch":
- try:
- package = importlib.import_module(pkg_name)
- temp_version = getattr(package, "__version__", "N/A")
- # Check if the version contains "dev"
- if "dev" in temp_version:
- package_version = temp_version
- package_exists = True
- else:
- package_exists = False
- except ImportError:
- # If the package can't be imported, it's not available
- package_exists = False
- else:
- # For packages other than "torch", don't attempt the fallback and set as not available
- package_exists = False
- logger.debug(f"Detected {pkg_name} version: {package_version}")
- if return_version:
- return package_exists, package_version
- else:
- return package_exists
-
-
-ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
-ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
-
-USE_TF = os.environ.get("USE_TF", "AUTO").upper()
-USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
-USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
-
-# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
-USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
-
-FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
-
-# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
-# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
-TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
-
-ACCELERATE_MIN_VERSION = "0.21.0"
-FSDP_MIN_VERSION = "1.12.0"
-XLA_FSDPV2_MIN_VERSION = "2.2.0"
-
-
-_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
-_apex_available = _is_package_available("apex")
-_aqlm_available = _is_package_available("aqlm")
-_av_available = importlib.util.find_spec("av") is not None
-_bitsandbytes_available = _is_package_available("bitsandbytes")
-_eetq_available = _is_package_available("eetq")
-_galore_torch_available = _is_package_available("galore_torch")
-_lomo_available = _is_package_available("lomo_optim")
-# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
-_bs4_available = importlib.util.find_spec("bs4") is not None
-_coloredlogs_available = _is_package_available("coloredlogs")
-# `importlib.metadata.util` doesn't work with `opencv-python-headless`.
-_cv2_available = importlib.util.find_spec("cv2") is not None
-_datasets_available = _is_package_available("datasets")
-_decord_available = importlib.util.find_spec("decord") is not None
-_detectron2_available = _is_package_available("detectron2")
-# We need to check both `faiss` and `faiss-cpu`.
-_faiss_available = importlib.util.find_spec("faiss") is not None
-try:
- _faiss_version = importlib.metadata.version("faiss")
- logger.debug(f"Successfully imported faiss version {_faiss_version}")
-except importlib.metadata.PackageNotFoundError:
- try:
- _faiss_version = importlib.metadata.version("faiss-cpu")
- logger.debug(f"Successfully imported faiss version {_faiss_version}")
- except importlib.metadata.PackageNotFoundError:
- _faiss_available = False
-_ftfy_available = _is_package_available("ftfy")
-_g2p_en_available = _is_package_available("g2p_en")
-_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
-_jieba_available = _is_package_available("jieba")
-_jinja_available = _is_package_available("jinja2")
-_kenlm_available = _is_package_available("kenlm")
-_keras_nlp_available = _is_package_available("keras_nlp")
-_levenshtein_available = _is_package_available("Levenshtein")
-_librosa_available = _is_package_available("librosa")
-_natten_available = _is_package_available("natten")
-_nltk_available = _is_package_available("nltk")
-_onnx_available = _is_package_available("onnx")
-_openai_available = _is_package_available("openai")
-_optimum_available = _is_package_available("optimum")
-_auto_gptq_available = _is_package_available("auto_gptq")
-# `importlib.metadata.version` doesn't work with `awq`
-_auto_awq_available = importlib.util.find_spec("awq") is not None
-_quanto_available = _is_package_available("quanto")
-_pandas_available = _is_package_available("pandas")
-_peft_available = _is_package_available("peft")
-_phonemizer_available = _is_package_available("phonemizer")
-_psutil_available = _is_package_available("psutil")
-_py3nvml_available = _is_package_available("py3nvml")
-_pyctcdecode_available = _is_package_available("pyctcdecode")
-_pygments_available = _is_package_available("pygments")
-_pytesseract_available = _is_package_available("pytesseract")
-_pytest_available = _is_package_available("pytest")
-_pytorch_quantization_available = _is_package_available("pytorch_quantization")
-_rjieba_available = _is_package_available("rjieba")
-_sacremoses_available = _is_package_available("sacremoses")
-_safetensors_available = _is_package_available("safetensors")
-_scipy_available = _is_package_available("scipy")
-_sentencepiece_available = _is_package_available("sentencepiece")
-_is_seqio_available = _is_package_available("seqio")
-_is_gguf_available = _is_package_available("gguf")
-_sklearn_available = importlib.util.find_spec("sklearn") is not None
-if _sklearn_available:
- try:
- importlib.metadata.version("scikit-learn")
- except importlib.metadata.PackageNotFoundError:
- _sklearn_available = False
-_smdistributed_available = importlib.util.find_spec("smdistributed") is not None
-_soundfile_available = _is_package_available("soundfile")
-_spacy_available = _is_package_available("spacy")
-_sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True)
-_tensorflow_probability_available = _is_package_available("tensorflow_probability")
-_tensorflow_text_available = _is_package_available("tensorflow_text")
-_tf2onnx_available = _is_package_available("tf2onnx")
-_timm_available = _is_package_available("timm")
-_tokenizers_available = _is_package_available("tokenizers")
-_torchaudio_available = _is_package_available("torchaudio")
-_torchdistx_available = _is_package_available("torchdistx")
-_torchvision_available = _is_package_available("torchvision")
-_mlx_available = _is_package_available("mlx")
-_hqq_available = _is_package_available("hqq")
-
-
-_torch_version = "N/A"
-_torch_available = False
-if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
- _torch_available, _torch_version = _is_package_available("torch", return_version=True)
-else:
- logger.info("Disabling PyTorch because USE_TF is set")
- _torch_available = False
-
-
-_tf_version = "N/A"
-_tf_available = False
-if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
- _tf_available = True
-else:
- if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
- # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
- # with tensorflow-cpu to make sure it still works!
- _tf_available = importlib.util.find_spec("tensorflow") is not None
- if _tf_available:
- candidates = (
- "tensorflow",
- "tensorflow-cpu",
- "tensorflow-gpu",
- "tf-nightly",
- "tf-nightly-cpu",
- "tf-nightly-gpu",
- "tf-nightly-rocm",
- "intel-tensorflow",
- "intel-tensorflow-avx512",
- "tensorflow-rocm",
- "tensorflow-macos",
- "tensorflow-aarch64",
- )
- _tf_version = None
- # For the metadata, we have to look for both tensorflow and tensorflow-cpu
- for pkg in candidates:
- try:
- _tf_version = importlib.metadata.version(pkg)
- break
- except importlib.metadata.PackageNotFoundError:
- pass
- _tf_available = _tf_version is not None
- if _tf_available:
- if version.parse(_tf_version) < version.parse("2"):
- logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
- _tf_available = False
- else:
- logger.info("Disabling Tensorflow because USE_TORCH is set")
-
-
-_essentia_available = importlib.util.find_spec("essentia") is not None
-try:
- _essentia_version = importlib.metadata.version("essentia")
- logger.debug(f"Successfully imported essentia version {_essentia_version}")
-except importlib.metadata.PackageNotFoundError:
- _essentia_version = False
-
-
-_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
-try:
- _pretty_midi_version = importlib.metadata.version("pretty_midi")
- logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
-except importlib.metadata.PackageNotFoundError:
- _pretty_midi_available = False
-
-
-ccl_version = "N/A"
-_is_ccl_available = (
- importlib.util.find_spec("torch_ccl") is not None or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
-)
-try:
- ccl_version = importlib.metadata.version("oneccl_bind_pt")
- logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
-except importlib.metadata.PackageNotFoundError:
- _is_ccl_available = False
-
-
-_flax_available = False
-if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
- _flax_available, _flax_version = _is_package_available("flax", return_version=True)
- if _flax_available:
- _jax_available, _jax_version = _is_package_available("jax", return_version=True)
- if _jax_available:
- logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
- else:
- _flax_available = _jax_available = False
- _jax_version = _flax_version = "N/A"
-
-
-_torch_fx_available = False
-if _torch_available:
- torch_version = version.parse(_torch_version)
- _torch_fx_available = (torch_version.major, torch_version.minor) >= (
- TORCH_FX_REQUIRED_VERSION.major,
- TORCH_FX_REQUIRED_VERSION.minor,
- )
-
-
-_torch_xla_available = False
-if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
- _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
- if _torch_xla_available:
- logger.info(f"Torch XLA version {_torch_xla_version} available.")
-
-
-def is_kenlm_available():
- return _kenlm_available
-
-
-def is_cv2_available():
- return _cv2_available
-
-
-def is_torch_available():
- return _torch_available
-
-
-def is_torch_deterministic():
- """
- Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
- """
- import torch
-
- if torch.get_deterministic_debug_mode() == 0:
- return False
- else:
- return True
-
-
-def is_hqq_available():
- return _hqq_available
-
-
-def is_pygments_available():
- return _pygments_available
-
-
-def get_torch_version():
- return _torch_version
-
-
-def is_torch_sdpa_available():
- if not is_torch_available():
- return False
- elif _torch_version == "N/A":
- return False
-
- # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
- # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
- # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
- # NOTE: MLU is OK with non-contiguous inputs.
- if is_torch_mlu_available():
- return version.parse(_torch_version) >= version.parse("2.1.0")
- # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
- return version.parse(_torch_version) >= version.parse("2.1.1")
-
-
-def is_torchvision_available():
- return _torchvision_available
-
-
-def is_galore_torch_available():
- return _galore_torch_available
-
-
-def is_lomo_available():
- return _lomo_available
-
-
-def is_pyctcdecode_available():
- return _pyctcdecode_available
-
-
-def is_librosa_available():
- return _librosa_available
-
-
-def is_essentia_available():
- return _essentia_available
-
-
-def is_pretty_midi_available():
- return _pretty_midi_available
-
-
-def is_torch_cuda_available():
- if is_torch_available():
- import torch
-
- return torch.cuda.is_available()
- else:
- return False
-
-
-def is_mamba_ssm_available():
- if is_torch_available():
- import torch
-
- if not torch.cuda.is_available():
- return False
- else:
- return _is_package_available("mamba_ssm")
- return False
-
-
-def is_causal_conv1d_available():
- if is_torch_available():
- import torch
-
- if not torch.cuda.is_available():
- return False
- return _is_package_available("causal_conv1d")
- return False
-
-
-def is_torch_mps_available():
- if is_torch_available():
- import torch
-
- if hasattr(torch.backends, "mps"):
- return torch.backends.mps.is_available() and torch.backends.mps.is_built()
- return False
-
-
-def is_torch_bf16_gpu_available():
- if not is_torch_available():
- return False
-
- import torch
-
- return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
-
-
-def is_torch_bf16_cpu_available():
- if not is_torch_available():
- return False
-
- import torch
-
- try:
- # multiple levels of AttributeError depending on the pytorch version so do them all in one check
- _ = torch.cpu.amp.autocast
- except AttributeError:
- return False
-
- return True
-
-
-def is_torch_bf16_available():
- # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
- # has become ambiguous and therefore deprecated
- warnings.warn(
- "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
- "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
- FutureWarning,
- )
- return is_torch_bf16_gpu_available()
-
-
-@lru_cache()
-def is_torch_fp16_available_on_device(device):
- if not is_torch_available():
- return False
-
- import torch
-
- try:
- x = torch.zeros(2, 2, dtype=torch.float16).to(device)
- _ = x @ x
-
- # At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many
- # models use this layer.
- batch, sentence_length, embedding_dim = 3, 4, 5
- embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device)
- layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device)
- _ = layer_norm(embedding)
-
- except: # noqa: E722
- # TODO: more precise exception matching, if possible.
- # most backends should return `RuntimeError` however this is not guaranteed.
- return False
-
- return True
-
-
-@lru_cache()
-def is_torch_bf16_available_on_device(device):
- if not is_torch_available():
- return False
-
- import torch
-
- if device == "cuda":
- return is_torch_bf16_gpu_available()
-
- try:
- x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
- _ = x @ x
- except: # noqa: E722
- # TODO: more precise exception matching, if possible.
- # most backends should return `RuntimeError` however this is not guaranteed.
- return False
-
- return True
-
-
-def is_torch_tf32_available():
- if not is_torch_available():
- return False
-
- import torch
-
- if not torch.cuda.is_available() or torch.version.cuda is None:
- return False
- if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
- return False
- if int(torch.version.cuda.split(".")[0]) < 11:
- return False
- if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
- return False
-
- return True
-
-
-def is_torch_fx_available():
- return _torch_fx_available
-
-
-def is_peft_available():
- return _peft_available
-
-
-def is_bs4_available():
- return _bs4_available
-
-
-def is_tf_available():
- return _tf_available
-
-
-def is_coloredlogs_available():
- return _coloredlogs_available
-
-
-def is_tf2onnx_available():
- return _tf2onnx_available
-
-
-def is_onnx_available():
- return _onnx_available
-
-
-def is_openai_available():
- return _openai_available
-
-
-def is_flax_available():
- return _flax_available
-
-
-def is_ftfy_available():
- return _ftfy_available
-
-
-def is_g2p_en_available():
- return _g2p_en_available
-
-
-@lru_cache()
-def is_torch_tpu_available(check_device=True):
- "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
- warnings.warn(
- "`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
- "Please use the `is_torch_xla_available` instead.",
- FutureWarning,
- )
-
- if not _torch_available:
- return False
- if importlib.util.find_spec("torch_xla") is not None:
- if check_device:
- # We need to check if `xla_device` can be found, will raise a RuntimeError if not
- try:
- import torch_xla.core.xla_model as xm
-
- _ = xm.xla_device()
- return True
- except RuntimeError:
- return False
- return True
- return False
-
-
-@lru_cache
-def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
- """
- Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
- the USE_TORCH_XLA to false.
- """
- assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
-
- if not _torch_xla_available:
- return False
-
- import torch_xla
-
- if check_is_gpu:
- return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
- elif check_is_tpu:
- return torch_xla.runtime.device_type() == "TPU"
-
- return True
-
-
-@lru_cache()
-def is_torch_neuroncore_available(check_device=True):
- if importlib.util.find_spec("torch_neuronx") is not None:
- return is_torch_xla_available()
- return False
-
-
-@lru_cache()
-def is_torch_npu_available(check_device=False):
- "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
- if not _torch_available or importlib.util.find_spec("torch_npu") is None:
- return False
-
- import torch
- import torch_npu # noqa: F401
-
- if check_device:
- try:
- # Will raise a RuntimeError if no NPU is found
- _ = torch.npu.device_count()
- return torch.npu.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "npu") and torch.npu.is_available()
-
-
-@lru_cache()
-def is_torch_mlu_available(check_device=False):
- "Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
- if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
- return False
-
- import torch
- import torch_mlu # noqa: F401
-
- from ..dependency_versions_table import deps
-
- deps["deepspeed"] = "deepspeed-mlu>=0.10.1"
-
- if check_device:
- try:
- # Will raise a RuntimeError if no MLU is found
- _ = torch.mlu.device_count()
- return torch.mlu.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "mlu") and torch.mlu.is_available()
-
-
-def is_torchdynamo_available():
- if not is_torch_available():
- return False
-
- return version.parse(_torch_version) >= version.parse("2.0.0")
-
-
-def is_torch_compile_available():
- if not is_torch_available():
- return False
-
- import torch
-
- # We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
- # 2.0 but let's do it later.
- return hasattr(torch, "compile")
-
-
-def is_torchdynamo_compiling():
- if not is_torch_available():
- return False
- try:
- # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
- if version.parse(_torch_version) >= version.parse("2.3.0"):
- import torch
-
- return torch.compiler.is_compiling()
- else:
- import torch._dynamo as dynamo # noqa: F401
-
- return dynamo.is_compiling()
- except Exception:
- return False
-
-
-def is_torch_tensorrt_fx_available():
- if importlib.util.find_spec("torch_tensorrt") is None:
- return False
- return importlib.util.find_spec("torch_tensorrt.fx") is not None
-
-
-def is_datasets_available():
- return _datasets_available
-
-
-def is_detectron2_available():
- return _detectron2_available
-
-
-def is_rjieba_available():
- return _rjieba_available
-
-
-def is_psutil_available():
- return _psutil_available
-
-
-def is_py3nvml_available():
- return _py3nvml_available
-
-
-def is_sacremoses_available():
- return _sacremoses_available
-
-
-def is_apex_available():
- return _apex_available
-
-
-def is_aqlm_available():
- return _aqlm_available
-
-
-def is_av_available():
- return _av_available
-
-
-def is_ninja_available():
- r"""
- Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
- [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
- """
- try:
- subprocess.check_output("ninja --version".split())
- except Exception:
- return False
- else:
- return True
-
-
-def is_ipex_available():
- def get_major_and_minor_from_version(full_version):
- return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
-
- if not is_torch_available() or not _ipex_available:
- return False
-
- torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
- ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
- if torch_major_and_minor != ipex_major_and_minor:
- logger.warning(
- f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
- f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
- )
- return False
- return True
-
-
-@lru_cache
-def is_torch_xpu_available(check_device=False):
- """
- Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or
- via stock PyTorch (>=2.4) and potentially if a XPU is in the environment
- """
- if not is_torch_available():
- return False
-
- torch_version = version.parse(_torch_version)
- if is_ipex_available():
- import intel_extension_for_pytorch # noqa: F401
- elif torch_version.major < 2 or (torch_version.major == 2 and torch_version.minor < 4):
- return False
-
- import torch
-
- if check_device:
- try:
- # Will raise a RuntimeError if no XPU is found
- _ = torch.xpu.device_count()
- return torch.xpu.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "xpu") and torch.xpu.is_available()
-
-
-def is_bitsandbytes_available():
- if not is_torch_available():
- return False
-
- # bitsandbytes throws an error if cuda is not available
- # let's avoid that by adding a simple check
- import torch
-
- return _bitsandbytes_available and torch.cuda.is_available()
-
-
-def is_flash_attn_2_available():
- if not is_torch_available():
- return False
-
- if not _is_package_available("flash_attn"):
- return False
-
- # Let's add an extra check to see if cuda is available
- import torch
-
- if not (torch.cuda.is_available() or is_torch_mlu_available()):
- return False
-
- if torch.version.cuda:
- return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
- elif torch.version.hip:
- # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
- return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
- elif is_torch_mlu_available():
- return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.3.3")
- else:
- return False
-
-
-def is_flash_attn_greater_or_equal_2_10():
- if not _is_package_available("flash_attn"):
- return False
-
- return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
-
-
-def is_flash_attn_greater_or_equal(library_version: str):
- if not _is_package_available("flash_attn"):
- return False
-
- return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
-
-
-def is_torchdistx_available():
- return _torchdistx_available
-
-
-def is_faiss_available():
- return _faiss_available
-
-
-def is_scipy_available():
- return _scipy_available
-
-
-def is_sklearn_available():
- return _sklearn_available
-
-
-def is_sentencepiece_available():
- return _sentencepiece_available
-
-
-def is_seqio_available():
- return _is_seqio_available
-
-
-def is_gguf_available():
- return _is_gguf_available
-
-
-def is_protobuf_available():
- if importlib.util.find_spec("google") is None:
- return False
- return importlib.util.find_spec("google.protobuf") is not None
-
-
-def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
- return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
-
-
-def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
- return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version)
-
-
-def is_optimum_available():
- return _optimum_available
-
-
-def is_auto_awq_available():
- return _auto_awq_available
-
-
-def is_quanto_available():
- return _quanto_available
-
-
-def is_auto_gptq_available():
- return _auto_gptq_available
-
-
-def is_eetq_available():
- return _eetq_available
-
-
-def is_levenshtein_available():
- return _levenshtein_available
-
-
-def is_optimum_neuron_available():
- return _optimum_available and _is_package_available("optimum.neuron")
-
-
-def is_safetensors_available():
- return _safetensors_available
-
-
-def is_tokenizers_available():
- return _tokenizers_available
-
-
-@lru_cache
-def is_vision_available():
- _pil_available = importlib.util.find_spec("PIL") is not None
- if _pil_available:
- try:
- package_version = importlib.metadata.version("Pillow")
- except importlib.metadata.PackageNotFoundError:
- try:
- package_version = importlib.metadata.version("Pillow-SIMD")
- except importlib.metadata.PackageNotFoundError:
- return False
- logger.debug(f"Detected PIL version {package_version}")
- return _pil_available
-
-
-def is_pytesseract_available():
- return _pytesseract_available
-
-
-def is_pytest_available():
- return _pytest_available
-
-
-def is_spacy_available():
- return _spacy_available
-
-
-def is_tensorflow_text_available():
- return is_tf_available() and _tensorflow_text_available
-
-
-def is_keras_nlp_available():
- return is_tensorflow_text_available() and _keras_nlp_available
-
-
-def is_in_notebook():
- try:
- # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
- get_ipython = sys.modules["IPython"].get_ipython
- if "IPKernelApp" not in get_ipython().config:
- raise ImportError("console")
- if "VSCODE_PID" in os.environ:
- raise ImportError("vscode")
- if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
- # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
- # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
- raise ImportError("databricks")
-
- return importlib.util.find_spec("IPython") is not None
- except (AttributeError, ImportError, KeyError):
- return False
-
-
-def is_pytorch_quantization_available():
- return _pytorch_quantization_available
-
-
-def is_tensorflow_probability_available():
- return _tensorflow_probability_available
-
-
-def is_pandas_available():
- return _pandas_available
-
-
-def is_sagemaker_dp_enabled():
- # Get the sagemaker specific env variable.
- sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
- try:
- # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
- sagemaker_params = json.loads(sagemaker_params)
- if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
- return False
- except json.JSONDecodeError:
- return False
- # Lastly, check if the `smdistributed` module is present.
- return _smdistributed_available
-
-
-def is_sagemaker_mp_enabled():
- # Get the sagemaker specific mp parameters from smp_options variable.
- smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
- try:
- # Parse it and check the field "partitions" is included, it is required for model parallel.
- smp_options = json.loads(smp_options)
- if "partitions" not in smp_options:
- return False
- except json.JSONDecodeError:
- return False
-
- # Get the sagemaker specific framework parameters from mpi_options variable.
- mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
- try:
- # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
- mpi_options = json.loads(mpi_options)
- if not mpi_options.get("sagemaker_mpi_enabled", False):
- return False
- except json.JSONDecodeError:
- return False
- # Lastly, check if the `smdistributed` module is present.
- return _smdistributed_available
-
-
-def is_training_run_on_sagemaker():
- return "SAGEMAKER_JOB_NAME" in os.environ
-
-
-def is_soundfile_availble():
- return _soundfile_available
-
-
-def is_timm_available():
- return _timm_available
-
-
-def is_natten_available():
- return _natten_available
-
-
-def is_nltk_available():
- return _nltk_available
-
-
-def is_torchaudio_available():
- return _torchaudio_available
-
-
-def is_speech_available():
- # For now this depends on torchaudio but the exact dependency might evolve in the future.
- return _torchaudio_available
-
-
-def is_phonemizer_available():
- return _phonemizer_available
-
-
-def torch_only_method(fn):
- def wrapper(*args, **kwargs):
- if not _torch_available:
- raise ImportError(
- "You need to install pytorch to use this method or class, "
- "or activate it with environment variables USE_TORCH=1 and USE_TF=0."
- )
- else:
- return fn(*args, **kwargs)
-
- return wrapper
-
-
-def is_ccl_available():
- return _is_ccl_available
-
-
-def is_decord_available():
- return _decord_available
-
-
-def is_sudachi_available():
- return _sudachipy_available
-
-
-def get_sudachi_version():
- return _sudachipy_version
-
-
-def is_sudachi_projection_available():
- if not is_sudachi_available():
- return False
-
- # NOTE: We require sudachipy>=0.6.8 to use projection option in sudachi_kwargs for the constructor of BertJapaneseTokenizer.
- # - `projection` option is not supported in sudachipy<0.6.8, see https://github.com/WorksApplications/sudachi.rs/issues/230
- return version.parse(_sudachipy_version) >= version.parse("0.6.8")
-
-
-def is_jumanpp_available():
- return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None)
-
-
-def is_cython_available():
- return importlib.util.find_spec("pyximport") is not None
-
-
-def is_jieba_available():
- return _jieba_available
-
-
-def is_jinja_available():
- return _jinja_available
-
-
-def is_mlx_available():
- return _mlx_available
-
-
-# docstyle-ignore
-AV_IMPORT_ERROR = """
-{0} requires the PyAv library but it was not found in your environment. You can install it with:
-```
-pip install av
-```
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-CV2_IMPORT_ERROR = """
-{0} requires the OpenCV library but it was not found in your environment. You can install it with:
-```
-pip install opencv-python
-```
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-DATASETS_IMPORT_ERROR = """
-{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
-```
-pip install datasets
-```
-In a notebook or a colab, you can install it by executing a cell with
-```
-!pip install datasets
-```
-then restarting your kernel.
-
-Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
-working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
-that python file if that's the case. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-TOKENIZERS_IMPORT_ERROR = """
-{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
-```
-pip install tokenizers
-```
-In a notebook or a colab, you can install it by executing a cell with
-```
-!pip install tokenizers
-```
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-SENTENCEPIECE_IMPORT_ERROR = """
-{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
-installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
-that match your environment. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-PROTOBUF_IMPORT_ERROR = """
-{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
-installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
-that match your environment. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-FAISS_IMPORT_ERROR = """
-{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
-installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
-that match your environment. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-PYTORCH_IMPORT_ERROR = """
-{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
-installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-TORCHVISION_IMPORT_ERROR = """
-{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
-installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
-Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-PYTORCH_IMPORT_ERROR_WITH_TF = """
-{0} requires the PyTorch library but it was not found in your environment.
-However, we were able to find a TensorFlow installation. TensorFlow classes begin
-with "TF", but are otherwise identically named to our PyTorch classes. This
-means that the TF equivalent of the class you tried to import would be "TF{0}".
-If you want to use TensorFlow, please use TF classes instead!
-
-If you really do want to use PyTorch please go to
-https://pytorch.org/get-started/locally/ and follow the instructions that
-match your environment.
-"""
-
-# docstyle-ignore
-TF_IMPORT_ERROR_WITH_PYTORCH = """
-{0} requires the TensorFlow library but it was not found in your environment.
-However, we were able to find a PyTorch installation. PyTorch classes do not begin
-with "TF", but are otherwise identically named to our TF classes.
-If you want to use PyTorch, please use those classes instead!
-
-If you really do want to use TensorFlow, please follow the instructions on the
-installation page https://www.tensorflow.org/install that match your environment.
-"""
-
-# docstyle-ignore
-BS4_IMPORT_ERROR = """
-{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
-`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-SKLEARN_IMPORT_ERROR = """
-{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
-```
-pip install -U scikit-learn
-```
-In a notebook or a colab, you can install it by executing a cell with
-```
-!pip install -U scikit-learn
-```
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-TENSORFLOW_IMPORT_ERROR = """
-{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
-installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-DETECTRON2_IMPORT_ERROR = """
-{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
-installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
-that match your environment. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-FLAX_IMPORT_ERROR = """
-{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
-installation page: https://github.com/google/flax and follow the ones that match your environment.
-Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-FTFY_IMPORT_ERROR = """
-{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
-installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
-that match your environment. Please note that you may need to restart your runtime after installation.
-"""
-
-LEVENSHTEIN_IMPORT_ERROR = """
-{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
-install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-G2P_EN_IMPORT_ERROR = """
-{0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
-`pip install g2p-en`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-PYTORCH_QUANTIZATION_IMPORT_ERROR = """
-{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip:
-`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
-Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
-{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
-explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-TENSORFLOW_TEXT_IMPORT_ERROR = """
-{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
-explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-PANDAS_IMPORT_ERROR = """
-{0} requires the pandas library but it was not found in your environment. You can install it with pip as
-explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
-Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-PHONEMIZER_IMPORT_ERROR = """
-{0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
-`pip install phonemizer`. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-SACREMOSES_IMPORT_ERROR = """
-{0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
-`pip install sacremoses`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-SCIPY_IMPORT_ERROR = """
-{0} requires the scipy library but it was not found in your environment. You can install it with pip:
-`pip install scipy`. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-SPEECH_IMPORT_ERROR = """
-{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
-`pip install torchaudio`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-TIMM_IMPORT_ERROR = """
-{0} requires the timm library but it was not found in your environment. You can install it with pip:
-`pip install timm`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-NATTEN_IMPORT_ERROR = """
-{0} requires the natten library but it was not found in your environment. You can install it by referring to:
-shi-labs.com/natten . You can also install it with pip (may take longer to build):
-`pip install natten`. Please note that you may need to restart your runtime after installation.
-"""
-
-NUMEXPR_IMPORT_ERROR = """
-{0} requires the numexpr library but it was not found in your environment. You can install it by referring to:
-https://numexpr.readthedocs.io/en/latest/index.html.
-"""
-
-
-# docstyle-ignore
-NLTK_IMPORT_ERROR = """
-{0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
-https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-VISION_IMPORT_ERROR = """
-{0} requires the PIL library but it was not found in your environment. You can install it with pip:
-`pip install pillow`. Please note that you may need to restart your runtime after installation.
-"""
-
-
-# docstyle-ignore
-PYTESSERACT_IMPORT_ERROR = """
-{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
-`pip install pytesseract`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-PYCTCDECODE_IMPORT_ERROR = """
-{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
-`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-ACCELERATE_IMPORT_ERROR = """
-{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
-You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
-runtime after installation.
-"""
-
-# docstyle-ignore
-CCL_IMPORT_ERROR = """
-{0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
-`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
-Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-ESSENTIA_IMPORT_ERROR = """
-{0} requires essentia library. But that was not found in your environment. You can install them with pip:
-`pip install essentia==2.1b6.dev1034`
-Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-LIBROSA_IMPORT_ERROR = """
-{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
-`pip install librosa`
-Please note that you may need to restart your runtime after installation.
-"""
-
-# docstyle-ignore
-PRETTY_MIDI_IMPORT_ERROR = """
-{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
-`pip install pretty_midi`
-Please note that you may need to restart your runtime after installation.
-"""
-
-DECORD_IMPORT_ERROR = """
-{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
-decord`. Please note that you may need to restart your runtime after installation.
-"""
-
-CYTHON_IMPORT_ERROR = """
-{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
-Cython`. Please note that you may need to restart your runtime after installation.
-"""
-
-JIEBA_IMPORT_ERROR = """
-{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
-jieba`. Please note that you may need to restart your runtime after installation.
-"""
-
-PEFT_IMPORT_ERROR = """
-{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
-peft`. Please note that you may need to restart your runtime after installation.
-"""
-
-JINJA_IMPORT_ERROR = """
-{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
-jinja2`. Please note that you may need to restart your runtime after installation.
-"""
-
-BACKENDS_MAPPING = OrderedDict(
- [
- ("av", (is_av_available, AV_IMPORT_ERROR)),
- ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
- ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
- ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
- ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
- ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
- ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
- ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
- ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
- ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
- ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
- ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
- ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
- ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
- ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
- ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
- ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
- ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
- ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
- ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
- ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
- ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
- ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
- ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
- ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
- ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
- ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
- ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
- ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
- ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
- ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
- ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
- ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
- ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
- ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
- ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
- ("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
- ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
- ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
- ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
- ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
- ]
-)
-
-
-def requires_backends(obj, backends):
- if not isinstance(backends, (list, tuple)):
- backends = [backends]
-
- name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
-
- # Raise an error for users who might not realize that classes without "TF" are torch-only
- if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
- raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
-
- # Raise the inverse error for PyTorch users trying to load TF classes
- if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
- raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
-
- checks = (BACKENDS_MAPPING[backend] for backend in backends)
- failed = [msg.format(name) for available, msg in checks if not available()]
- if failed:
- raise ImportError("".join(failed))
-
-
-class DummyObject(type):
- """
- Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
- `requires_backend` each time a user tries to access any method of that class.
- """
-
- def __getattribute__(cls, key):
- if key.startswith("_") and key != "_from_config":
- return super().__getattribute__(key)
- requires_backends(cls, cls._backends)
-
-
-def is_torch_fx_proxy(x):
- if is_torch_fx_available():
- import torch.fx
-
- return isinstance(x, torch.fx.Proxy)
- return False
-
-
-class _LazyModule(ModuleType):
- """
- Module class that surfaces all objects but only performs associated imports when the objects are requested.
- """
-
- # Very heavily inspired by optuna.integration._IntegrationModule
- # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
- def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
- super().__init__(name)
- self._modules = set(import_structure.keys())
- self._class_to_module = {}
- for key, values in import_structure.items():
- for value in values:
- self._class_to_module[value] = key
- # Needed for autocompletion in an IDE
- self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
- self.__file__ = module_file
- self.__spec__ = module_spec
- self.__path__ = [os.path.dirname(module_file)]
- self._objects = {} if extra_objects is None else extra_objects
- self._name = name
- self._import_structure = import_structure
-
- # Needed for autocompletion in an IDE
- def __dir__(self):
- result = super().__dir__()
- # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
- # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
- for attr in self.__all__:
- if attr not in result:
- result.append(attr)
- return result
-
- def __getattr__(self, name: str) -> Any:
- if name in self._objects:
- return self._objects[name]
- if name in self._modules:
- value = self._get_module(name)
- elif name in self._class_to_module.keys():
- module = self._get_module(self._class_to_module[name])
- value = getattr(module, name)
- else:
- raise AttributeError(f"module {self.__name__} has no attribute {name}")
-
- setattr(self, name, value)
- return value
-
- def _get_module(self, module_name: str):
- return importlib.import_module("." + module_name, self.__name__)
-
- def __reduce__(self):
- return (self.__class__, (self._name, self.__file__, self._import_structure))
-
-
-class OptionalDependencyNotAvailable(BaseException):
- """Internally used error class for signalling an optional dependency was not found."""
-
-
-def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
- """Imports transformers directly
-
- Args:
- path (`str`): The path to the source file
- file (`str`, optional): The file to join with the path. Defaults to "__init__.py".
-
- Returns:
- `ModuleType`: The resulting imported module
- """
- name = "maxdiffusion.transformers"
- location = os.path.join(path, file)
- spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- module = sys.modules[name]
- return module
diff --git a/src/maxdiffusion/transformers/utils/logging.py b/src/maxdiffusion/transformers/utils/logging.py
deleted file mode 100644
index 1a7c1bfd3..000000000
--- a/src/maxdiffusion/transformers/utils/logging.py
+++ /dev/null
@@ -1,395 +0,0 @@
-# coding=utf-8
-# Copyright 2020 Optuna, Hugging Face
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Logging utilities."""
-
-import functools
-import logging
-import os
-import sys
-import threading
-from logging import (
- CRITICAL, # NOQA
- DEBUG, # NOQA
- ERROR, # NOQA
- FATAL, # NOQA
- INFO, # NOQA
- NOTSET, # NOQA
- WARN, # NOQA
- WARNING, # NOQA
-)
-from logging import captureWarnings as _captureWarnings
-from typing import Optional
-
-import huggingface_hub.utils as hf_hub_utils
-from tqdm import auto as tqdm_lib
-
-
-_lock = threading.Lock()
-_default_handler: Optional[logging.Handler] = None
-
-log_levels = {
- "detail": logging.DEBUG, # will also print filename and line number
- "debug": logging.DEBUG,
- "info": logging.INFO,
- "warning": logging.WARNING,
- "error": logging.ERROR,
- "critical": logging.CRITICAL,
-}
-
-_default_log_level = logging.WARNING
-
-_tqdm_active = not hf_hub_utils.are_progress_bars_disabled()
-
-
-def _get_default_logging_level():
- """
- If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
- not - fall back to `_default_log_level`
- """
- env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
- if env_level_str:
- if env_level_str in log_levels:
- return log_levels[env_level_str]
- else:
- logging.getLogger().warning(
- f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }"
- )
- return _default_log_level
-
-
-def _get_library_name() -> str:
- return __name__.split(".")[0]
-
-
-def _get_library_root_logger() -> logging.Logger:
- return logging.getLogger(_get_library_name())
-
-
-def _configure_library_root_logger() -> None:
- global _default_handler
-
- with _lock:
- if _default_handler:
- # This library has already configured the library root logger.
- return
- _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
- # set defaults based on https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
- if sys.stderr is None:
- sys.stderr = open(os.devnull, "w")
-
- _default_handler.flush = sys.stderr.flush
-
- # Apply our default configuration to the library root logger.
- library_root_logger = _get_library_root_logger()
- library_root_logger.addHandler(_default_handler)
- library_root_logger.setLevel(_get_default_logging_level())
- # if logging level is debug, we add pathname and lineno to formatter for easy debugging
- if os.getenv("TRANSFORMERS_VERBOSITY", None) == "detail":
- formatter = logging.Formatter("[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s")
- _default_handler.setFormatter(formatter)
-
- library_root_logger.propagate = False
-
-
-def _reset_library_root_logger() -> None:
- global _default_handler
-
- with _lock:
- if not _default_handler:
- return
-
- library_root_logger = _get_library_root_logger()
- library_root_logger.removeHandler(_default_handler)
- library_root_logger.setLevel(logging.NOTSET)
- _default_handler = None
-
-
-def get_log_levels_dict():
- return log_levels
-
-
-def captureWarnings(capture):
- """
- Calls the `captureWarnings` method from the logging library to enable management of the warnings emitted by the
- `warnings` library.
-
- Read more about this method here:
- https://docs.python.org/3/library/logging.html#integration-with-the-warnings-module
-
- All warnings will be logged through the `py.warnings` logger.
-
- Careful: this method also adds a handler to this logger if it does not already have one, and updates the logging
- level of that logger to the library's root logger.
- """
- logger = get_logger("py.warnings")
-
- if not logger.handlers:
- logger.addHandler(_default_handler)
-
- logger.setLevel(_get_library_root_logger().level)
-
- _captureWarnings(capture)
-
-
-def get_logger(name: Optional[str] = None) -> logging.Logger:
- """
- Return a logger with the specified name.
-
- This function is not supposed to be directly accessed unless you are writing a custom transformers module.
- """
-
- if name is None:
- name = _get_library_name()
-
- _configure_library_root_logger()
- return logging.getLogger(name)
-
-
-def get_verbosity() -> int:
- """
- Return the current level for the 🤗 Transformers's root logger as an int.
-
- Returns:
- `int`: The logging level.
-
-
-
- 🤗 Transformers has following logging levels:
-
- - 50: `transformers.logging.CRITICAL` or `transformers.logging.FATAL`
- - 40: `transformers.logging.ERROR`
- - 30: `transformers.logging.WARNING` or `transformers.logging.WARN`
- - 20: `transformers.logging.INFO`
- - 10: `transformers.logging.DEBUG`
-
- """
-
- _configure_library_root_logger()
- return _get_library_root_logger().getEffectiveLevel()
-
-
-def set_verbosity(verbosity: int) -> None:
- """
- Set the verbosity level for the 🤗 Transformers's root logger.
-
- Args:
- verbosity (`int`):
- Logging level, e.g., one of:
-
- - `transformers.logging.CRITICAL` or `transformers.logging.FATAL`
- - `transformers.logging.ERROR`
- - `transformers.logging.WARNING` or `transformers.logging.WARN`
- - `transformers.logging.INFO`
- - `transformers.logging.DEBUG`
- """
-
- _configure_library_root_logger()
- _get_library_root_logger().setLevel(verbosity)
-
-
-def set_verbosity_info():
- """Set the verbosity to the `INFO` level."""
- return set_verbosity(INFO)
-
-
-def set_verbosity_warning():
- """Set the verbosity to the `WARNING` level."""
- return set_verbosity(WARNING)
-
-
-def set_verbosity_debug():
- """Set the verbosity to the `DEBUG` level."""
- return set_verbosity(DEBUG)
-
-
-def set_verbosity_error():
- """Set the verbosity to the `ERROR` level."""
- return set_verbosity(ERROR)
-
-
-def disable_default_handler() -> None:
- """Disable the default handler of the HuggingFace Transformers's root logger."""
-
- _configure_library_root_logger()
-
- assert _default_handler is not None
- _get_library_root_logger().removeHandler(_default_handler)
-
-
-def enable_default_handler() -> None:
- """Enable the default handler of the HuggingFace Transformers's root logger."""
-
- _configure_library_root_logger()
-
- assert _default_handler is not None
- _get_library_root_logger().addHandler(_default_handler)
-
-
-def add_handler(handler: logging.Handler) -> None:
- """adds a handler to the HuggingFace Transformers's root logger."""
-
- _configure_library_root_logger()
-
- assert handler is not None
- _get_library_root_logger().addHandler(handler)
-
-
-def remove_handler(handler: logging.Handler) -> None:
- """removes given handler from the HuggingFace Transformers's root logger."""
-
- _configure_library_root_logger()
-
- assert handler is not None and handler not in _get_library_root_logger().handlers
- _get_library_root_logger().removeHandler(handler)
-
-
-def disable_propagation() -> None:
- """
- Disable propagation of the library log outputs. Note that log propagation is disabled by default.
- """
-
- _configure_library_root_logger()
- _get_library_root_logger().propagate = False
-
-
-def enable_propagation() -> None:
- """
- Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
- prevent double logging if the root logger has been configured.
- """
-
- _configure_library_root_logger()
- _get_library_root_logger().propagate = True
-
-
-def enable_explicit_format() -> None:
- """
- Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
- ```
- [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
- ```
- All handlers currently bound to the root logger are affected by this method.
- """
- handlers = _get_library_root_logger().handlers
-
- for handler in handlers:
- formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
- handler.setFormatter(formatter)
-
-
-def reset_format() -> None:
- """
- Resets the formatting for HuggingFace Transformers's loggers.
-
- All handlers currently bound to the root logger are affected by this method.
- """
- handlers = _get_library_root_logger().handlers
-
- for handler in handlers:
- handler.setFormatter(None)
-
-
-def warning_advice(self, *args, **kwargs):
- """
- This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
- warning will not be printed
- """
- no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
- if no_advisory_warnings:
- return
- self.warning(*args, **kwargs)
-
-
-logging.Logger.warning_advice = warning_advice
-
-
-@functools.lru_cache(None)
-def warning_once(self, *args, **kwargs):
- """
- This method is identical to `logger.warning()`, but will emit the warning with the same message only once
-
- Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
- The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
- another type of cache that includes the caller frame information in the hashing function.
- """
- self.warning(*args, **kwargs)
-
-
-logging.Logger.warning_once = warning_once
-
-
-class EmptyTqdm:
- """Dummy tqdm which doesn't do anything."""
-
- def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
- self._iterator = args[0] if args else None
-
- def __iter__(self):
- return iter(self._iterator)
-
- def __getattr__(self, _):
- """Return empty function."""
-
- def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
- return
-
- return empty_fn
-
- def __enter__(self):
- return self
-
- def __exit__(self, type_, value, traceback):
- return
-
-
-class _tqdm_cls:
-
- def __call__(self, *args, **kwargs):
- if _tqdm_active:
- return tqdm_lib.tqdm(*args, **kwargs)
- else:
- return EmptyTqdm(*args, **kwargs)
-
- def set_lock(self, *args, **kwargs):
- self._lock = None
- if _tqdm_active:
- return tqdm_lib.tqdm.set_lock(*args, **kwargs)
-
- def get_lock(self):
- if _tqdm_active:
- return tqdm_lib.tqdm.get_lock()
-
-
-tqdm = _tqdm_cls()
-
-
-def is_progress_bar_enabled() -> bool:
- """Return a boolean indicating whether tqdm progress bars are enabled."""
- global _tqdm_active
- return bool(_tqdm_active)
-
-
-def enable_progress_bar():
- """Enable tqdm progress bar."""
- global _tqdm_active
- _tqdm_active = True
- hf_hub_utils.enable_progress_bars()
-
-
-def disable_progress_bar():
- """Disable tqdm progress bar."""
- global _tqdm_active
- _tqdm_active = False
- hf_hub_utils.disable_progress_bars()
diff --git a/src/maxdiffusion/transformers/utils/model_parallel_utils.py b/src/maxdiffusion/transformers/utils/model_parallel_utils.py
deleted file mode 100644
index 3f964adf7..000000000
--- a/src/maxdiffusion/transformers/utils/model_parallel_utils.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# coding=utf-8
-# Copyright 2020 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from math import ceil
-
-
-def assert_device_map(device_map, num_blocks):
- blocks = list(range(0, num_blocks))
-
- device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist]
-
- # Duplicate check
- duplicate_blocks = []
- for i in device_map_blocks:
- if device_map_blocks.count(i) > 1 and i not in duplicate_blocks:
- duplicate_blocks.append(i)
- # Missing blocks
- missing_blocks = [i for i in blocks if i not in device_map_blocks]
- extra_blocks = [i for i in device_map_blocks if i not in blocks]
-
- if len(duplicate_blocks) != 0:
- raise ValueError(
- "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device."
- " These attention blocks were specified more than once: " + str(duplicate_blocks)
- )
- if len(missing_blocks) != 0:
- raise ValueError(
- "There are attention blocks for this model that are not specified in the device_map. Add these attention "
- "blocks to a device on the device_map: " + str(missing_blocks)
- )
- if len(extra_blocks) != 0:
- raise ValueError(
- "The device_map contains more attention blocks than this model has. Remove these from the device_map:"
- + str(extra_blocks)
- )
-
-
-def get_device_map(n_layers, devices):
- """Returns a dictionary of layers distributed evenly across all devices."""
- layers = list(range(n_layers))
- n_blocks = int(ceil(n_layers / len(devices)))
- layers_list = [layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)]
-
- return dict(zip(devices, layers_list))
diff --git a/src/maxdiffusion/transformers/utils/notebook.py b/src/maxdiffusion/transformers/utils/notebook.py
deleted file mode 100644
index 5b701a64b..000000000
--- a/src/maxdiffusion/transformers/utils/notebook.py
+++ /dev/null
@@ -1,377 +0,0 @@
-# coding=utf-8
-# Copyright 2020 Hugging Face
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import re
-import time
-from typing import Optional
-
-import IPython.display as disp
-
-from ..trainer_callback import TrainerCallback
-from ..trainer_utils import IntervalStrategy, has_length
-
-
-def format_time(t):
- "Format `t` (in seconds) to (h):mm:ss"
- t = int(t)
- h, m, s = t // 3600, (t // 60) % 60, t % 60
- return f"{h}:{m:02d}:{s:02d}" if h != 0 else f"{m:02d}:{s:02d}"
-
-
-def html_progress_bar(value, total, prefix, label, width=300):
- # docstyle-ignore
- return f"""
-
- """
-
-
-def text_to_html_table(items):
- "Put the texts in `items` in an HTML table."
- html_code = """\n"""
- html_code += """ \n \n"""
- for i in items[0]:
- html_code += f" | {i} | \n"
- html_code += "
\n \n \n"
- for line in items[1:]:
- html_code += " \n"
- for elt in line:
- elt = f"{elt:.6f}" if isinstance(elt, float) else str(elt)
- html_code += f" | {elt} | \n"
- html_code += "
\n"
- html_code += " \n
"
- return html_code
-
-
-class NotebookProgressBar:
- """
- A progress par for display in a notebook.
-
- Class attributes (overridden by derived classes)
-
- - **warmup** (`int`) -- The number of iterations to do at the beginning while ignoring `update_every`.
- - **update_every** (`float`) -- Since calling the time takes some time, we only do it every presumed
- `update_every` seconds. The progress bar uses the average time passed up until now to guess the next value
- for which it will call the update.
-
- Args:
- total (`int`):
- The total number of iterations to reach.
- prefix (`str`, *optional*):
- A prefix to add before the progress bar.
- leave (`bool`, *optional*, defaults to `True`):
- Whether or not to leave the progress bar once it's completed. You can always call the
- [`~utils.notebook.NotebookProgressBar.close`] method to make the bar disappear.
- parent ([`~notebook.NotebookTrainingTracker`], *optional*):
- A parent object (like [`~utils.notebook.NotebookTrainingTracker`]) that spawns progress bars and handle
- their display. If set, the object passed must have a `display()` method.
- width (`int`, *optional*, defaults to 300):
- The width (in pixels) that the bar will take.
-
- Example:
-
- ```python
- import time
-
- pbar = NotebookProgressBar(100)
- for val in range(100):
- pbar.update(val)
- time.sleep(0.07)
- pbar.update(100)
- ```"""
-
- warmup = 5
- update_every = 0.2
-
- def __init__(
- self,
- total: int,
- prefix: Optional[str] = None,
- leave: bool = True,
- parent: Optional["NotebookTrainingTracker"] = None,
- width: int = 300,
- ):
- self.total = total
- self.prefix = "" if prefix is None else prefix
- self.leave = leave
- self.parent = parent
- self.width = width
- self.last_value = None
- self.comment = None
- self.output = None
-
- def update(self, value: int, force_update: bool = False, comment: str = None):
- """
- The main method to update the progress bar to `value`.
-
- Args:
- value (`int`):
- The value to use. Must be between 0 and `total`.
- force_update (`bool`, *optional*, defaults to `False`):
- Whether or not to force and update of the internal state and display (by default, the bar will wait for
- `value` to reach the value it predicted corresponds to a time of more than the `update_every` attribute
- since the last update to avoid adding boilerplate).
- comment (`str`, *optional*):
- A comment to add on the left of the progress bar.
- """
- self.value = value
- if comment is not None:
- self.comment = comment
- if self.last_value is None:
- self.start_time = self.last_time = time.time()
- self.start_value = self.last_value = value
- self.elapsed_time = self.predicted_remaining = None
- self.first_calls = self.warmup
- self.wait_for = 1
- self.update_bar(value)
- elif value <= self.last_value and not force_update:
- return
- elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
- if self.first_calls > 0:
- self.first_calls -= 1
- current_time = time.time()
- self.elapsed_time = current_time - self.start_time
- # We could have value = self.start_value if the update is called twixe with the same start value.
- if value > self.start_value:
- self.average_time_per_item = self.elapsed_time / (value - self.start_value)
- else:
- self.average_time_per_item = None
- if value >= self.total:
- value = self.total
- self.predicted_remaining = None
- if not self.leave:
- self.close()
- elif self.average_time_per_item is not None:
- self.predicted_remaining = self.average_time_per_item * (self.total - value)
- self.update_bar(value)
- self.last_value = value
- self.last_time = current_time
- if (self.average_time_per_item is None) or (self.average_time_per_item == 0):
- self.wait_for = 1
- else:
- self.wait_for = max(int(self.update_every / self.average_time_per_item), 1)
-
- def update_bar(self, value, comment=None):
- spaced_value = " " * (len(str(self.total)) - len(str(value))) + str(value)
- if self.elapsed_time is None:
- self.label = f"[{spaced_value}/{self.total} : < :"
- elif self.predicted_remaining is None:
- self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)}"
- else:
- self.label = (
- f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} <" f" {format_time(self.predicted_remaining)}"
- )
- if self.average_time_per_item == 0:
- self.label += ", +inf it/s"
- else:
- self.label += f", {1/self.average_time_per_item:.2f} it/s"
-
- self.label += "]" if self.comment is None or len(self.comment) == 0 else f", {self.comment}]"
- self.display()
-
- def display(self):
- self.html_code = html_progress_bar(self.value, self.total, self.prefix, self.label, self.width)
- if self.parent is not None:
- # If this is a child bar, the parent will take care of the display.
- self.parent.display()
- return
- if self.output is None:
- self.output = disp.display(disp.HTML(self.html_code), display_id=True)
- else:
- self.output.update(disp.HTML(self.html_code))
-
- def close(self):
- "Closes the progress bar."
- if self.parent is None and self.output is not None:
- self.output.update(disp.HTML(""))
-
-
-class NotebookTrainingTracker(NotebookProgressBar):
- """
- An object tracking the updates of an ongoing training with progress bars and a nice table reporting metrics.
-
- Args:
- num_steps (`int`): The number of steps during training. column_names (`List[str]`, *optional*):
- The list of column names for the metrics table (will be inferred from the first call to
- [`~utils.notebook.NotebookTrainingTracker.write_line`] if not set).
- """
-
- def __init__(self, num_steps, column_names=None):
- super().__init__(num_steps)
- self.inner_table = None if column_names is None else [column_names]
- self.child_bar = None
-
- def display(self):
- self.html_code = html_progress_bar(self.value, self.total, self.prefix, self.label, self.width)
- if self.inner_table is not None:
- self.html_code += text_to_html_table(self.inner_table)
- if self.child_bar is not None:
- self.html_code += self.child_bar.html_code
- if self.output is None:
- self.output = disp.display(disp.HTML(self.html_code), display_id=True)
- else:
- self.output.update(disp.HTML(self.html_code))
-
- def write_line(self, values):
- """
- Write the values in the inner table.
-
- Args:
- values (`Dict[str, float]`): The values to display.
- """
- if self.inner_table is None:
- self.inner_table = [list(values.keys()), list(values.values())]
- else:
- columns = self.inner_table[0]
- for key in values.keys():
- if key not in columns:
- columns.append(key)
- self.inner_table[0] = columns
- if len(self.inner_table) > 1:
- last_values = self.inner_table[-1]
- first_column = self.inner_table[0][0]
- if last_values[0] != values[first_column]:
- # write new line
- self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
- else:
- # update last line
- new_values = values
- for c in columns:
- if c not in new_values.keys():
- new_values[c] = last_values[columns.index(c)]
- self.inner_table[-1] = [new_values[c] for c in columns]
- else:
- self.inner_table.append([values[c] for c in columns])
-
- def add_child(self, total, prefix=None, width=300):
- """
- Add a child progress bar displayed under the table of metrics. The child progress bar is returned (so it can be
- easily updated).
-
- Args:
- total (`int`): The number of iterations for the child progress bar.
- prefix (`str`, *optional*): A prefix to write on the left of the progress bar.
- width (`int`, *optional*, defaults to 300): The width (in pixels) of the progress bar.
- """
- self.child_bar = NotebookProgressBar(total, prefix=prefix, parent=self, width=width)
- return self.child_bar
-
- def remove_child(self):
- """
- Closes the child progress bar.
- """
- self.child_bar = None
- self.display()
-
-
-class NotebookProgressCallback(TrainerCallback):
- """
- A [`TrainerCallback`] that displays the progress of training or evaluation, optimized for Jupyter Notebooks or
- Google colab.
- """
-
- def __init__(self):
- self.training_tracker = None
- self.prediction_bar = None
- self._force_next_update = False
-
- def on_train_begin(self, args, state, control, **kwargs):
- self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
- self.training_loss = 0
- self.last_log = 0
- column_names = [self.first_column] + ["Training Loss"]
- if args.eval_strategy != IntervalStrategy.NO:
- column_names.append("Validation Loss")
- self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
-
- def on_step_end(self, args, state, control, **kwargs):
- epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
- self.training_tracker.update(
- state.global_step + 1,
- comment=f"Epoch {epoch}/{state.num_train_epochs}",
- force_update=self._force_next_update,
- )
- self._force_next_update = False
-
- def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
- if not has_length(eval_dataloader):
- return
- if self.prediction_bar is None:
- if self.training_tracker is not None:
- self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
- else:
- self.prediction_bar = NotebookProgressBar(len(eval_dataloader))
- self.prediction_bar.update(1)
- else:
- self.prediction_bar.update(self.prediction_bar.value + 1)
-
- def on_predict(self, args, state, control, **kwargs):
- if self.prediction_bar is not None:
- self.prediction_bar.close()
- self.prediction_bar = None
-
- def on_log(self, args, state, control, logs=None, **kwargs):
- # Only for when there is no evaluation
- if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
- values = {"Training Loss": logs["loss"]}
- # First column is necessarily Step sine we're not in epoch eval strategy
- values["Step"] = state.global_step
- self.training_tracker.write_line(values)
-
- def on_evaluate(self, args, state, control, metrics=None, **kwargs):
- if self.training_tracker is not None:
- values = {"Training Loss": "No log", "Validation Loss": "No log"}
- for log in reversed(state.log_history):
- if "loss" in log:
- values["Training Loss"] = log["loss"]
- break
-
- if self.first_column == "Epoch":
- values["Epoch"] = int(state.epoch)
- else:
- values["Step"] = state.global_step
- metric_key_prefix = "eval"
- for k in metrics:
- if k.endswith("_loss"):
- metric_key_prefix = re.sub(r"\_loss$", "", k)
- _ = metrics.pop("total_flos", None)
- _ = metrics.pop("epoch", None)
- _ = metrics.pop(f"{metric_key_prefix}_runtime", None)
- _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
- _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
- _ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
- for k, v in metrics.items():
- splits = k.split("_")
- name = " ".join([part.capitalize() for part in splits[1:]])
- if name == "Loss":
- # Single dataset
- name = "Validation Loss"
- values[name] = v
- self.training_tracker.write_line(values)
- self.training_tracker.remove_child()
- self.prediction_bar = None
- # Evaluation takes a long time so we should force the next update.
- self._force_next_update = True
-
- def on_train_end(self, args, state, control, **kwargs):
- self.training_tracker.update(
- state.global_step,
- comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}",
- force_update=True,
- )
- self.training_tracker = None
diff --git a/src/maxdiffusion/transformers/utils/peft_utils.py b/src/maxdiffusion/transformers/utils/peft_utils.py
deleted file mode 100644
index 538a85afb..000000000
--- a/src/maxdiffusion/transformers/utils/peft_utils.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import importlib
-import os
-from typing import Dict, Optional, Union
-
-from packaging import version
-
-from .hub import cached_file
-from .import_utils import is_peft_available
-
-
-ADAPTER_CONFIG_NAME = "adapter_config.json"
-ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
-ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
-
-
-def find_adapter_config_file(
- model_id: str,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- resume_download: Optional[bool] = None,
- proxies: Optional[Dict[str, str]] = None,
- token: Optional[Union[bool, str]] = None,
- revision: Optional[str] = None,
- local_files_only: bool = False,
- subfolder: str = "",
- _commit_hash: Optional[str] = None,
-) -> Optional[str]:
- r"""
- Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter
- config file if it is, None otherwise.
-
- Args:
- model_id (`str`):
- The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
- cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
- exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `huggingface-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
-
-
-
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
-
-
-
- local_files_only (`bool`, *optional*, defaults to `False`):
- If `True`, will only try to load the tokenizer configuration from local files.
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
- """
- adapter_cached_filename = None
- if model_id is None:
- return None
- elif os.path.isdir(model_id):
- list_remote_files = os.listdir(model_id)
- if ADAPTER_CONFIG_NAME in list_remote_files:
- adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
- else:
- adapter_cached_filename = cached_file(
- model_id,
- ADAPTER_CONFIG_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- token=token,
- revision=revision,
- local_files_only=local_files_only,
- subfolder=subfolder,
- _commit_hash=_commit_hash,
- _raise_exceptions_for_gated_repo=False,
- _raise_exceptions_for_missing_entries=False,
- _raise_exceptions_for_connection_errors=False,
- )
-
- return adapter_cached_filename
-
-
-def check_peft_version(min_version: str) -> None:
- r"""
- Checks if the version of PEFT is compatible.
-
- Args:
- version (`str`):
- The version of PEFT to check against.
- """
- if not is_peft_available():
- raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
-
- is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version)
-
- if not is_peft_version_compatible:
- raise ValueError(
- f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}"
- )
diff --git a/src/maxdiffusion/transformers/utils/quantization_config.py b/src/maxdiffusion/transformers/utils/quantization_config.py
deleted file mode 100644
index 0e76bfc98..000000000
--- a/src/maxdiffusion/transformers/utils/quantization_config.py
+++ /dev/null
@@ -1,1037 +0,0 @@
-#!/usr/bin/env python
-# coding=utf-8
-
-# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import copy
-import importlib.metadata
-import json
-import os
-from dataclasses import dataclass
-from enum import Enum
-from typing import Any, Dict, List, Optional, Union
-
-from packaging import version
-
-from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
-
-
-logger = logging.get_logger(__name__)
-
-
-class QuantizationMethod(str, Enum):
- BITS_AND_BYTES = "bitsandbytes"
- GPTQ = "gptq"
- AWQ = "awq"
- AQLM = "aqlm"
- QUANTO = "quanto"
- EETQ = "eetq"
- HQQ = "hqq"
-
-
-class AWQLinearVersion(str, Enum):
- GEMM = "gemm"
- GEMV = "gemv"
- EXLLAMA = "exllama"
-
- @staticmethod
- def from_str(version: str):
- version = version.lower()
- if version == "gemm":
- return AWQLinearVersion.GEMM
- elif version == "gemv":
- return AWQLinearVersion.GEMV
- elif version == "exllama":
- return AWQLinearVersion.EXLLAMA
- else:
- raise ValueError(f"Unknown AWQLinearVersion {version}")
-
-
-class AwqBackendPackingMethod(str, Enum):
- AUTOAWQ = "autoawq"
- LLMAWQ = "llm-awq"
-
-
-@dataclass
-class QuantizationConfigMixin:
- """
- Mixin class for quantization config
- """
-
- quant_method: QuantizationMethod
-
- @classmethod
- def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
- """
- Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters.
-
- Args:
- config_dict (`Dict[str, Any]`):
- Dictionary that will be used to instantiate the configuration object.
- return_unused_kwargs (`bool`,*optional*, defaults to `False`):
- Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
- `PreTrainedModel`.
- kwargs (`Dict[str, Any]`):
- Additional parameters from which to initialize the configuration object.
-
- Returns:
- [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
- """
-
- config = cls(**config_dict)
-
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
-
- if return_unused_kwargs:
- return config, kwargs
- else:
- return config
-
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
- """
- Save this instance to a JSON file.
-
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this configuration instance's parameters will be saved.
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default
- `QuantizationConfig()` is serialized to JSON file.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- config_dict = self.to_dict()
- json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
-
- writer.write(json_string)
-
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return copy.deepcopy(self.__dict__)
-
- def __iter__(self):
- """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
- for attr, value in copy.deepcopy(self.__dict__).items():
- yield attr, value
-
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
-
- def to_json_string(self, use_diff: bool = True) -> str:
- """
- Serializes this instance to a JSON string.
-
- Args:
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
- is serialized to JSON string.
-
- Returns:
- `str`: String containing all the attributes that make up this configuration instance in JSON format.
- """
- if use_diff is True:
- config_dict = self.to_diff_dict()
- else:
- config_dict = self.to_dict()
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
-
- def update(self, **kwargs):
- """
- Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
- returning all the unused kwargs.
-
- Args:
- kwargs (`Dict[str, Any]`):
- Dictionary of attributes to tentatively update this class.
-
- Returns:
- `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
- """
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(self, key):
- setattr(self, key, value)
- to_remove.append(key)
-
- # Remove all the attributes that were updated, without modifying the input dict
- unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
- return unused_kwargs
-
-
-@dataclass
-class HqqConfig(QuantizationConfigMixin):
- """
- This is wrapper around hqq's BaseQuantizeConfig.
-
- Args:
- nbits (`int`, *optional*, defaults to 4):
- Number of bits. Supported values are (8, 4, 3, 2, 1).
- group_size (`int`, *optional*, defaults to 64):
- Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
- quant_zero (`bool`, *optional*, defaults to `True`):
- Quantize the zero-point if set to `True`.
- quant_scale (`bool`, *optional*, defaults to `False`):
- Quantize the scaling if set to `True`.
- offload_meta (`bool`, *optional*, defaults to `False`):
- Offload the meta-data to the CPU if set to `True`.
- view_as_float (`bool`, *optional*, defaults to `False`):
- View the quantized weight as float (used in distributed training) if set to `True`.
- axis (`int`, *optional*, defaults to 0):
- Axis along which grouping is performed. Supported values are 0 or 1.
- dynamic_config (dict, *optional*):
- Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
- If set, each layer specified by its id will use its dedicated quantization configuration.
- skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`):
- List of `nn.Linear` layers to skip.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
-
- def __init__(
- self,
- nbits: int = 4,
- group_size: int = 64,
- quant_zero: bool = True,
- quant_scale: bool = False,
- offload_meta: bool = False,
- view_as_float: bool = False,
- axis: int = 0,
- dynamic_config: Optional[dict] = None,
- skip_modules: List[str] = ["lm_head"],
- **kwargs,
- ):
- if is_hqq_available():
- from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
-
- if axis not in [0, 1]:
- raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
-
- if dynamic_config is not None:
- self.quant_config = {}
- for key in dynamic_config:
- self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key])
- else:
- self.quant_config = HQQBaseQuantizeConfig(
- **{
- "nbits": nbits,
- "group_size": group_size,
- "quant_zero": quant_zero,
- "quant_scale": quant_scale,
- "offload_meta": offload_meta,
- "view_as_float": view_as_float,
- "axis": axis,
- }
- )
-
- self.quant_method = QuantizationMethod.HQQ
- self.skip_modules = skip_modules
-
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- pass
-
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return self.quant_config
-
- def __repr__(self):
- config_dict = self.to_dict()
- return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
-
- def to_diff_dict(self) -> Dict[str, Any]:
- """
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
- Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- config_dict = self.to_dict()
-
- # get the default config dict
- default_config_dict = HqqConfig().to_dict()
-
- serializable_config_dict = {}
-
- # only serialize values that differ from the default config
- for key, value in config_dict.items():
- if value != default_config_dict[key]:
- serializable_config_dict[key] = value
-
- return serializable_config_dict
-
-
-@dataclass
-class BitsAndBytesConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `bitsandbytes`.
-
- This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
-
- Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
- then more arguments will be added to this class.
-
- Args:
- load_in_8bit (`bool`, *optional*, defaults to `False`):
- This flag is used to enable 8-bit quantization with LLM.int8().
- load_in_4bit (`bool`, *optional*, defaults to `False`):
- This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
- `bitsandbytes`.
- llm_int8_threshold (`float`, *optional*, defaults to 6.0):
- This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
- Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
- that is above this threshold will be considered an outlier and the operation on those values will be done
- in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
- there are some exceptional systematic outliers that are very differently distributed for large models.
- These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
- magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
- but a lower threshold might be needed for more unstable models (small models, fine-tuning).
- llm_int8_skip_modules (`List[str]`, *optional*):
- An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
- Jukebox that has several heads in different places and not necessarily at the last position. For example
- for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
- llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
- This flag is used for advanced use cases and users that are aware of this feature. If you want to split
- your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
- this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
- operations will not be run on CPU.
- llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
- This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
- have to be converted back and forth for the backward pass.
- bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
- This sets the computational type which might be different than the input type. For example, inputs might be
- fp32, but computation can be set to bf16 for speedups.
- bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`):
- This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
- which are specified by `fp4` or `nf4`.
- bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
- This flag is used for nested quantization where the quantization constants from the first quantization are
- quantized again.
- bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
- This sets the storage type to pack the quanitzed 4-bit prarams.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
-
- def __init__(
- self,
- load_in_8bit=False,
- load_in_4bit=False,
- llm_int8_threshold=6.0,
- llm_int8_skip_modules=None,
- llm_int8_enable_fp32_cpu_offload=False,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=None,
- bnb_4bit_quant_type="fp4",
- bnb_4bit_use_double_quant=False,
- bnb_4bit_quant_storage=None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.BITS_AND_BYTES
-
- if load_in_4bit and load_in_8bit:
- raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
-
- self._load_in_8bit = load_in_8bit
- self._load_in_4bit = load_in_4bit
- self.llm_int8_threshold = llm_int8_threshold
- self.llm_int8_skip_modules = llm_int8_skip_modules
- self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
- self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
- self.bnb_4bit_quant_type = bnb_4bit_quant_type
- self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
-
- if bnb_4bit_compute_dtype is None:
- self.bnb_4bit_compute_dtype = torch.float32
- elif isinstance(bnb_4bit_compute_dtype, str):
- self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
- elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
- self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
- else:
- raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
-
- if bnb_4bit_quant_storage is None:
- self.bnb_4bit_quant_storage = torch.uint8
- elif isinstance(bnb_4bit_quant_storage, str):
- if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
- raise ValueError(
- "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
- )
- self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
- elif isinstance(bnb_4bit_quant_storage, torch.dtype):
- self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
- else:
- raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
-
- if kwargs:
- logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
-
- self.post_init()
-
- @property
- def load_in_4bit(self):
- return self._load_in_4bit
-
- @load_in_4bit.setter
- def load_in_4bit(self, value: bool):
- if not isinstance(value, bool):
- raise ValueError("load_in_4bit must be a boolean")
-
- if self.load_in_8bit and value:
- raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
- self._load_in_4bit = value
-
- @property
- def load_in_8bit(self):
- return self._load_in_8bit
-
- @load_in_8bit.setter
- def load_in_8bit(self, value: bool):
- if not isinstance(value, bool):
- raise ValueError("load_in_8bit must be a boolean")
-
- if self.load_in_4bit and value:
- raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
- self._load_in_8bit = value
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if not isinstance(self.load_in_4bit, bool):
- raise ValueError("load_in_4bit must be a boolean")
-
- if not isinstance(self.load_in_8bit, bool):
- raise ValueError("load_in_8bit must be a boolean")
-
- if not isinstance(self.llm_int8_threshold, float):
- raise ValueError("llm_int8_threshold must be a float")
-
- if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
- raise ValueError("llm_int8_skip_modules must be a list of strings")
- if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
- raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean")
-
- if not isinstance(self.llm_int8_has_fp16_weight, bool):
- raise ValueError("llm_int8_has_fp16_weight must be a boolean")
-
- if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
- raise ValueError("bnb_4bit_compute_dtype must be torch.dtype")
-
- if not isinstance(self.bnb_4bit_quant_type, str):
- raise ValueError("bnb_4bit_quant_type must be a string")
-
- if not isinstance(self.bnb_4bit_use_double_quant, bool):
- raise ValueError("bnb_4bit_use_double_quant must be a boolean")
-
- if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.39.0"):
- raise ValueError("4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version")
-
- def is_quantizable(self):
- r"""
- Returns `True` if the model is quantizable, `False` otherwise.
- """
- return self.load_in_8bit or self.load_in_4bit
-
- def quantization_method(self):
- r"""
- This method returns the quantization method used for the model. If the model is not quantizable, it returns
- `None`.
- """
- if self.load_in_8bit:
- return "llm_int8"
- elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
- return "fp4"
- elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
- return "nf4"
- else:
- return None
-
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
- output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
- output["load_in_4bit"] = self.load_in_4bit
- output["load_in_8bit"] = self.load_in_8bit
-
- return output
-
- def __repr__(self):
- config_dict = self.to_dict()
- return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
-
- def to_diff_dict(self) -> Dict[str, Any]:
- """
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
-
- Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- config_dict = self.to_dict()
-
- # get the default config dict
- default_config_dict = BitsAndBytesConfig().to_dict()
-
- serializable_config_dict = {}
-
- # only serialize values that differ from the default config
- for key, value in config_dict.items():
- if value != default_config_dict[key]:
- serializable_config_dict[key] = value
-
- return serializable_config_dict
-
-
-class ExllamaVersion(int, Enum):
- ONE = 1
- TWO = 2
-
-
-@dataclass
-class GPTQConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `optimum` api for gptq quantization relying on auto_gptq backend.
-
- Args:
- bits (`int`):
- The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
- tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
- The tokenizer used to process the dataset. You can pass either:
- - A custom tokenizer object.
- - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
- using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
- dataset (`Union[List[str]]`, *optional*):
- The dataset used for quantization. You can provide your own dataset in a list of string or just use the
- original datasets used in GPTQ paper ['wikitext2','c4','c4-new']
- group_size (`int`, *optional*, defaults to 128):
- The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
- damp_percent (`float`, *optional*, defaults to 0.1):
- The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.1.
- desc_act (`bool`, *optional*, defaults to `False`):
- Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly
- speed up inference but the perplexity may become slightly worse. Also known as act-order.
- sym (`bool`, *optional*, defaults to `True`):
- Whether to use symetric quantization.
- true_sequential (`bool`, *optional*, defaults to `True`):
- Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing
- the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes
- quantization using inputs that have passed through the previously quantized layers.
- use_cuda_fp16 (`bool`, *optional*, defaults to `False`):
- Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
- model_seqlen (`int`, *optional*):
- The maximum sequence length that the model can take.
- block_name_to_quantize (`str`, *optional*):
- The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers)
- module_name_preceding_first_block (`List[str]`, *optional*):
- The layers that are preceding the first Transformer block.
- batch_size (`int`, *optional*, defaults to 1):
- The batch size used when processing the dataset
- pad_token_id (`int`, *optional*):
- The pad token id. Needed to prepare the dataset when `batch_size` > 1.
- use_exllama (`bool`, *optional*):
- Whether to use exllama backend. Defaults to `True` if unset. Only works with `bits` = 4.
- max_input_length (`int`, *optional*):
- The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
- length. It is specific to the exllama backend with act-order.
- exllama_config (`Dict[str, Any]`, *optional*):
- The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults
- to `{"version": 1}` if unset.
- cache_block_outputs (`bool`, *optional*, defaults to `True`):
- Whether to cache block outputs to reuse as inputs for the succeeding block.
- modules_in_block_to_quantize (`List[List[str]]`, *optional*):
- List of list of module names to quantize in the specified block. This argument is useful to exclude certain linear modules from being quantized.
- The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers.
- Example: `modules_in_block_to_quantize =[["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"]]`.
- In this example, we will first quantize the q,k,v layers simultaneously since they are independent.
- Then, we will quantize `self_attn.o_proj` layer with the q,k,v layers quantized. This way, we will get
- better results since it reflects the real input `self_attn.o_proj` will get when the model is quantized.
- """
-
- def __init__(
- self,
- bits: int,
- tokenizer: Any = None,
- dataset: Optional[Union[List[str], str]] = None,
- group_size: int = 128,
- damp_percent: float = 0.1,
- desc_act: bool = False,
- sym: bool = True,
- true_sequential: bool = True,
- use_cuda_fp16: bool = False,
- model_seqlen: Optional[int] = None,
- block_name_to_quantize: Optional[str] = None,
- module_name_preceding_first_block: Optional[List[str]] = None,
- batch_size: int = 1,
- pad_token_id: Optional[int] = None,
- use_exllama: Optional[bool] = None,
- max_input_length: Optional[int] = None,
- exllama_config: Optional[Dict[str, Any]] = None,
- cache_block_outputs: bool = True,
- modules_in_block_to_quantize: Optional[List[List[str]]] = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.GPTQ
- self.bits = bits
- self.tokenizer = tokenizer
- self.dataset = dataset
- self.group_size = group_size
- self.damp_percent = damp_percent
- self.desc_act = desc_act
- self.sym = sym
- self.true_sequential = true_sequential
- self.use_cuda_fp16 = use_cuda_fp16
- self.model_seqlen = model_seqlen
- self.block_name_to_quantize = block_name_to_quantize
- self.module_name_preceding_first_block = module_name_preceding_first_block
- self.batch_size = batch_size
- self.pad_token_id = pad_token_id
- self.use_exllama = use_exllama
- self.max_input_length = max_input_length
- self.exllama_config = exllama_config
- self.disable_exllama = kwargs.pop("disable_exllama", None)
- self.cache_block_outputs = cache_block_outputs
- self.modules_in_block_to_quantize = modules_in_block_to_quantize
- self.post_init()
-
- def get_loading_attributes(self):
- attibutes_dict = copy.deepcopy(self.__dict__)
- loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
- loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
- return loading_attibutes_dict
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- if self.bits not in [2, 3, 4, 8]:
- raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
- if self.group_size != -1 and self.group_size <= 0:
- raise ValueError("group_size must be greater than 0 or equal to -1")
- if not (0 < self.damp_percent < 1):
- raise ValueError("damp_percent must between 0 and 1.")
- if self.dataset is not None:
- if isinstance(self.dataset, str):
- if self.dataset in ["ptb", "ptb-new"]:
- raise ValueError(
- f"""{self.dataset} dataset was deprecated. You can only choose between
- ['wikitext2','c4','c4-new']"""
- )
- if self.dataset not in ["wikitext2", "c4", "c4-new"]:
- raise ValueError(
- f"""You have entered a string value for dataset. You can only choose between
- ['wikitext2','c4','c4-new'], but we found {self.dataset}"""
- )
- elif not isinstance(self.dataset, list):
- raise ValueError(
- f"""dataset needs to be either a list of string or a value in
- ['wikitext2','c4','c4-new'], but we found {self.dataset}"""
- )
-
- if self.disable_exllama is None and self.use_exllama is None:
- # New default behaviour
- self.use_exllama = True
- elif self.disable_exllama is not None and self.use_exllama is None:
- # Follow pattern of old config
- logger.warning(
- "Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
- "The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
- )
- self.use_exllama = not self.disable_exllama
- self.disable_exllama = None
- elif self.disable_exllama is not None and self.use_exllama is not None:
- # Only happens if user explicitly passes in both arguments
- raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
-
- if self.exllama_config is None:
- self.exllama_config = {"version": ExllamaVersion.ONE}
- else:
- if "version" not in self.exllama_config:
- raise ValueError("`exllama_config` needs to have a `version` key.")
- elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
- exllama_version = self.exllama_config["version"]
- raise ValueError(
- f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
- )
-
- if self.bits == 4 and self.use_exllama:
- if self.exllama_config["version"] == ExllamaVersion.ONE:
- logger.info(
- "You have activated exllama backend. Note that you can get better inference "
- "speed using exllamav2 kernel by setting `exllama_config`."
- )
- elif self.exllama_config["version"] == ExllamaVersion.TWO:
- optimum_version = version.parse(importlib.metadata.version("optimum"))
- autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
- if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
- raise ValueError(
- f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
- )
- if self.modules_in_block_to_quantize is not None:
- optimum_version = version.parse(importlib.metadata.version("optimum"))
- if optimum_version < version.parse("1.15.0"):
- raise ValueError(
- "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
- )
-
- def to_dict(self):
- config_dict = super().to_dict()
- config_dict.pop("disable_exllama", None)
- return config_dict
-
- def to_dict_optimum(self):
- """
- Get compatible dict for optimum gptq config
- """
- quant_dict = self.to_dict()
- # make it compatible with optimum config
- quant_dict["disable_exllama"] = not self.use_exllama
- return quant_dict
-
- @classmethod
- def from_dict_optimum(cls, config_dict):
- """
- Get compatible class with optimum gptq config dict
- """
-
- if "disable_exllama" in config_dict:
- config_dict["use_exllama"] = not config_dict["disable_exllama"]
- # switch to None to not trigger the warning
- config_dict["disable_exllama"] = None
-
- config = cls(**config_dict)
- return config
-
-
-@dataclass
-class AwqConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `auto-awq` library awq quantization relying on auto_awq backend.
-
- Args:
- bits (`int`, *optional*, defaults to 4):
- The number of bits to quantize to.
- group_size (`int`, *optional*, defaults to 128):
- The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
- zero_point (`bool`, *optional*, defaults to `True`):
- Whether to use zero point quantization.
- version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
- The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
- GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
- backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
- The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
- that quantize their own models using `llm-awq` library.
- do_fuse (`bool`, *optional*, defaults to `False`):
- Whether to fuse attention and mlp layers together for faster inference
- fuse_max_seq_len (`int`, *optional*):
- The Maximum sequence length to generate when using fusing.
- modules_to_fuse (`dict`, *optional*, default to `None`):
- Overwrite the natively supported fusing scheme with the one specified by the users.
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
- Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
- exllama_config (`Dict[str, Any]`, *optional*):
- You can specify the version of the exllama kernel through the `version` key, the maximum sequence
- length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
- Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
- """
-
- def __init__(
- self,
- bits: int = 4,
- group_size: int = 128,
- zero_point: bool = True,
- version: AWQLinearVersion = AWQLinearVersion.GEMM,
- backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
- do_fuse: Optional[bool] = None,
- fuse_max_seq_len: Optional[int] = None,
- modules_to_fuse: Optional[dict] = None,
- modules_to_not_convert: Optional[List] = None,
- exllama_config: Optional[Dict[str, int]] = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.AWQ
-
- self.bits = bits
- self.group_size = group_size
- self.zero_point = zero_point
- self.version = version
- self.backend = backend
- self.fuse_max_seq_len = fuse_max_seq_len
- self.modules_to_not_convert = modules_to_not_convert
- self.exllama_config = exllama_config
-
- self.modules_to_fuse = modules_to_fuse
- if do_fuse is None:
- self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse) > 0
- else:
- self.do_fuse = do_fuse
- self.fuse_max_seq_len = fuse_max_seq_len
-
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- if not torch.cuda.is_available():
- raise ValueError("AWQ is only available on GPU")
-
- if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
- raise ValueError(
- f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
- )
-
- self.version = AWQLinearVersion.from_str(self.version)
- if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
- raise ValueError(
- f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
- )
-
- if self.backend == AwqBackendPackingMethod.LLMAWQ:
- compute_capability = torch.cuda.get_device_capability()
- major, minor = compute_capability
- if major < 8:
- raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")
-
- if self.do_fuse and self.fuse_max_seq_len is None:
- raise ValueError(
- "You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase"
- )
-
- if self.do_fuse:
- awq_version_supports_fusing = False
- MIN_AWQ_VERSION = "0.1.7"
- if is_auto_awq_available():
- awq_version_supports_fusing = version.parse(importlib.metadata.version("autoawq")) >= version.parse(MIN_AWQ_VERSION)
-
- if not awq_version_supports_fusing:
- raise ValueError(
- f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
- )
-
- if self.modules_to_not_convert is not None:
- awq_version_supports_non_conversion = False
- MIN_AWQ_VERSION = "0.1.8"
- if is_auto_awq_available():
- awq_version_supports_non_conversion = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
- MIN_AWQ_VERSION
- )
-
- if not awq_version_supports_non_conversion:
- raise ValueError(
- f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
- )
-
- if self.do_fuse and self.modules_to_fuse is not None:
- required_keys = [
- "hidden_size",
- "num_attention_heads",
- "num_key_value_heads",
- "mlp",
- "attention",
- "layernorm",
- "use_alibi",
- ]
- if not all(key in self.modules_to_fuse for key in required_keys):
- raise ValueError(f"Required fields are missing in the fusing mapping, required fields are {required_keys}")
-
- if self.version == AWQLinearVersion.EXLLAMA:
- awq_version_supports_exllama = False
- MIN_AWQ_VERSION = "0.2.0"
- if is_auto_awq_available():
- awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(MIN_AWQ_VERSION)
-
- if not awq_version_supports_exllama:
- raise ValueError(
- f"You current version of `autoawq` does not support exllama backend, "
- f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
- )
-
- if self.exllama_config is None:
- self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
- else:
- if "version" not in self.exllama_config:
- raise ValueError("`exllama_config` needs to have a `version` key.")
- elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
- exllama_version = self.exllama_config["version"]
- raise ValueError(
- f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
- )
-
- def get_loading_attributes(self):
- attibutes_dict = copy.deepcopy(self.__dict__)
- loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len", "exllama_config"]
- loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
- return loading_attibutes_dict
-
-
-@dataclass
-class AqlmConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about `aqlm` parameters.
-
- Args:
- in_group_size (`int`, *optional*, defaults to 8):
- The group size along the input dimension.
- out_group_size (`int`, *optional*, defaults to 1):
- The group size along the output dimension. It's recommended to always use 1.
- num_codebooks (`int`, *optional*, defaults to 1):
- Number of codebooks for the Additive Quantization procedure.
- nbits_per_codebook (`int`, *optional*, defaults to 16):
- Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
- linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*):
- List of full paths of `nn.Linear` weight parameters that shall not be quantized.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
-
- def __init__(
- self,
- in_group_size: int = 8,
- out_group_size: int = 1,
- num_codebooks: int = 1,
- nbits_per_codebook: int = 16,
- linear_weights_not_to_quantize: Optional[List[str]] = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.AQLM
- self.in_group_size = in_group_size
- self.out_group_size = out_group_size
- self.num_codebooks = num_codebooks
- self.nbits_per_codebook = nbits_per_codebook
- self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
-
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if not isinstance(self.in_group_size, int):
- raise ValueError("in_group_size must be a float")
- if not isinstance(self.out_group_size, int):
- raise ValueError("out_group_size must be a float")
- if not isinstance(self.num_codebooks, int):
- raise ValueError("num_codebooks must be a float")
- if not isinstance(self.nbits_per_codebook, int):
- raise ValueError("nbits_per_codebook must be a float")
-
- if self.linear_weights_not_to_quantize is not None and not isinstance(self.linear_weights_not_to_quantize, list):
- raise ValueError("linear_weights_not_to_quantize must be a list of strings")
-
- if self.linear_weights_not_to_quantize is None:
- self.linear_weights_not_to_quantize = []
-
-
-@dataclass
-class QuantoConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `quanto`.
-
- Args:
- weights (`str`, *optional*, defaults to `"int8"`):
- The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
- activations (`str`, *optional*):
- The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
- """
-
- def __init__(
- self,
- weights="int8",
- activations=None,
- modules_to_not_convert: Optional[List] = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.QUANTO
- self.weights = weights
- self.activations = activations
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- accepted_weights = ["float8", "int8", "int4", "int2"]
- accepted_activations = [None, "int8", "float8"]
- if self.weights not in accepted_weights:
- raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
- if self.activations not in accepted_activations:
- raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
-
-
-@dataclass
-class EetqConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `eetq`.
-
- Args:
- weights (`str`, *optional*, defaults to `"int8"`):
- The target dtype for the weights. Supported value is only "int8"
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision.
- """
-
- def __init__(
- self,
- weights: str = "int8",
- modules_to_not_convert: Optional[List] = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.EETQ
- self.weights = weights
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- accepted_weights = ["int8"]
- if self.weights not in accepted_weights:
- raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
diff --git a/src/maxdiffusion/transformers/utils/sentencepiece_model_pb2.py b/src/maxdiffusion/transformers/utils/sentencepiece_model_pb2.py
deleted file mode 100644
index b4b2992a6..000000000
--- a/src/maxdiffusion/transformers/utils/sentencepiece_model_pb2.py
+++ /dev/null
@@ -1,1511 +0,0 @@
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: sentencepiece_model.proto
-
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import message as _message
-from google.protobuf import reflection as _reflection
-from google.protobuf import symbol_database as _symbol_database
-
-
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor.FileDescriptor(
- name="sentencepiece_model.proto",
- package="sentencepiece",
- syntax="proto2",
- serialized_options=b"H\003",
- create_key=_descriptor._internal_create_key,
- serialized_pb=(
- b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01'
- b" \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02"
- b" \x01(\t\x12\x41\n\nmodel_type\x18\x03"
- b" \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04"
- b" \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12"
- b' \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n'
- b" \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b"
- b" \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12"
- b' \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r'
- b" \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e"
- b" \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f"
- b" \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12"
- b" \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10"
- b" \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11"
- b" \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14"
- b" \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15"
- b" \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17"
- b" \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16"
- b" \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18"
- b" \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19"
- b" \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e"
- b" \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$"
- b" \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18"
- b' \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18"'
- b" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18)"
- b" \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+"
- b" \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18."
- b" \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30"
- b" \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87"
- b" \x12+\n\x1ctrain_extremely_large_corpus\x18\x31"
- b' \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01'
- b" \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03"
- b" \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12"
- b" \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06"
- b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01'
- b' \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01'
- b" \x01(\t\x12\x10\n\x08\x65xpected\x18\x02"
- b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01'
- b" \x03(\x0b\x32'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02"
- b" \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03"
- b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04"
- b" \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05"
- b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01"
- b" \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03"
- b' \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
- ),
-)
-
-
-_TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor(
- name="ModelType",
- full_name="sentencepiece.TrainerSpec.ModelType",
- filename=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- values=[
- _descriptor.EnumValueDescriptor(
- name="UNIGRAM",
- index=0,
- number=1,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="BPE",
- index=1,
- number=2,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="WORD",
- index=2,
- number=3,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="CHAR",
- index=3,
- number=4,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- containing_type=None,
- serialized_options=None,
- serialized_start=1294,
- serialized_end=1347,
-)
-_sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE)
-
-_MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor(
- name="Type",
- full_name="sentencepiece.ModelProto.SentencePiece.Type",
- filename=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- values=[
- _descriptor.EnumValueDescriptor(
- name="NORMAL",
- index=0,
- number=1,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="UNKNOWN",
- index=1,
- number=2,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="CONTROL",
- index=2,
- number=3,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="USER_DEFINED",
- index=3,
- number=4,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="BYTE",
- index=4,
- number=6,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.EnumValueDescriptor(
- name="UNUSED",
- index=5,
- number=5,
- serialized_options=None,
- type=None,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- containing_type=None,
- serialized_options=None,
- serialized_start=2100,
- serialized_end=2184,
-)
-_sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE)
-
-
-_TRAINERSPEC = _descriptor.Descriptor(
- name="TrainerSpec",
- full_name="sentencepiece.TrainerSpec",
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- create_key=_descriptor._internal_create_key,
- fields=[
- _descriptor.FieldDescriptor(
- name="input",
- full_name="sentencepiece.TrainerSpec.input",
- index=0,
- number=1,
- type=9,
- cpp_type=9,
- label=3,
- has_default_value=False,
- default_value=[],
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="input_format",
- full_name="sentencepiece.TrainerSpec.input_format",
- index=1,
- number=7,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="model_prefix",
- full_name="sentencepiece.TrainerSpec.model_prefix",
- index=2,
- number=2,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="model_type",
- full_name="sentencepiece.TrainerSpec.model_type",
- index=3,
- number=3,
- type=14,
- cpp_type=8,
- label=1,
- has_default_value=True,
- default_value=1,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="vocab_size",
- full_name="sentencepiece.TrainerSpec.vocab_size",
- index=4,
- number=4,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=8000,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="accept_language",
- full_name="sentencepiece.TrainerSpec.accept_language",
- index=5,
- number=5,
- type=9,
- cpp_type=9,
- label=3,
- has_default_value=False,
- default_value=[],
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="self_test_sample_size",
- full_name="sentencepiece.TrainerSpec.self_test_sample_size",
- index=6,
- number=6,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=0,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="character_coverage",
- full_name="sentencepiece.TrainerSpec.character_coverage",
- index=7,
- number=10,
- type=2,
- cpp_type=6,
- label=1,
- has_default_value=True,
- default_value=float(0.9995),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="input_sentence_size",
- full_name="sentencepiece.TrainerSpec.input_sentence_size",
- index=8,
- number=11,
- type=4,
- cpp_type=4,
- label=1,
- has_default_value=True,
- default_value=0,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="shuffle_input_sentence",
- full_name="sentencepiece.TrainerSpec.shuffle_input_sentence",
- index=9,
- number=19,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="mining_sentence_size",
- full_name="sentencepiece.TrainerSpec.mining_sentence_size",
- index=10,
- number=12,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=False,
- default_value=0,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=b"\030\001",
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="training_sentence_size",
- full_name="sentencepiece.TrainerSpec.training_sentence_size",
- index=11,
- number=13,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=False,
- default_value=0,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=b"\030\001",
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="seed_sentencepiece_size",
- full_name="sentencepiece.TrainerSpec.seed_sentencepiece_size",
- index=12,
- number=14,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=1000000,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="shrinking_factor",
- full_name="sentencepiece.TrainerSpec.shrinking_factor",
- index=13,
- number=15,
- type=2,
- cpp_type=6,
- label=1,
- has_default_value=True,
- default_value=float(0.75),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="max_sentence_length",
- full_name="sentencepiece.TrainerSpec.max_sentence_length",
- index=14,
- number=18,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=4192,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="num_threads",
- full_name="sentencepiece.TrainerSpec.num_threads",
- index=15,
- number=16,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=16,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="num_sub_iterations",
- full_name="sentencepiece.TrainerSpec.num_sub_iterations",
- index=16,
- number=17,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=2,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="max_sentencepiece_length",
- full_name="sentencepiece.TrainerSpec.max_sentencepiece_length",
- index=17,
- number=20,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=16,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="split_by_unicode_script",
- full_name="sentencepiece.TrainerSpec.split_by_unicode_script",
- index=18,
- number=21,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="split_by_number",
- full_name="sentencepiece.TrainerSpec.split_by_number",
- index=19,
- number=23,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="split_by_whitespace",
- full_name="sentencepiece.TrainerSpec.split_by_whitespace",
- index=20,
- number=22,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="treat_whitespace_as_suffix",
- full_name="sentencepiece.TrainerSpec.treat_whitespace_as_suffix",
- index=21,
- number=24,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=False,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="split_digits",
- full_name="sentencepiece.TrainerSpec.split_digits",
- index=22,
- number=25,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=False,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="control_symbols",
- full_name="sentencepiece.TrainerSpec.control_symbols",
- index=23,
- number=30,
- type=9,
- cpp_type=9,
- label=3,
- has_default_value=False,
- default_value=[],
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="user_defined_symbols",
- full_name="sentencepiece.TrainerSpec.user_defined_symbols",
- index=24,
- number=31,
- type=9,
- cpp_type=9,
- label=3,
- has_default_value=False,
- default_value=[],
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="required_chars",
- full_name="sentencepiece.TrainerSpec.required_chars",
- index=25,
- number=36,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="byte_fallback",
- full_name="sentencepiece.TrainerSpec.byte_fallback",
- index=26,
- number=35,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=False,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="vocabulary_output_piece_score",
- full_name="sentencepiece.TrainerSpec.vocabulary_output_piece_score",
- index=27,
- number=32,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="hard_vocab_limit",
- full_name="sentencepiece.TrainerSpec.hard_vocab_limit",
- index=28,
- number=33,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="use_all_vocab",
- full_name="sentencepiece.TrainerSpec.use_all_vocab",
- index=29,
- number=34,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=False,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="unk_id",
- full_name="sentencepiece.TrainerSpec.unk_id",
- index=30,
- number=40,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=0,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="bos_id",
- full_name="sentencepiece.TrainerSpec.bos_id",
- index=31,
- number=41,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=1,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="eos_id",
- full_name="sentencepiece.TrainerSpec.eos_id",
- index=32,
- number=42,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=2,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="pad_id",
- full_name="sentencepiece.TrainerSpec.pad_id",
- index=33,
- number=43,
- type=5,
- cpp_type=1,
- label=1,
- has_default_value=True,
- default_value=-1,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="unk_piece",
- full_name="sentencepiece.TrainerSpec.unk_piece",
- index=34,
- number=45,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=True,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="bos_piece",
- full_name="sentencepiece.TrainerSpec.bos_piece",
- index=35,
- number=46,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=True,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="eos_piece",
- full_name="sentencepiece.TrainerSpec.eos_piece",
- index=36,
- number=47,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=True,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="pad_piece",
- full_name="sentencepiece.TrainerSpec.pad_piece",
- index=37,
- number=48,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=True,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="unk_surface",
- full_name="sentencepiece.TrainerSpec.unk_surface",
- index=38,
- number=44,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=True,
- default_value=b" \342\201\207 ".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="train_extremely_large_corpus",
- full_name="sentencepiece.TrainerSpec.train_extremely_large_corpus",
- index=39,
- number=49,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=False,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- extensions=[],
- nested_types=[],
- enum_types=[
- _TRAINERSPEC_MODELTYPE,
- ],
- serialized_options=None,
- is_extendable=True,
- syntax="proto2",
- extension_ranges=[
- (200, 536870912),
- ],
- oneofs=[],
- serialized_start=45,
- serialized_end=1358,
-)
-
-
-_NORMALIZERSPEC = _descriptor.Descriptor(
- name="NormalizerSpec",
- full_name="sentencepiece.NormalizerSpec",
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- create_key=_descriptor._internal_create_key,
- fields=[
- _descriptor.FieldDescriptor(
- name="name",
- full_name="sentencepiece.NormalizerSpec.name",
- index=0,
- number=1,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="precompiled_charsmap",
- full_name="sentencepiece.NormalizerSpec.precompiled_charsmap",
- index=1,
- number=2,
- type=12,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"",
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="add_dummy_prefix",
- full_name="sentencepiece.NormalizerSpec.add_dummy_prefix",
- index=2,
- number=3,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="remove_extra_whitespaces",
- full_name="sentencepiece.NormalizerSpec.remove_extra_whitespaces",
- index=3,
- number=4,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="escape_whitespaces",
- full_name="sentencepiece.NormalizerSpec.escape_whitespaces",
- index=4,
- number=5,
- type=8,
- cpp_type=7,
- label=1,
- has_default_value=True,
- default_value=True,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="normalization_rule_tsv",
- full_name="sentencepiece.NormalizerSpec.normalization_rule_tsv",
- index=5,
- number=6,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- extensions=[],
- nested_types=[],
- enum_types=[],
- serialized_options=None,
- is_extendable=True,
- syntax="proto2",
- extension_ranges=[
- (200, 536870912),
- ],
- oneofs=[],
- serialized_start=1361,
- serialized_end=1570,
-)
-
-
-_SELFTESTDATA_SAMPLE = _descriptor.Descriptor(
- name="Sample",
- full_name="sentencepiece.SelfTestData.Sample",
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- create_key=_descriptor._internal_create_key,
- fields=[
- _descriptor.FieldDescriptor(
- name="input",
- full_name="sentencepiece.SelfTestData.Sample.input",
- index=0,
- number=1,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="expected",
- full_name="sentencepiece.SelfTestData.Sample.expected",
- index=1,
- number=2,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- extensions=[],
- nested_types=[],
- enum_types=[],
- serialized_options=None,
- is_extendable=False,
- syntax="proto2",
- extension_ranges=[],
- oneofs=[],
- serialized_start=1641,
- serialized_end=1682,
-)
-
-_SELFTESTDATA = _descriptor.Descriptor(
- name="SelfTestData",
- full_name="sentencepiece.SelfTestData",
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- create_key=_descriptor._internal_create_key,
- fields=[
- _descriptor.FieldDescriptor(
- name="samples",
- full_name="sentencepiece.SelfTestData.samples",
- index=0,
- number=1,
- type=11,
- cpp_type=10,
- label=3,
- has_default_value=False,
- default_value=[],
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- extensions=[],
- nested_types=[
- _SELFTESTDATA_SAMPLE,
- ],
- enum_types=[],
- serialized_options=None,
- is_extendable=True,
- syntax="proto2",
- extension_ranges=[
- (200, 536870912),
- ],
- oneofs=[],
- serialized_start=1572,
- serialized_end=1693,
-)
-
-
-_MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor(
- name="SentencePiece",
- full_name="sentencepiece.ModelProto.SentencePiece",
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- create_key=_descriptor._internal_create_key,
- fields=[
- _descriptor.FieldDescriptor(
- name="piece",
- full_name="sentencepiece.ModelProto.SentencePiece.piece",
- index=0,
- number=1,
- type=9,
- cpp_type=9,
- label=1,
- has_default_value=False,
- default_value=b"".decode("utf-8"),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="score",
- full_name="sentencepiece.ModelProto.SentencePiece.score",
- index=1,
- number=2,
- type=2,
- cpp_type=6,
- label=1,
- has_default_value=False,
- default_value=float(0),
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="type",
- full_name="sentencepiece.ModelProto.SentencePiece.type",
- index=2,
- number=3,
- type=14,
- cpp_type=8,
- label=1,
- has_default_value=True,
- default_value=1,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- extensions=[],
- nested_types=[],
- enum_types=[
- _MODELPROTO_SENTENCEPIECE_TYPE,
- ],
- serialized_options=None,
- is_extendable=True,
- syntax="proto2",
- extension_ranges=[
- (200, 536870912),
- ],
- oneofs=[],
- serialized_start=1985,
- serialized_end=2195,
-)
-
-_MODELPROTO = _descriptor.Descriptor(
- name="ModelProto",
- full_name="sentencepiece.ModelProto",
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- create_key=_descriptor._internal_create_key,
- fields=[
- _descriptor.FieldDescriptor(
- name="pieces",
- full_name="sentencepiece.ModelProto.pieces",
- index=0,
- number=1,
- type=11,
- cpp_type=10,
- label=3,
- has_default_value=False,
- default_value=[],
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="trainer_spec",
- full_name="sentencepiece.ModelProto.trainer_spec",
- index=1,
- number=2,
- type=11,
- cpp_type=10,
- label=1,
- has_default_value=False,
- default_value=None,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="normalizer_spec",
- full_name="sentencepiece.ModelProto.normalizer_spec",
- index=2,
- number=3,
- type=11,
- cpp_type=10,
- label=1,
- has_default_value=False,
- default_value=None,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="self_test_data",
- full_name="sentencepiece.ModelProto.self_test_data",
- index=3,
- number=4,
- type=11,
- cpp_type=10,
- label=1,
- has_default_value=False,
- default_value=None,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- _descriptor.FieldDescriptor(
- name="denormalizer_spec",
- full_name="sentencepiece.ModelProto.denormalizer_spec",
- index=4,
- number=5,
- type=11,
- cpp_type=10,
- label=1,
- has_default_value=False,
- default_value=None,
- message_type=None,
- enum_type=None,
- containing_type=None,
- is_extension=False,
- extension_scope=None,
- serialized_options=None,
- file=DESCRIPTOR,
- create_key=_descriptor._internal_create_key,
- ),
- ],
- extensions=[],
- nested_types=[
- _MODELPROTO_SENTENCEPIECE,
- ],
- enum_types=[],
- serialized_options=None,
- is_extendable=True,
- syntax="proto2",
- extension_ranges=[
- (200, 536870912),
- ],
- oneofs=[],
- serialized_start=1696,
- serialized_end=2206,
-)
-
-_TRAINERSPEC.fields_by_name["model_type"].enum_type = _TRAINERSPEC_MODELTYPE
-_TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC
-_SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA
-_SELFTESTDATA.fields_by_name["samples"].message_type = _SELFTESTDATA_SAMPLE
-_MODELPROTO_SENTENCEPIECE.fields_by_name["type"].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE
-_MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO
-_MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE
-_MODELPROTO.fields_by_name["pieces"].message_type = _MODELPROTO_SENTENCEPIECE
-_MODELPROTO.fields_by_name["trainer_spec"].message_type = _TRAINERSPEC
-_MODELPROTO.fields_by_name["normalizer_spec"].message_type = _NORMALIZERSPEC
-_MODELPROTO.fields_by_name["self_test_data"].message_type = _SELFTESTDATA
-_MODELPROTO.fields_by_name["denormalizer_spec"].message_type = _NORMALIZERSPEC
-DESCRIPTOR.message_types_by_name["TrainerSpec"] = _TRAINERSPEC
-DESCRIPTOR.message_types_by_name["NormalizerSpec"] = _NORMALIZERSPEC
-DESCRIPTOR.message_types_by_name["SelfTestData"] = _SELFTESTDATA
-DESCRIPTOR.message_types_by_name["ModelProto"] = _MODELPROTO
-_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-
-TrainerSpec = _reflection.GeneratedProtocolMessageType(
- "TrainerSpec",
- (_message.Message,),
- {
- "DESCRIPTOR": _TRAINERSPEC,
- "__module__": "sentencepiece_model_pb2",
- # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec)
- },
-)
-_sym_db.RegisterMessage(TrainerSpec)
-
-NormalizerSpec = _reflection.GeneratedProtocolMessageType(
- "NormalizerSpec",
- (_message.Message,),
- {
- "DESCRIPTOR": _NORMALIZERSPEC,
- "__module__": "sentencepiece_model_pb2",
- # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec)
- },
-)
-_sym_db.RegisterMessage(NormalizerSpec)
-
-SelfTestData = _reflection.GeneratedProtocolMessageType(
- "SelfTestData",
- (_message.Message,),
- {
- "Sample": _reflection.GeneratedProtocolMessageType(
- "Sample",
- (_message.Message,),
- {
- "DESCRIPTOR": _SELFTESTDATA_SAMPLE,
- "__module__": "sentencepiece_model_pb2",
- # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample)
- },
- ),
- "DESCRIPTOR": _SELFTESTDATA,
- "__module__": "sentencepiece_model_pb2",
- # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData)
- },
-)
-_sym_db.RegisterMessage(SelfTestData)
-_sym_db.RegisterMessage(SelfTestData.Sample)
-
-ModelProto = _reflection.GeneratedProtocolMessageType(
- "ModelProto",
- (_message.Message,),
- {
- "SentencePiece": _reflection.GeneratedProtocolMessageType(
- "SentencePiece",
- (_message.Message,),
- {
- "DESCRIPTOR": _MODELPROTO_SENTENCEPIECE,
- "__module__": "sentencepiece_model_pb2",
- # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece)
- },
- ),
- "DESCRIPTOR": _MODELPROTO,
- "__module__": "sentencepiece_model_pb2",
- # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto)
- },
-)
-_sym_db.RegisterMessage(ModelProto)
-_sym_db.RegisterMessage(ModelProto.SentencePiece)
-
-
-DESCRIPTOR._options = None
-_TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None
-_TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None
-# @@protoc_insertion_point(module_scope)
diff --git a/src/maxdiffusion/transformers/utils/sentencepiece_model_pb2_new.py b/src/maxdiffusion/transformers/utils/sentencepiece_model_pb2_new.py
deleted file mode 100644
index 4a0b51104..000000000
--- a/src/maxdiffusion/transformers/utils/sentencepiece_model_pb2_new.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# -*- coding: utf-8 -*-
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: sentencepiece_model.proto
-"""Generated protocol buffer code."""
-
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf.internal import builder as _builder
-
-
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
-)
-
-_globals = globals()
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
-if _descriptor._USE_C_DESCRIPTORS is False:
- DESCRIPTOR._options = None
- DESCRIPTOR._serialized_options = b"H\003"
- # (generated by protobuf compiler, but `_TRAINERSPEC` is not defined)
- # _TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None
- # _TRAINERSPEC.fields_by_name["mining_sentence_size"]._serialized_options = b"\030\001"
- # _TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None
- # _TRAINERSPEC.fields_by_name["training_sentence_size"]._serialized_options = b"\030\001"
- _globals["_TRAINERSPEC"]._serialized_start = 45
- _globals["_TRAINERSPEC"]._serialized_end = 1581
- _globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
- _globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
- _globals["_NORMALIZERSPEC"]._serialized_start = 1584
- _globals["_NORMALIZERSPEC"]._serialized_end = 1793
- _globals["_SELFTESTDATA"]._serialized_start = 1795
- _globals["_SELFTESTDATA"]._serialized_end = 1916
- _globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
- _globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
- _globals["_MODELPROTO"]._serialized_start = 1919
- _globals["_MODELPROTO"]._serialized_end = 2429
- _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
- _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
- _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
- _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
-# @@protoc_insertion_point(module_scope)
diff --git a/src/maxdiffusion/transformers/utils/versions.py b/src/maxdiffusion/transformers/utils/versions.py
deleted file mode 100644
index c6dcba2fb..000000000
--- a/src/maxdiffusion/transformers/utils/versions.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# Copyright 2020 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Utilities for working with package versions
-"""
-
-import importlib.metadata
-import operator
-import re
-import sys
-from typing import Optional
-
-from packaging import version
-
-
-ops = {
- "<": operator.lt,
- "<=": operator.le,
- "==": operator.eq,
- "!=": operator.ne,
- ">=": operator.ge,
- ">": operator.gt,
-}
-
-
-def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
- if got_ver is None or want_ver is None:
- raise ValueError(
- f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider"
- f" reinstalling {pkg}."
- )
- if not ops[op](version.parse(got_ver), version.parse(want_ver)):
- raise ImportError(
- f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
- )
-
-
-def require_version(requirement: str, hint: Optional[str] = None) -> None:
- """
- Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
-
- The installed module version comes from the *site-packages* dir via *importlib.metadata*.
-
- Args:
- requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
- hint (`str`, *optional*): what suggestion to print in case of requirements not being met
-
- Example:
-
- ```python
- require_version("pandas>1.1.2")
- require_version("numpy>1.18.5", "this is important to have for whatever reason")
- ```"""
-
- hint = f"\n{hint}" if hint is not None else ""
-
- # non-versioned check
- if re.match(r"^[\w_\-\d]+$", requirement):
- pkg, op, want_ver = requirement, None, None
- else:
- match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
- if not match:
- raise ValueError(
- "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but"
- f" got {requirement}"
- )
- pkg, want_full = match[0]
- want_range = want_full.split(",") # there could be multiple requirements
- wanted = {}
- for w in want_range:
- match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
- if not match:
- raise ValueError(
- "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23,"
- f" but got {requirement}"
- )
- op, want_ver = match[0]
- wanted[op] = want_ver
- if op not in ops:
- raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
-
- # special case
- if pkg == "python":
- got_ver = ".".join([str(x) for x in sys.version_info[:3]])
- for op, want_ver in wanted.items():
- _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
- return
-
- # check if any version is installed
- try:
- got_ver = importlib.metadata.version(pkg)
- except importlib.metadata.PackageNotFoundError:
- raise importlib.metadata.PackageNotFoundError(
- f"The '{requirement}' distribution was not found and is required by this application. {hint}"
- )
-
- # check that the right version is installed if version number or a range was provided
- if want_ver is not None:
- for op, want_ver in wanted.items():
- _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
-
-
-def require_version_core(requirement):
- """require_version wrapper which emits a core-specific hint on failure"""
- hint = "Try: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main"
- return require_version(requirement, hint)