Skip to content

Commit

Permalink
update user_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeshkatakam committed Dec 30, 2024
1 parent 60624d6 commit 721cda2
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 84 deletions.
File renamed without changes.
Empty file added configs/dqn_atari.yaml
Empty file.
40 changes: 40 additions & 0 deletions configs/ppo_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Environment
env:
env_id: "CartPole-v1"
max_episode_steps: 500
action_space_type: "discrete"
state_space_type: "continuous"

# Network Architecture
network:
actor_hidden_sizes: [64, 64]
critic_hidden_sizes: [64, 64]
activation: "tanh"
layer_norm: false

# Training Process
training:
total_timesteps: 100000
batch_size: 64
learning_rate: 0.0003
device: "cuda"
seed: 42
num_envs: 8
eval_freq: 5000

# PPO Specific
algorithm:
n_steps: 2048
n_epochs: 10
gamma: 0.99
gae_lambda: 0.95
clip_range: 0.2
ent_coef: 0.0
vf_coef: 0.5
max_grad_norm: 0.5

# Logging
logging:
log_dir: "logs/ppo_cartpole"
tensorboard: true
verbose: 1
41 changes: 41 additions & 0 deletions configs/sac_pendulum.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Environment
env:
env_id: "Pendulum-v1"
max_episode_steps: 200
action_space_type: "continuous"
state_space_type: "continuous"
reward_scale: 0.1

# Network Architecture
network:
actor_hidden_sizes: [256, 256]
critic_hidden_sizes: [256, 256]
activation: "relu"
output_activation: "tanh"

# Training Process
training:
total_timesteps: 500000
batch_size: 256
learning_rate: 0.0003
device: "cuda"
seed: 42
num_envs: 1
eval_freq: 10000

# SAC Specific
algorithm:
buffer_size: 1000000
learning_starts: 10000
tau: 0.005
gamma: 0.99
train_freq: 1
gradient_steps: 1
ent_coef: "auto"

# Logging
logging:
log_dir: "logs/sac_pendulum"
wandb_project: "sac_training"
tensorboard: true
verbose: 1
166 changes: 82 additions & 84 deletions omnixrl/user_config.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,76 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Union, Dict, Any
from pathlib import Path
import yaml
import json

from .configs.base_config import EnvConfig, NetworkConfig
from .configs.algo_configs import PPOConfig, SACConfig

from dataclasses import dataclass
from typing import List, Optional

@dataclass
class EnvConfig:
env_id: str = "CartPole-v1"
max_episode_steps: int = 1000
action_space_type: str = "discrete"
state_space_type: str = "continuous"
reward_scale: float = 1.0

@dataclass
class ModelConfig:
actor_hidden_sizes: List[int] = (256, 256)
critic_hidden_sizes: List[int] = (256, 256)
activation: str = "relu"
layer_norm: bool = False


@dataclass
class PPOConfig:
n_steps: int = 2048
n_epochs: int = 10
gamma: float = 0.99
gae_lambda: float = 0.95
clip_range: float = 0.2

@dataclass
class SACConfig:
buffer_size: int = 1_000_000
tau: float = 0.005
gamma: float = 0.99
train_freq: int = 1
@dataclass
class Config:
"""Base configuration class for RL algorithms
Attributes:
learning_rate (float): Learning rate for optimization
batch_size (int): Batch size for training
hidden_sizes (List[int]): Neural network architecture
max_steps (int): Maximum number of training steps
device (str): Device to run on ('cpu' or 'cuda')
seed (Optional[int]): Random seed for reproducibility
"""
# Training parameters
learning_rate: float = 0.001
batch_size: int = 64
hidden_sizes: List[int] = (256, 256)
max_steps: int = 1_000_000
device: str = "cpu"
seed: Optional[int] = None

def validate(self) -> bool:
"""Validate configuration values"""
try:
assert self.learning_rate > 0, "Learning rate must be positive"
assert self.batch_size > 0, "Batch size must be positive"
assert all(h > 0 for h in self.hidden_sizes), "Hidden sizes must be positive"
assert self.max_steps > 0, "Max steps must be positive"
assert self.device in ["cpu", "cuda"], "Device must be 'cpu' or 'cuda'"
return True
except AssertionError as e:
print(f"Configuration validation failed: {e}")
return False
"""Main configuration class that combines all configs"""
env: EnvConfig
model: ModelConfig
algorithm: Union[PPOConfig, SACConfig]

@classmethod
def from_dict(cls, config: Dict) -> "Config":
"""Create a Config instance from a dictionary"""
# Filter only valid keys
valid_keys = cls.__dataclass_fields__.keys()
filtered_config = {
k: v for k, v in config.items()
if k in valid_keys
}
def from_dict(cls, config: Dict[str, Any]) -> "Config":
# Create component configs
env_config = EnvConfig(**config.get("env", {}))
model_config = ModelConfig(**config.get("network", {}))

# Create instance
instance = cls(**filtered_config)
instance.validate()
return instance
# Determine algorithm type and create config
algo_type = config.get("algorithm_type", "ppo")
if algo_type == "ppo":
algo_config = PPOConfig(**config.get("algorithm", {}))
elif algo_type == "sac":
algo_config = SACConfig(**config.get("algorithm", {}))
else:
raise ValueError(f"Unknown algorithm type: {algo_type}")

return cls(
env=env_config,
model=model_config,
algorithm=algo_config
)

@classmethod
def from_file(cls, path: str) -> "Config":
"""Load configuration from a YAML or JSON file"""
path = Path(path)

# Load file
with open(path, 'r') as f:
if path.suffix in ['.yaml', '.yml']:
config_dict = yaml.safe_load(f)
Expand All @@ -67,42 +80,27 @@ def from_file(cls, path: str) -> "Config":
raise ValueError(f"Unsupported file format: {path.suffix}")

return cls.from_dict(config_dict)




# Example usage:
if __name__ == "__main__":
# From dictionary
config_dict = {
"learning_rate": 0.001,
"batch_size": 32,
"hidden_sizes": [64, 64],
"device": "cuda"
}
config1 = Config.from_dict(config_dict)
print("Config from dict:", config1)

# Create example YAML file
yaml_config = """
learning_rate: 0.0003
batch_size: 64
hidden_sizes: [256, 256]
device: cpu
seed: 42
"""

with open("example_config.yaml", "w") as f:
f.write(yaml_config)

# Load from YAML
config2 = Config.from_file("example_config.yaml")
print("\nConfig from YAML:", config2)

# Validation example
invalid_config = {
"learning_rate": -0.001, # Invalid negative learning rate
"batch_size": 64
}
print("\nTrying invalid config:")
config3 = Config.from_dict(invalid_config) # Will print validation error
def validate(self) -> bool:
"""Validate all configurations"""
try:
# Validate environment config
assert self.env.max_episode_steps > 0
assert self.env.action_space_type in ["discrete", "continuous"]
# Validate network config
assert all(h > 0 for h in self.network.actor_hidden_sizes)
assert self.network.activation in ["relu", "tanh"]
# Validate algorithm config
if isinstance(self.algorithm, PPOConfig):
assert 0 < self.algorithm.clip_range < 1
assert 0 < self.algorithm.gamma <= 1
elif isinstance(self.algorithm, SACConfig):
assert self.algorithm.buffer_size > 0
assert 0 < self.algorithm.tau <= 1
return True
except AssertionError as e:
print(f"Configuration validation failed: {e}")
return False

0 comments on commit 721cda2

Please sign in to comment.