Skip to content

Commit

Permalink
Ensure that there is one way to collect psi sources and make it easie…
Browse files Browse the repository at this point in the history
…r to reuse values

PiperOrigin-RevId: 725181143
  • Loading branch information
tamaranorman authored and Torax team committed Feb 13, 2025
1 parent b695cc5 commit c70e73a
Show file tree
Hide file tree
Showing 28 changed files with 207 additions and 707 deletions.
1 change: 0 additions & 1 deletion torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def _build_single_source_builder_from_config(
source_builder_class = source_lib.make_source_builder(
supported_source.source_class,
runtime_params_type=model_function.runtime_params_class,
links_back=model_function.links_back,
model_func=model_function.source_profile_function,
)

Expand Down
138 changes: 69 additions & 69 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torax.geometry import standard_geometry
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib

_trapz = jax.scipy.integrate.trapezoid
Expand Down Expand Up @@ -341,48 +342,19 @@ def _prescribe_currents(


def _calculate_currents_from_psi(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels,
source_profiles: source_profiles_lib.SourceProfiles,
) -> state.Currents:
"""Creates the initial Currents using psi to calculate jtot.
Args:
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: General runtime parameters at t_initial.
geo: Geometry of the tokamak.
core_profiles: Core profiles.
source_models: All TORAX source/sink functions. If not provided, uses the
default sources.
Returns:
currents: Plasma currents
"""

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
"""Creates the initial Currents using psi to calculate jtot."""
jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
core_profiles.psi,
)

bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)

# calculate "External" current profile (e.g. ECCD)
# form of external current on face grid
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
bootstrap_profile = source_profiles.j_bootstrap
# Note that the psi sources here are the standard sources and don't include
# the bootstrap current.
external_current = sum(source_profiles.psi.values())
johm = jtot - external_current - bootstrap_profile.j_bootstrap
currents = state.Currents(
jtot=jtot,
Expand Down Expand Up @@ -466,7 +438,7 @@ def _calculate_psi_grad_constraint(
)


def _init_psi_and_current(
def _init_psi_psidot_and_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand All @@ -493,6 +465,22 @@ def _init_psi_and_current(
Returns:
Refined core profiles.
"""
source_profiles = source_profiles_lib.SourceProfiles(
j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile(geo),
qei=source_profiles_lib.QeiInfo.zeros(geo),
)
# Updates the calculated source profiles with the standard source profiles.
source_profile_builders.build_standard_source_profiles(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
psi_only=True,
calculate_anyway=True,
calculated_source_profiles=source_profiles,
)

# Retrieving psi from the profile conditions.
if dynamic_runtime_params_slice.profile_conditions.psi is not None:
psi = cell_variable.CellVariable(
Expand All @@ -504,12 +492,19 @@ def _init_psi_and_current(
dr=geo.drho_norm,
)
core_profiles = dataclasses.replace(core_profiles, psi=psi)
currents = _calculate_currents_from_psi(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
source_profiles = dataclasses.replace(
source_profiles, j_bootstrap=bootstrap_profile
)
currents = _calculate_currents_from_psi(
geo=geo,
core_profiles=core_profiles,
source_profiles=source_profiles,
)
# Retrieving psi from the standard geometry input.
elif (
Expand All @@ -528,30 +523,29 @@ def _init_psi_and_current(
dr=geo.drho_norm,
)
core_profiles = dataclasses.replace(core_profiles, psi=psi)
currents = _calculate_currents_from_psi(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
source_profiles = dataclasses.replace(
source_profiles, j_bootstrap=bootstrap_profile
)
currents = _calculate_currents_from_psi(
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
source_profiles=source_profiles,
)
# Calculating j according to nu formula and psi from j.
elif (
isinstance(geo, circular_geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
# First calculate currents without bootstrap.
bootstrap = source_profiles_lib.BootstrapCurrentProfile.zero_profile(
geo
)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
external_current = sum(source_profiles.psi.values())
currents = _prescribe_currents(
bootstrap_profile=bootstrap,
bootstrap_profile=source_profiles.j_bootstrap,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
Expand Down Expand Up @@ -590,7 +584,28 @@ def _init_psi_and_current(
else:
raise ValueError('Cannot compute psi for given config.')

core_profiles = dataclasses.replace(core_profiles, psi=psi, currents=currents)
core_profiles = dataclasses.replace(
core_profiles, psi=psi, currents=currents)
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
source_profiles = dataclasses.replace(
source_profiles, j_bootstrap=bootstrap_profile
)
# psidot calculated here with phibdot=0 in geo, since this is initial
# conditions and we don't yet have information on geo_t_plus_dt for the
# phibdot calculation.
psidot = ohmic_heat_source.calculate_psidot_from_psi_sources(
source_profiles=source_profiles,
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
psi=psi,
geo=geo,
)
psidot_cell_var = dataclasses.replace(core_profiles.psidot, value=psidot)
core_profiles = dataclasses.replace(core_profiles, psidot=psidot_cell_var)

return core_profiles

Expand Down Expand Up @@ -667,29 +682,14 @@ def initial_core_profiles(
nref=jnp.asarray(dynamic_runtime_params_slice.numerics.nref),
)

core_profiles = _init_psi_and_current(
core_profiles = _init_psi_psidot_and_current(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
)

# psidot calculated here with phibdot=0 in geo, since this is initial
# conditions and we don't yet have information on geo_t_plus_dt for the
# phibdot calculation.
psidot = dataclasses.replace(
core_profiles.psidot,
value=ohmic_heat_source.calc_psidot(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
),
)
core_profiles = dataclasses.replace(core_profiles, psidot=psidot)

# Set psi as source of truth and recalculate jtot, q, s
return physics.update_jtot_q_face_s_face(
geo=geo,
Expand Down
7 changes: 1 addition & 6 deletions torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torax.geometry import geometry_provider as geometry_provider_lib
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import ohmic_heat_source
from torax.sources import source_operations
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
Expand Down Expand Up @@ -690,14 +689,10 @@ def _update_psidot(
core_sources: source_profiles_lib.SourceProfiles,
) -> state.CoreProfiles:
"""Update psidot based on new core_profiles."""
psi_sources = source_operations.sum_sources_psi(geo, core_sources)

psidot = dataclasses.replace(
core_profiles.psidot,
value=ohmic_heat_source.calculate_psidot_from_psi_sources(
psi_sources=psi_sources,
sigma=core_sources.j_bootstrap.sigma,
sigma_face=core_sources.j_bootstrap.sigma_face,
source_profiles=core_sources,
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
psi=core_profiles.psi,
geo=geo,
Expand Down
2 changes: 0 additions & 2 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles


Expand Down Expand Up @@ -128,7 +127,6 @@ def bremsstrahlung_model_func(
source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_model_func: source_models.SourceModels | None,
) -> tuple[chex.Array, ...]:
"""Model function for the Bremsstrahlung heat sink."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
Expand Down
3 changes: 0 additions & 3 deletions torax/sources/cyclotron_radiation_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles


Expand Down Expand Up @@ -285,7 +284,6 @@ def cyclotron_radiation_albajar(
source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels,
) -> tuple[array_typing.ArrayFloat, ...]:
"""Calculates the cyclotron radiation heat sink contribution to the electron heat equation.
Expand Down Expand Up @@ -314,7 +312,6 @@ def cyclotron_radiation_albajar(
source_name: The name of the source.
core_profiles: The core profiles object.
unused_calculated_source_profiles: Unused.
unused_source_models: Unused.
Returns:
The cyclotron radiation heat sink contribution to the electron heat
Expand Down
11 changes: 2 additions & 9 deletions torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torax.sources import formulas
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles

InterpolatedVarTimeRhoInput = (
Expand Down Expand Up @@ -104,35 +103,29 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):


def calc_heating_and_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> tuple[chex.Array, ...]:
"""Model function for the electron-cyclotron source.
Based on Lin-Liu, Y. R., Chan, V. S., & Prater, R. (2003).
See https://torax.readthedocs.io/en/latest/electron-cyclotron-derivation.html
Args:
static_runtime_params_slice: Static runtime parameters.
unused_static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: Global runtime parameters
geo: Magnetic geometry.
source_name: Name of the source.
core_profiles: CoreProfiles component of the state.
unused_calculated_source_profiles: Unused.
unused_source_models: Unused.
Returns:
2D array of electron cyclotron heating power density and current density.
"""
del (
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
Expand Down
4 changes: 0 additions & 4 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from torax.sources import formulas
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models
from torax.sources import source_profiles


Expand Down Expand Up @@ -83,7 +82,6 @@ def calc_puff_source(
source_name: str,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> tuple[chex.Array, ...]:
"""Calculates external source term for n from puffs."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
Expand Down Expand Up @@ -176,7 +174,6 @@ def calc_generic_particle_source(
source_name: str,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> tuple[chex.Array, ...]:
"""Calculates external source term for n from SBI."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
Expand Down Expand Up @@ -262,7 +259,6 @@ def calc_pellet_source(
source_name: str,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> tuple[chex.Array, ...]:
"""Calculates external source term for n from pellets."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
Expand Down
3 changes: 1 addition & 2 deletions torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import ClassVar, Optional
from typing import ClassVar

import chex
import jax
Expand Down Expand Up @@ -149,7 +149,6 @@ def fusion_heat_model_func(
unused_source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: Optional['source_models.SourceModels'],
) -> tuple[chex.Array, ...]:
"""Model function for fusion heating."""
# pytype: enable=name-error
Expand Down
Loading

0 comments on commit c70e73a

Please sign in to comment.