Skip to content

Commit

Permalink
update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeshkatakam committed Dec 29, 2024
1 parent 1955d9d commit be6b6bf
Show file tree
Hide file tree
Showing 6 changed files with 1,743 additions and 15 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<div align="center">

# OmnixRL
## **NOTE: [Development in Progress]**
Inspired by OpenAI Spinning Up RL Algorithms Educational Resource implemented in JAX
Expand All @@ -10,10 +12,11 @@ Inspired by OpenAI Spinning Up RL Algorithms Educational Resource implemented in
[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

<p align="center">
<img src="./assets/imgs/OmnixRL_logo.png">
<p align="center" style="text-align: center; width: 100%;">
<img src="./assets/imgs/OmnixRL_logo.png" alt="OmnixRL Logo" style="max-width: 600px; width: 100%; height: auto; display: block; margin: 0 auto;">
</p>

</div>
A comprehensive reinforcement learning library implemented in JAX, inspired by OpenAI's Spinning Up. This library provides a clean, modular implementation of popular RL algorithms with a focus on research experimentation and serves as a research framework for developing novel RL algorithms.

## Core Features
Expand Down
5 changes: 4 additions & 1 deletion omnixrl/algo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
from jax.random import PRNGKey

class BaseModel(ABC):
class BaseNN(ABC):
"""
Abstract base class for user-defined models in the RL library.
"""
Expand Down Expand Up @@ -128,6 +128,9 @@ def loss_fn(params, states, actions, rewards):
return -jnp.mean(rewards) # Placeholder loss

return self.train_step(states, actions, rewards, loss_fn)



############################# User Defined Models Examples in Different neural network libraries ##########################
# In FLAX:
from flax import linen as nn
Expand Down
84 changes: 75 additions & 9 deletions omnixrl/algos/vpg/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,86 @@
import jax.numpy as jnp
import scipy.signal
import gym.spaces import Box, Discrete




import jax
import jax.numpy as jnp
import equinox as eqx
import distrax
from jax.random.distributions.normal import Normal
from jax.random.distributions.categorical import Categorical

import gym.spaces import Box, Discrete
from typing import List, Callable # NOTE: Use JAXTyping library later on
from omnixrl import BaseNN

def combined_shape(length, shape = None):
if shape is None:
return (length,)
return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
layers = []
for j in range(len(sizes - 1)):
act = activation if j < len(sizes) - 2 else output_activation
layers += [nn.layers


class MLP:
def __init__(self, layer_sizes: List[int], activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu):
"""
Initialize an MLP with variable layer sizes.
Args:
layer_sizes: List of integers specifying the number of units in each layer, including input and output layers.
activation: Activation function to use between layers.
"""
self.layer_sizes = layer_sizes
self.activation = activation

# Initialize weights and biases
self.params = self.initialize_params()

def initialize_params(self):
"""Initialize the weights and biases of the MLP."""
params = []
for in_size, out_size in zip(self.layer_sizes[:-1], self.layer_sizes[1:]):
weight = jax.random.normal(jax.random.PRNGKey(0), (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
bias = jnp.zeros(out_size)
params.append({'weight': weight, 'bias': bias})
return params

def forward(self, x: jnp.ndarray) -> jnp.ndarray:
"""Perform a forward pass through the MLP."""
for i, layer in enumerate(self.params):
x = jnp.dot(x, layer['weight']) + layer['bias']
if i < len(self.params) - 1: # Apply activation for all but the last layer
x = self.activation(x)
return x

def apply(self, x: jnp.ndarray) -> jnp.ndarray:
"""Alias for forward."""
return self.forward(x)




class Actor(BaseNN): # Implement Base Neural Network to inherit methods for Actor Network
def _distribution(self, obs):
raise NotImplementedError

def _log_prob_from_distributions(self, pi, act):
raise NotImplementedError

def forward(self,obs, act = None):
pass # Do this Later

class MLPCategoricalActor(Actor):
def __init__(self, act):
super().__init__()

class MLPGaussianActor(Actor):
def __init__(self,act):
super().__init__()

def _distribution(self,obs):
distrax.distributions.normal.Normal(obs)



class MLPCritic(BaseNN):
def __init__(self,values):
super().__init__()
11 changes: 11 additions & 0 deletions omnixrl/algos/vpg/vpg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from typing import Any
import numpy as np

class policy:
def __init__(self, num_actions):
self.num_actions = num_actions
self.returns = []
self.states = []
self.actions = []

def get_action(self, state):
# Implement policy to select an action based on the given state
pass
class VPG:
def __init__(self, env, policy, value_function, learning_rate, gamma):
self.env = env
Expand Down
Loading

0 comments on commit be6b6bf

Please sign in to comment.