| | import contextvars |
| | import importlib |
| | from contextlib import contextmanager |
| | from typing import Any, Type |
| |
|
| |
|
| | def get_class_from_str(class_str: str, package: str | None = None) -> Type[Any]: |
| | """ |
| | Converts a string to the corresponding class object, supporting relative imports. |
| | For relative module paths (starting with '.'), a package must be provided. |
| | |
| | Args: |
| | class_str: String representation of the class, either absolute or relative. |
| | package: Package context, only required for relative imports. |
| | |
| | Returns: Class object corresponding to the provided string. |
| | """ |
| | if not isinstance(class_str, str) and isinstance(class_str, type): |
| | return class_str |
| |
|
| | module_path, _, class_name = class_str.rpartition(".") |
| | if not module_path and class_str.startswith("."): |
| | module_path = "." |
| | if module_path.startswith("."): |
| | if not package: |
| | raise ValueError("Relative module path provided without a package context.") |
| | module = importlib.import_module(module_path, package=package) |
| | else: |
| | module = importlib.import_module(module_path) |
| | return getattr(module, class_name) |
| |
|
| |
|
| | def get_str_from_class(cls: Type[Any], package: str | None = None) -> str: |
| | """ |
| | Converts a class object to its string representation. |
| | If a package is provided and the class's module is a submodule of the package, |
| | the returned string will use a relative import. |
| | Otherwise, an absolute import string is returned. |
| | |
| | Args: |
| | cls: Class object to convert. |
| | package: Package context, only required for relative imports. |
| | |
| | Returns: String representation of the class. |
| | """ |
| | if isinstance(cls, str): |
| | return cls |
| |
|
| | module_path = cls.__module__ |
| | class_name = cls.__name__ |
| |
|
| | if package: |
| | |
| | if module_path == package: |
| | return f".{class_name}" |
| | |
| | elif module_path.startswith(package + "."): |
| | |
| | relative = module_path[len(package) :] |
| | if not relative.startswith("."): |
| | relative = "." + relative |
| | return f"{relative}.{class_name}" |
| | return f"{module_path}.{class_name}" |
| |
|
| |
|
| | use_init_empty_weights = contextvars.ContextVar("init_empty_weights", default=False) |
| |
|
| |
|
| | @contextmanager |
| | def init_empty_weights(value: bool): |
| | """ |
| | Context manager to indicate that a (parametrized) model should be initialized with empty weights or not. |
| | If active, `use_init_empty_weights` will be set to `True` otherwise to `False`. |
| | To check if the context is active, import and check `use_init_empty_weights.get()`. |
| | |
| | Args: |
| | value: Indicates whether the model should be initialized with empty weights or not. |
| | """ |
| | token = use_init_empty_weights.set(value) |
| | try: |
| | yield |
| | finally: |
| | use_init_empty_weights.reset(token) |
| |
|