From 28578b96bf6217fa2b79699838e5a4af30843de4 Mon Sep 17 00:00:00 2001 From: Yi Liu <106061964+yiliu30@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:19:27 +0800 Subject: [PATCH] Add docstring for `common` module (#1905) Signed-off-by: yiliu30 --- neural_compressor/common/__init__.py | 1 + neural_compressor/common/base_config.py | 309 ++++++++++++++++++-- neural_compressor/common/base_tuning.py | 139 ++++++++- neural_compressor/common/tuning_param.py | 36 ++- neural_compressor/common/utils/__init__.py | 1 + neural_compressor/common/utils/constants.py | 6 +- neural_compressor/common/utils/logger.py | 8 + neural_compressor/common/utils/save_load.py | 2 +- neural_compressor/common/utils/utility.py | 20 +- 9 files changed, 475 insertions(+), 47 deletions(-) diff --git a/neural_compressor/common/__init__.py b/neural_compressor/common/__init__.py index 93b3de4b22b..e38627d5c7c 100644 --- a/neural_compressor/common/__init__.py +++ b/neural_compressor/common/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The common module.""" from neural_compressor.common.utils import ( level, diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 267a1ed5deb..3138ed2e728 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The base config.""" from __future__ import annotations @@ -51,12 +52,23 @@ ] -# Config registry to store all registered configs. class ConfigRegistry(object): + """A registry for managing configuration classes for different algorithms within specific frameworks.""" + registered_configs = {} _config_registry = None def __new__(cls) -> Self: + """Create a new instance of the ConfigRegistry class. + + This method is responsible for creating a new instance of the ConfigRegistry class. + It ensures that only one instance of the class is created by checking if the `_config_registry` + attribute is None. If it is None, a new instance is created and assigned to `_config_registry`. + If `_config_registry` is not None, the existing instance is returned. + + Returns: + The instance of the ConfigRegistry class. + """ if cls._config_registry is None: cls._config_registry = super(ConfigRegistry, cls).__new__(cls) @@ -64,20 +76,22 @@ def __new__(cls) -> Self: @classmethod def register_config_impl(cls, framework_name: str, algo_name: str, priority: Union[float, int] = 0): - """Register config decorator. + """Register a configuration decorator. - The register the configuration classes for different algorithms within specific frameworks. + This decorator is used to register the configuration classes + for different algorithms within specific frameworks. Usage example: - @ConfigRegistry.register_config(framework_name=FRAMEWORK_NAME, algo_name=ExampleAlgorithm, priority=100) + @ConfigRegistry.register_config_impl(framework_name=FRAMEWORK_NAME, algo_name=ExampleAlgorithm, priority=1) class ExampleAlgorithmConfig: # Configuration details for the ExampleAlgorithm Args: - framework_name: the framework name. - algo_name: the algorithm name. - priority: priority: the priority of the configuration. A larger number indicates a higher priority, - which will be tried first at the auto-tune stage. Defaults to 0. + framework_name (str): The framework name. + algo_name (str): The algorithm name. + priority (Union[float, int], optional): The priority of the configuration. + A larger number indicates a higher priority, which will be tried first + at the auto-tune stage. Defaults to 0. """ def decorator(config_cls): @@ -89,12 +103,21 @@ def decorator(config_cls): @classmethod def get_all_configs(cls) -> Dict[str, Dict[str, Dict[str, object]]]: - """Get all registered configurations.""" + """Get all registered configurations. + + Returns: + Dict[str, Dict[str, Dict[str, object]]]: A dictionary containing all registered configurations. + """ return cls.registered_configs @classmethod def get_sorted_configs(cls) -> Dict[str, OrderedDict[str, Dict[str, object]]]: - """Get registered configurations sorted by priority.""" + """Get registered configurations sorted by priority. + + Returns: + Dict[str, OrderedDict[str, Dict[str, object]]]: + A dictionary containing registered configurations sorted by priority. + """ sorted_configs = OrderedDict() for framework_name, algos in sorted(cls.registered_configs.items()): sorted_configs[framework_name] = OrderedDict( @@ -104,7 +127,11 @@ def get_sorted_configs(cls) -> Dict[str, OrderedDict[str, Dict[str, object]]]: @classmethod def get_cls_configs(cls) -> Dict[str, Dict[str, object]]: - """Get registered configurations without priority.""" + """Get registered configurations without priority. + + Returns: + Dict[str, Dict[str, object]]: A dictionary containing registered configurations without priority. + """ cls_configs = {} for framework_name, algos in cls.registered_configs.items(): cls_configs[framework_name] = {} @@ -114,6 +141,14 @@ def get_cls_configs(cls) -> Dict[str, Dict[str, object]]: @classmethod def get_all_config_cls_by_fwk_name(cls, fwk_name: str) -> List[Type[BaseConfig]]: + """Get all registered configuration classes for a specific framework. + + Args: + fwk_name (str): The framework name. + + Returns: + List[Type[BaseConfig]]: A list of all registered configuration classes for the specified framework. + """ configs_cls = [] for algo_name, config_pairs in cls.registered_configs.get(fwk_name, {}).items(): configs_cls.append(config_pairs["cls"]) @@ -139,21 +174,31 @@ class ExampleAlgorithmConfig: priority: the priority of the configuration. A larger number indicates a higher priority, which will be tried first at the auto-tune stage. Defaults to 0. """ - return config_registry.register_config_impl(framework_name=framework_name, algo_name=algo_name, priority=priority) class BaseConfig(ABC): - """The base config for all algorithm configs.""" + """The base config for all algorithm configs. + + Attributes: + name (str): The name of the config. + params_list (list): The list of **tunable parameters** in the config. + """ name = BASE_CONFIG params_list = [] def __init__(self, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST) -> None: + """Initialize the BaseConfig. + + Args: + white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): The white list of operator names or types. + Defaults to DEFAULT_WHITE_LIST. + """ self._global_config: Optional[BaseConfig] = None # For PyTorch, operator_type is the collective name for module type and functional operation type, # for example, `torch.nn.Linear`, and `torch.nn.functional.linear`. - # local config is the collections of operator_type configs and operator configs + # local config is the collections of operator_type configs and operator configs. self._local_config: Dict[str, Optional[BaseConfig]] = {} self._white_list = white_list @@ -176,37 +221,67 @@ def _post_init(self): @property def white_list(self): + """Get the white list of operator names or types. + + Returns: + The white list of operator names or types. + """ return self._white_list @white_list.setter def white_list(self, op_name_or_type_list: Optional[List[OP_NAME_OR_MODULE_TYPE]]): + """Set the white list of operator names or types. + + Args: + op_name_or_type_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): The white list of operator names or types. + """ self._white_list = op_name_or_type_list @property def global_config(self): + """Get the global configuration object. + + Returns: + The global configuration object. + """ return self._global_config @global_config.setter def global_config(self, config): + """Set the global configuration object. + + Args: + config: The global configuration object. + """ self._global_config = config @property def local_config(self): + """Get the local configuration objects. + + Returns: + The local configuration objects. + """ return self._local_config @local_config.setter def local_config(self, config): + """Set the local configuration objects. + + Args: + config: The local configuration objects. + """ self._local_config = config def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig: """Set custom configuration based on the global configuration object. Args: - operator_name_or_list (Union[List, str, Callable]): specific operator - config (BaseConfig): specific configuration + operator_name_or_list (Union[List, str, Callable]): Specific operator name or list of operator names. + config (BaseConfig): Specific configuration. Returns: - Updated Config + Updated Config. """ if isinstance(operator_name_or_list, list): for operator_name in operator_name_or_list: @@ -220,6 +295,11 @@ def set_local(self, operator_name_or_list: Union[List, str, Callable], config: B return self def to_dict(self): + """Convert the config to a dictionary. + + Returns: + The config as a dictionary. + """ result = {} global_config = self.get_params_dict() if bool(self.local_config): @@ -233,6 +313,11 @@ def to_dict(self): return result def get_params_dict(self): + """Get a dictionary containing the parameters and their values for the current instance. + + Returns: + A dictionary containing the parameters and their values. + """ result = dict() for param, value in self.__dict__.items(): if param not in ["_global_config", "_local_config", "_white_list"]: @@ -241,10 +326,10 @@ def get_params_dict(self): @classmethod def from_dict(cls, config_dict): - """Construct config from a dict. + """Construct config from a dictionary. Args: - config_dict: _description_ + config_dict: The dictionary containing the config. Returns: The constructed config. @@ -262,31 +347,49 @@ def from_dict(cls, config_dict): @classmethod def to_diff_dict(cls, instance) -> Dict[str, Any]: + """Compare the instance with the default BaseConfig and return the differences as a dictionary. + + Args: + instance: The instance to compare. + + Returns: + A dictionary representation of the instance with only the differences from the class defaults. + """ # TODO (Yi) to implement it return {} @classmethod def from_json_file(cls, filename): + """Load config from a JSON file. + + Args: + filename (str): The path to the JSON file. + + Returns: + The loaded config. + """ with open(filename, "r", encoding="utf-8") as file: config_dict = json.load(file) return cls.from_dict(**config_dict) def to_json_file(self, filename): - config_dict = self.to_dict() - with open(filename, "w", encoding="utf-8") as file: - json.dump(config_dict, file, indent=4) - logger.info("Dump the config into %s.", filename) + """Save the config to a JSON file. + + Args: + filename (str): The path to save the JSON file. + """ + # Implementation details omitted for brevity + pass def to_json_string(self, use_diff: bool = False) -> str: """Serializes this instance to a JSON string. Args: - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default `BaseConfig()` - is serialized to JSON string. + use_diff (bool, optional): If True, only the difference between the config instance and the default + BaseConfig is serialized to JSON string. Defaults to False. Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. + The config as a JSON string. """ if use_diff is True: config_dict = self.to_diff_dict(self) @@ -298,6 +401,11 @@ def to_json_string(self, use_diff: bool = False) -> str: return config_dict def __repr__(self) -> str: + """Return a string representation of the config. + + Returns: + str: The string representation of the config. + """ return f"{self.__class__.__name__} {self.to_json_string()}" @classmethod @@ -308,10 +416,29 @@ def register_supported_configs(cls): @classmethod def validate(self, user_config: BaseConfig): + """Validates the user configuration. + + Args: + user_config (BaseConfig): The user configuration to be validated. + + Returns: + None + """ # TODO(Yi) validate the user config pass def __add__(self, other: BaseConfig) -> BaseConfig: + """Combine two configs. + + If the other config is an instance of the same class, the local configs will be combined. + Otherwise, a `ComposableConfig` will be created to combine the two configs. + + Args: + other (BaseConfig): The other config to combine. + + Returns: + BaseConfig: The combined config. + """ if isinstance(other, type(self)): for op_name, config in other.local_config.items(): self.set_local(op_name, config) @@ -321,6 +448,15 @@ def __add__(self, other: BaseConfig) -> BaseConfig: @staticmethod def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any: + """Get the default value of a parameter in the config. + + Args: + config (BaseConfig): The config object. + param (str): The name of the parameter. + + Returns: + default_vaule: The default value of the parameter. + """ # Get the signature of the __init__ method signature = inspect.signature(config.__init__) @@ -420,6 +556,20 @@ def _get_op_name_op_type_config(self): def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]: + """Generate the configuration mapping based on the model information. + + Args: + config_list (List[BaseConfig], optional): A list of BaseConfig objects to be converted. + If not provided, the method will use the current instance of BaseConfig. Defaults to None. + model_info (List[Tuple[str, str]], optional): A list of tuples representing the model information. + Each tuple contains the operation name and operation type. Defaults to None. + + Returns: + OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]: + A OrderedDict representing the configuration mapping. + The keys of the outer OrderedDict are tuples of (operation name, operation type), + and the values are inner OrderedDicts containing the corresponding configuration objects. + """ config_mapping = OrderedDict() if config_list is None: config_list = [self] @@ -452,9 +602,28 @@ def _is_op_type(name: str) -> bool: @classmethod @abstractmethod def get_config_set_for_tuning(cls): + """A set of predefined configurations used for tuning. + + This method should be implemented by subclasses to provide a set of configurations + that can be used for auto-tune. + + Returns: + set: A set of configurations for tuning. + + Raises: + NotImplementedError: If the method is not implemented by the subclass. + """ raise NotImplementedError def __eq__(self, other: BaseConfig) -> bool: + """Check if the current BaseConfig object is equal to another BaseConfig object. + + Args: + other (BaseConfig): The other BaseConfig object to compare with. + + Returns: + bool: True if the objects are equal, False otherwise. + """ if not isinstance(other, type(self)): return False return self.params_list == other.params_list and all( @@ -463,12 +632,42 @@ def __eq__(self, other: BaseConfig) -> bool: class ComposableConfig(BaseConfig): + """A class representing a composable configuration. + + This class allows for composing multiple configurations together by using the `+` operator. + + Args: + configs (List[BaseConfig]): A list of base configurations to be composed. + + Attributes: + config_list (List[BaseConfig]): The list of base configurations. + """ + name = COMPOSABLE_CONFIG def __init__(self, configs: List[BaseConfig]) -> None: + """Initializes a new ComposableConfig. + + Args: + configs (List[BaseConfig]): A list of BaseConfig objects. + + Returns: + None + """ self.config_list = configs def __add__(self, other: BaseConfig) -> BaseConfig: + """Adds another BaseConfig object to the current BaseConfig object. + + If the other object is of the same type as the current object, the config_list of the other object is appended + to the config_list of the current object. Otherwise, the other object is appended directly to the config_list. + + Args: + other (BaseConfig): The other BaseConfig object to be added. + + Returns: + BaseConfig: The updated BaseConfig object after the addition. + """ if isinstance(other, type(self)): self.config_list.extend(other.config_list) else: @@ -476,6 +675,17 @@ def __add__(self, other: BaseConfig) -> BaseConfig: return self def to_dict(self, params_list=[], operator2str=None): + """Converts the configuration object to a dictionary. + + Args: + params_list (list): A list of parameters to include in the dictionary. + If empty, all parameters will be included. + operator2str (callable): A function that converts operator objects to strings. + If None, the default conversion will be used. + + Returns: + dict: A dictionary representation of the configuration object. + """ result = {} for config in self.config_list: result[config.name] = config.to_dict() @@ -483,6 +693,18 @@ def to_dict(self, params_list=[], operator2str=None): @classmethod def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[str, BaseConfig]): + """Create a BaseConfig object from a dictionary representation. + + Args: + config_dict (OrderedDict[str, Dict]): The dictionary representation of the configuration. + config_registry (Dict[str, BaseConfig]): The registry of available configurations. + + Returns: + BaseConfig: The created BaseConfig object. + + Raises: + AssertionError: If the config_dict does not include at least one configuration. + """ assert len(config_dict) >= 1, "The config dict must include at least one configuration." num_configs = len(config_dict) name, value = next(iter(config_dict.items())) @@ -493,14 +715,37 @@ def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[st return config def to_json_string(self, use_diff: bool = False) -> str: + """Convert the object to a JSON string representation. + + Args: + use_diff (bool): Whether to include only the differences from the base configuration. + Defaults to False. + + Returns: + str: The JSON string representation of the object. + """ return json.dumps(self.to_dict(), indent=2) + "\n" def __repr__(self) -> str: + """A string representation of the object. + + Returns: + str: The string representation of the object. + """ return f"{self.__class__.__name__} {self.to_json_string()}" def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: Dict[str, Any] = None ) -> OrderedDict[str, BaseConfig]: + """Converts the configuration list to a mapping of (op_name, op_type) to corresponding BaseConfig objects. + + Args: + config_list (List[BaseConfig], optional): List of BaseConfig objects. Defaults to None. + model_info (Dict[str, Any], optional): Dictionary containing model information. Defaults to None. + + Returns: + OrderedDict[str, BaseConfig]: Mapping of (op_name, op_type) to corresponding BaseConfig objects. + """ config_mapping = OrderedDict() for config in self.config_list: op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() @@ -520,10 +765,12 @@ def register_supported_configs(cls): @classmethod def get_config_set_for_tuning(cls) -> None: + """Get the set of predefined configurations used for tuning.""" # TODO (Yi) handle the composable config in `tuning_config` return None def get_model_info(self, model, *args, **kwargs): + """Get the model information.""" model_info_dict = dict() for config in self.config_list: model_info_dict.update({config.name: config.get_model_info(model, *args, **kwargs)}) @@ -531,6 +778,14 @@ def get_model_info(self, model, *args, **kwargs): def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]: + """Retrieves all the configuration sets from the config registry for a given framework name. + + Args: + fwk_name (str): The name of the framework. + + Returns: + Union[BaseConfig, List[BaseConfig]]: The configuration set(s) for the given framework name. + """ all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name) config_set = [] for config_cls in all_registered_config_cls: diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index 30910a865f7..88f6be5b188 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""The auto-tune module.""" import copy import uuid @@ -36,6 +36,8 @@ class EvaluationFuncWrapper: + """Evaluation function wrapper.""" + def __init__(self, eval_fn: Callable, eval_args=None): """Evaluation function wrapper. @@ -47,6 +49,14 @@ def __init__(self, eval_fn: Callable, eval_args=None): self.eval_args = eval_args def evaluate(self, model) -> Union[float, int]: + """Evaluates the given model using the evaluation function and arguments provided. + + Args: + model: The model to be evaluated. + + Returns: + The evaluation result, which can be a float or an integer. + """ result = self.eval_fn(model, *self.eval_args) if self.eval_args else self.eval_fn(model) return result @@ -79,6 +89,7 @@ def eval_perf(molde): EVAL_FN_TEMPLATE: Dict[str, Any] = {EVAL_FN: None, WEIGHT: 1.0, FN_NAME: None} def __init__(self) -> None: + """Initializes the BaseTuning class.""" self.eval_fn_registry: List[Dict[str, Any]] = [] def evaluate(self, model) -> float: @@ -101,6 +112,11 @@ def _update_the_objective_score(self, eval_pair, eval_result, overall_result) -> return overall_result + eval_result * eval_pair[self.WEIGHT] def get_number_of_eval_functions(self) -> int: + """Returns the number of evaluation functions in the eval_fn_registry. + + Returns: + int: The number of evaluation functions. + """ return len(self.eval_fn_registry) def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None: @@ -114,7 +130,22 @@ def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None: ] def set_eval_fn_registry(self, eval_fns: Optional[Union[Callable, Dict, List[Dict]]] = None) -> None: - # About the eval_fns format, refer the class docstring for details. + """Set the evaluation function registry. + + Args: + eval_fns (Optional[Union[Callable, Dict, List[Dict]]]): The evaluation function(s) to be registered. + It can be a single function, a dictionary, or a list of dictionaries. + If `eval_fns` is None, the method returns without making any changes. + If `eval_fns` is a single function, it will be converted into a dictionary and added to the registry. + If `eval_fns` is a dictionary, it will be added to the registry as is. + If `eval_fns` is a list of dictionaries, each dictionary will be added to the registry. + + Raises: + NotImplementedError: If `eval_fns` is not a dict or a list of dicts. + + Note: + The format of the evaluation function(s) should follow the class docstring for details. + """ if eval_fns is None: return elif callable(eval_fns): @@ -132,6 +163,11 @@ def set_eval_fn_registry(self, eval_fns: Optional[Union[Callable, Dict, List[Dic self._set_eval_fn_registry(eval_fns) def self_check(self) -> None: + """Perform a self-check to ensure that there is at least one evaluation metric registered for auto-tune. + + Raises: + AssertionError: If no evaluation metric is registered for auto-tune. + """ # check the number of evaluation functions num_eval_fns = self.get_number_of_eval_functions() assert num_eval_fns > 0, "Please ensure that you register at least one evaluation metric for auto-tune." @@ -142,15 +178,26 @@ def self_check(self) -> None: class ConfigSet: + """A class representing a set of configurations. + + Args: + config_list (List[BaseConfig]): A list of BaseConfig objects. + + Attributes: + config_list (List[BaseConfig]): The list of BaseConfig objects. + """ def __init__(self, config_list: List[BaseConfig]) -> None: + """Initializes a ConfigSet object.""" self.config_list = config_list def __getitem__(self, index) -> BaseConfig: + """Get the config object by index.""" assert 0 <= index < len(self.config_list), f"Index {index} out of range." return self.config_list[index] def __len__(self) -> int: + """Get the number of configs in the config_list.""" return len(self.config_list) @classmethod @@ -168,10 +215,19 @@ def _from_list_of_configs(cls, fwk_configs: List[BaseConfig]) -> List[BaseConfig @classmethod def generate_config_list(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]): - # There are several cases for the input `fwk_configs`: - # 1. fwk_configs is a single config - # 2. fwk_configs is a list of configs - # For a single config, we need to check if it can be expanded or not. + """Generate the config_list based on the input fwk_configs. + + There are several cases for the input `fwk_configs`: + 1. fwk_configs is a single config + 2. fwk_configs is a list of configs + For a single config, we need to check if it can be expanded or not. + + Args: + fwk_configs (Union[BaseConfig, List[BaseConfig]]): A single config or a list of configs. + + Returns: + List[BaseConfig]: The generated config_list. + """ config_list = [] if isinstance(fwk_configs, BaseConfig): config_list = cls._from_single_config(fwk_configs) @@ -187,10 +243,11 @@ def from_fwk_configs(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]) -> " Args: fwk_configs: A single config or a list of configs. - Examples: - 1) single config: RTNConfig(weight_group_size=32) - 2) single expandable config: RTNConfig(weight_group_size=[32, 64]) - 3) mixed 1) and 2): [RTNConfig(weight_group_size=32), RTNConfig(weight_group_size=[32, 64])] + + Examples of `fwk_configs`: + 1) single config: RTNConfig(weight_group_size=32) + 2) single expandable config: RTNConfig(weight_group_size=[32, 64]) + 3) mixed 1) and 2): [RTNConfig(weight_group_size=32), RTNConfig(weight_group_size=[32, 64])] Returns: ConfigSet: A ConfigSet object. @@ -200,7 +257,10 @@ def from_fwk_configs(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]) -> " class Sampler: + """Base class for samplers.""" + def __init__(self, config_source: Optional[ConfigSet]) -> None: + """Initializes a Sampler object.""" pass def __iter__(self) -> Iterator[BaseConfig]: @@ -218,12 +278,15 @@ class SequentialSampler(Sampler): config_source: Sized def __init__(self, config_source: Sized) -> None: + """Initializes a SequentialSampler object.""" self.config_source = config_source def __iter__(self) -> Iterator[int]: + """Iterate over indices of config set elements.""" return iter(range(len(self.config_source))) def __len__(self) -> int: + """Get the number of configs in the config_source.""" return len(self.config_source) @@ -231,21 +294,32 @@ def __len__(self) -> int: class ConfigLoader: + """ConfigLoader is a generator that yields configs from a config set.""" + def __init__( self, config_set: ConfigSet, sampler: Sampler = default_sampler, skip_verified_config: bool = True ) -> None: + """Initializes the ConfigLoader class. + + Args: + config_set (ConfigSet): The configuration set. + sampler (Sampler, optional): The sampler to use for sampling configurations. Defaults to default_sampler. + skip_verified_config (bool, optional): Whether to skip verified configurations. Defaults to True. + """ self.config_set = ConfigSet.from_fwk_configs(config_set) self._sampler = sampler(self.config_set) self.skip_verified_config = skip_verified_config self.verify_config_list = list() def is_verified_config(self, config): + """Check if the config is verified.""" for verified_config in self.verify_config_list: if config == verified_config: return True return False def __iter__(self) -> Generator[BaseConfig, Any, None]: + """Iterate over the config set and yield configs.""" for index in self._sampler: new_config = self.config_set[index] if self.skip_verified_config and self.is_verified_config(new_config): @@ -318,25 +392,59 @@ def __init__(self, trial_index: int, trial_result: Union[int, float], quant_conf class TuningMonitor: + """The tuning monitor class for auto-tuning.""" + def __init__(self, tuning_config: TuningConfig) -> None: + """Initialize a TuningMonitor class. + + Args: + tuning_config (TuningConfig): The configuration object for tuning. + + Attributes: + tuning_config (TuningConfig): The configuration object for tuning. + trial_cnt (int): The number of trials performed. + tuning_history (List[_TrialRecord]): The history of tuning records. + baseline: The baseline value for comparison. + """ self.tuning_config = tuning_config self.trial_cnt = 0 self.tuning_history: List[_TrialRecord] = [] self.baseline = None def add_trial_result(self, trial_index: int, trial_result: Union[int, float], quant_config: BaseConfig) -> None: + """Adds a trial result to the tuning history. + + Args: + trial_index (int): The index of the trial. + trial_result (Union[int, float]): The result of the trial. + quant_config (BaseConfig): The quantization configuration used for the trial. + """ self.trial_cnt += 1 trial_record = _TrialRecord(trial_index, trial_result, quant_config) self.tuning_history.append(trial_record) def set_baseline(self, baseline: float): + """Set the baseline value for auto-tune. + + Args: + baseline (float): The baseline value to be set. + """ self.baseline = baseline logger.info(f"Fp32 baseline is {self.baseline}") def get_number_of_trials(self): + """Returns the number of trials in the tuning history.""" return len(self.tuning_history) def get_best_trial_record(self) -> _TrialRecord: + """Returns the best trial record based on the trial result. + + Raises: + AssertionError: If there are no trial records in the tuning monitor. + + Returns: + The best trial record. + """ assert self.get_number_of_trials() > 0, "No trial record in tuning monitor." # Put the record with a higher score at the beginning sorted_trials_records: List[_TrialRecord] = sorted( @@ -345,6 +453,11 @@ def get_best_trial_record(self) -> _TrialRecord: return sorted_trials_records[0] def get_best_quant_config(self) -> BaseConfig: + """Get the best quantization configuration based on the best trial record. + + Returns: + The best quantization configuration (BaseConfig). + """ best_trial_record = self.get_best_trial_record() return best_trial_record.quant_config @@ -354,7 +467,6 @@ def need_stop(self) -> bool: Returns: stop_flag: True if need to stop, otherwise False. """ - # reach max trials reach_max_trials = self.trial_cnt >= self.tuning_config.max_trials # reach accuracy goal @@ -368,6 +480,11 @@ def need_stop(self) -> bool: def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger, TuningMonitor]: + """Initializes the tuning process. + + Args: + tuning_config (TuningConfig): The configuration for the tuning process. + """ config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) tuning_logger = TuningLogger() tuning_monitor = TuningMonitor(tuning_config) diff --git a/neural_compressor/common/tuning_param.py b/neural_compressor/common/tuning_param.py index 3f6d9272e4f..d3f7d452e1c 100644 --- a/neural_compressor/common/tuning_param.py +++ b/neural_compressor/common/tuning_param.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The tunable parameters module.""" + import typing from enum import Enum, auto @@ -22,6 +24,14 @@ class ParamLevel(Enum): + """Enumeration representing the different levels of tuning parameters. + + Attributes: + OP_LEVEL: Represents the level of tuning parameters for operations. + OP_TYPE_LEVEL: Represents the level of tuning parameters for operation types. + MODEL_LEVEL: Represents the level of tuning parameters for models. + """ + OP_LEVEL = auto() OP_TYPE_LEVEL = auto() MODEL_LEVEL = auto() @@ -63,6 +73,15 @@ def __init__( options=None, level: ParamLevel = ParamLevel.OP_LEVEL, ) -> None: + """Initialize a TuningParam object. + + Args: + name (str): The name of the tuning parameter. + default_val (Any, optional): The default value of the tuning parameter. Defaults to None. + tunable_type (optional): The type of the tuning parameter. Defaults to None. + options (optional): The available options for the tuning parameter. Defaults to None. + level (ParamLevel, optional): The level of the tuning parameter. Defaults to ParamLevel.OP_LEVEL. + """ self.name = name self.default_val = default_val self.tunable_type = tunable_type @@ -70,14 +89,14 @@ def __init__( self.level = level @staticmethod - def create_input_args_model(expect_args_type: Any) -> type: + def create_input_args_model(expect_args_type: Any): """Dynamically create an InputArgsModel based on the provided type hint. - Parameters: - - expect_args_type (Any): The user-provided type hint for input_args. + Args: + expect_args_type (Any): The user-provided type hint for input_args. Returns: - - type: The dynamically created InputArgsModel class. + The dynamically created InputArgsModel class. """ class DynamicInputArgsModel(BaseModel): @@ -86,6 +105,14 @@ class DynamicInputArgsModel(BaseModel): return DynamicInputArgsModel def is_tunable(self, value: Any) -> bool: + """Checks if the given value is tunable based on the specified tunable type. + + Args: + value (Any): The value to be checked for tunability. + + Returns: + bool: True if the value is tunable, False otherwise. + """ # Use `Pydantic` to validate the input_args. # TODO: refine the implementation in further. assert isinstance( @@ -100,4 +127,5 @@ def is_tunable(self, value: Any) -> bool: return False def __str__(self) -> str: + """Return the name of the tuning parameter.""" return self.name diff --git a/neural_compressor/common/utils/__init__.py b/neural_compressor/common/utils/__init__.py index 88aeae223ef..fba5ee02680 100644 --- a/neural_compressor/common/utils/__init__.py +++ b/neural_compressor/common/utils/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The utils of common module.""" from neural_compressor.common.utils.constants import * from neural_compressor.common.utils.logger import * diff --git a/neural_compressor/common/utils/constants.py b/neural_compressor/common/utils/constants.py index 629a3f5743e..adf7755003b 100644 --- a/neural_compressor/common/utils/constants.py +++ b/neural_compressor/common/utils/constants.py @@ -14,9 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -# All constants +"""All frameworks-agnostic constants.""" # constants for configs GLOBAL = "global" @@ -53,6 +51,8 @@ class Mode(Enum): + """Enumeration class representing different modes of the quantizer execution.""" + PREPARE = "prepare" CONVERT = "convert" QUANTIZE = "quantize" diff --git a/neural_compressor/common/utils/logger.py b/neural_compressor/common/utils/logger.py index 94a7ca09c50..4c933368fdd 100644 --- a/neural_compressor/common/utils/logger.py +++ b/neural_compressor/common/utils/logger.py @@ -161,36 +161,44 @@ class TuningLogger: @classmethod def tuning_start(cls) -> None: + """Log the start of the tuning process.""" logger.info("Tuning started.") @classmethod def trial_start(cls, trial_index: int = None) -> None: + """Log the start of a trial.""" logger.info("%d-trail started.", trial_index) @classmethod def execution_start(cls, mode=Mode.QUANTIZE, stacklevel=2): + """Log the start of the execution process.""" log_msg = _get_log_msg(mode) assert log_msg is not None, "Please check `mode` in execution_start function of TuningLogger class." logger.info("{} started.".format(log_msg), stacklevel=stacklevel) @classmethod def execution_end(cls, mode=Mode.QUANTIZE, stacklevel=2): + """Log the end of the execution process.""" log_msg = _get_log_msg(mode) assert log_msg is not None, "Please check `mode` in execution_end function of TuningLogger class." logger.info("{} end.".format(log_msg), stacklevel=stacklevel) @classmethod def evaluation_start(cls) -> None: + """Log the start of the evaluation process.""" logger.info("Evaluation started.") @classmethod def evaluation_end(cls) -> None: + """Log the end of the evaluation process.""" logger.info("Evaluation end.") @classmethod def trial_end(cls, trial_index: int = None) -> None: + """Log the end of a trial.""" logger.info("%d-trail end.", trial_index) @classmethod def tuning_end(cls) -> None: + """Log the end of the tuning process.""" logger.info("Tuning completed.") diff --git a/neural_compressor/common/utils/save_load.py b/neural_compressor/common/utils/save_load.py index 15de5d8c2a3..e7258fc5218 100644 --- a/neural_compressor/common/utils/save_load.py +++ b/neural_compressor/common/utils/save_load.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The module for save/load config.""" import json import os @@ -26,7 +27,6 @@ def save_config_mapping(config_mapping, qconfig_file_path): # pragma: no cover config_mapping (dict): config mapping. qconfig_file_path (str): path to saved json file. """ - per_op_qconfig = {} for (op_name, op_type), op_config in config_mapping.items(): value = {op_config.name: op_config.to_dict()} diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index 82f24243a9b..2b606bb8a76 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The utility of common module.""" import collections import importlib @@ -43,7 +44,6 @@ def singleton(cls): """Singleton decorator.""" - instances = {} def _singleton(*args, **kw): @@ -211,6 +211,15 @@ def set_tensorboard(tensorboard: bool): def log_process(mode=Mode.QUANTIZE): + """Decorator function that logs the stage of process. + + Args: + mode (Mode): The mode of the process. + + Returns: + The decorated function. + """ + def log_process_wrapper(func): def inner_wrapper(*args, **kwargs): start_log = default_tuning_logger.execution_start @@ -235,6 +244,15 @@ def inner_wrapper(*args, **kwargs): def call_counter(func): + """A decorator that keeps track of the number of times a function is called. + + Args: + func: The function to be decorated. + + Returns: + The decorated function. + """ + def wrapper(*args, **kwargs): FUNC_CALL_COUNTS[func.__name__] += 1 return func(*args, **kwargs)