From 5eb504fd72d5630e56a338389365cb8afa593cee Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 8 May 2024 11:06:27 +0800 Subject: [PATCH] skip verified config Signed-off-by: yiliu30 --- neural_compressor/common/base_config.py | 7 +++++++ neural_compressor/common/base_tuning.py | 20 ++++++++++++++++++-- neural_compressor/common/tuning_param.py | 3 +++ test/3x/common/test_common.py | 8 ++++++++ 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 5e9e72a8882..35b0f532738 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -436,6 +436,13 @@ def _is_op_type(name: str) -> bool: def get_config_set_for_tuning(cls): raise NotImplementedError + def __eq__(self, other: BaseConfig) -> bool: + if not isinstance(other, type(self)): + return False + return self.params_list == other.params_list and all( + getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list + ) + class ComposableConfig(BaseConfig): name = COMPOSABLE_CONFIG diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index 54f908232ad..2a1adfa480b 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -231,13 +231,29 @@ def __len__(self) -> int: class ConfigLoader: - def __init__(self, config_set: ConfigSet, sampler: Sampler = default_sampler) -> None: + def __init__( + self, config_set: ConfigSet, sampler: Sampler = default_sampler, skip_verified_config: bool = True + ) -> None: self.config_set = ConfigSet.from_fwk_configs(config_set) self._sampler = sampler(self.config_set) + self.skip_verified_config = skip_verified_config + self.verify_config_list = list() + + def is_verified_config(self, config): + for verified_config in self.verify_config_list: + if config == verified_config: + return True + return False def __iter__(self) -> Generator[BaseConfig, Any, None]: for index in self._sampler: - yield self.config_set[index] + new_config = self.config_set[index] + if self.skip_verified_config and self.is_verified_config(new_config): + logger.warning("Skip the verified config:") + logger.warning(new_config.to_dict()) + continue + self.verify_config_list.append(new_config) + yield new_config class TuningConfig: diff --git a/neural_compressor/common/tuning_param.py b/neural_compressor/common/tuning_param.py index 207811590ee..3f6d9272e4f 100644 --- a/neural_compressor/common/tuning_param.py +++ b/neural_compressor/common/tuning_param.py @@ -98,3 +98,6 @@ def is_tunable(self, value: Any) -> bool: except Exception as e: logger.debug(f"Failed to validate the input_args: {e}") return False + + def __str__(self) -> str: + return self.name diff --git a/test/3x/common/test_common.py b/test/3x/common/test_common.py index d1df7d98b1d..9eb89b9ef23 100644 --- a/test/3x/common/test_common.py +++ b/test/3x/common/test_common.py @@ -336,6 +336,14 @@ def test_config_loader(self) -> None: for i, config in enumerate(self.loader): self.assertEqual(config, self.config_set[i]) + def test_config_loader_skip_verified_config(self) -> None: + config_set = [FakeAlgoConfig(weight_bits=[4, 8]), FakeAlgoConfig(weight_bits=8)] + config_loader = ConfigLoader(config_set) + config_count = 0 + for i, config in enumerate(config_loader): + config_count += 1 + self.assertEqual(config_count, 2) + if __name__ == "__main__": unittest.main()