Skip to content

Commit

Permalink
skip verified config
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 committed May 8, 2024
1 parent 1a45090 commit 5eb504f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
7 changes: 7 additions & 0 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,13 @@ def _is_op_type(name: str) -> bool:
def get_config_set_for_tuning(cls):
raise NotImplementedError

def __eq__(self, other: BaseConfig) -> bool:
if not isinstance(other, type(self)):
return False
return self.params_list == other.params_list and all(
getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list
)


class ComposableConfig(BaseConfig):
name = COMPOSABLE_CONFIG
Expand Down
20 changes: 18 additions & 2 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,29 @@ def __len__(self) -> int:


class ConfigLoader:
def __init__(self, config_set: ConfigSet, sampler: Sampler = default_sampler) -> None:
def __init__(
self, config_set: ConfigSet, sampler: Sampler = default_sampler, skip_verified_config: bool = True
) -> None:
self.config_set = ConfigSet.from_fwk_configs(config_set)
self._sampler = sampler(self.config_set)
self.skip_verified_config = skip_verified_config
self.verify_config_list = list()

def is_verified_config(self, config):
for verified_config in self.verify_config_list:
if config == verified_config:
return True
return False

def __iter__(self) -> Generator[BaseConfig, Any, None]:
for index in self._sampler:
yield self.config_set[index]
new_config = self.config_set[index]
if self.skip_verified_config and self.is_verified_config(new_config):
logger.warning("Skip the verified config:")
logger.warning(new_config.to_dict())
continue
self.verify_config_list.append(new_config)
yield new_config


class TuningConfig:
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/common/tuning_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def is_tunable(self, value: Any) -> bool:
except Exception as e:
logger.debug(f"Failed to validate the input_args: {e}")
return False

def __str__(self) -> str:
return self.name
8 changes: 8 additions & 0 deletions test/3x/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ def test_config_loader(self) -> None:
for i, config in enumerate(self.loader):
self.assertEqual(config, self.config_set[i])

def test_config_loader_skip_verified_config(self) -> None:
config_set = [FakeAlgoConfig(weight_bits=[4, 8]), FakeAlgoConfig(weight_bits=8)]
config_loader = ConfigLoader(config_set)
config_count = 0
for i, config in enumerate(config_loader):
config_count += 1
self.assertEqual(config_count, 2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5eb504f

Please sign in to comment.