Skip to content

Commit

Permalink
Expose some providers (such as sources) for access
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725324007
  • Loading branch information
Torax team committed Feb 13, 2025
1 parent 105e8c3 commit 33a3c02
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torax.config.build_sim import build_sim_from_config
from torax.config.config_loader import import_module
from torax.interpolated_param import InterpolatedVarSingleAxis
from torax.interpolated_param import InterpolatedVarTimeRho
from torax.interpolated_param import InterpolationMode
from torax.output import ToraxSimOutputs
from torax.sim import Sim
Expand All @@ -35,6 +36,8 @@
'build_sim_from_config',
'import_module',
'InterpolatedVarSingleAxis',
'InterpolatedVarTimeRho',
'InterpolationMode',
'Sim',
'SimError',
'ToraxSimOutputs',
Expand Down
12 changes: 12 additions & 0 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,18 @@ def __init__(
def sources(self) -> dict[str, sources_params.RuntimeParams]:
return self._sources

@property
def sources_providers(
self,
) -> dict[str, sources_params.RuntimeParamsProvider]:
return self._sources_providers

@property
def pedestal_runtime_params_provider(
self,
) -> pedestal_model_params.RuntimeParamsProvider:
return self._pedestal_runtime_params_provider

def validate_new(
self,
new_provider: DynamicRuntimeParamsSliceProvider,
Expand Down
7 changes: 7 additions & 0 deletions torax/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def __init__(
"""
self.rho_norm = rho_norm
self.sorted_indices = np.array(sorted(values.keys()))
self._rho_interpolation_mode = rho_interpolation_mode
self._time_interpolation_mode = time_interpolation_mode
rho_norm_interpolated_values = np.stack(
[
InterpolatedVarSingleAxis(
Expand All @@ -498,6 +500,11 @@ def __init__(
interpolation_mode=time_interpolation_mode,
)

@property
def interpolation_mode(self) -> tuple[InterpolationMode, InterpolationMode]:
"""Returns the interpolation mode used by this param."""
return self._time_interpolation_mode, self._rho_interpolation_mode

def get_value(self, x: chex.Numeric) -> chex.Array:
"""Returns the value of this parameter interpolated at x=time."""
return self._time_interpolated_var.get_value(x)

0 comments on commit 33a3c02

Please sign in to comment.