Skip to content

Commit

Permalink
Added the ability to blend lines based on wavelength bins (#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillJRoper authored Jan 29, 2025
2 parents 8fb1192 + 79bcedf commit 5214813
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 82 deletions.
29 changes: 29 additions & 0 deletions docs/source/lines/galaxy_lines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,35 @@
"for line in stars.lines[\"emergent\"]:\n",
" print(f\"{line.id}: {line.flux} @ {line.obslam}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Blending lines\n",
"\n",
"``Lines`` in a ``LineCollection`` can be blended based on a given wavelength resolution using the ``get_blended_lines`` method. This method takes a set of wavelength bins, either arbitrarily defined or based on a particular observatory, and returns a new ``LineCollection`` containing lines bleneded within each bin."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from unyt import angstrom\n",
"\n",
"print(\"Before blending:\")\n",
"print(stars.lines[\"emergent\"])\n",
"\n",
"# Blend the lines onto an arbitrary wavelength grid\n",
"lam_bins = np.arange(4000, 7000, 1000) * angstrom\n",
"blended_lines = stars.lines[\"emergent\"].get_blended_lines(lam_bins)\n",
"\n",
"print(\"After blending:\")\n",
"print(blended_lines)"
]
}
],
"metadata": {},
Expand Down
183 changes: 106 additions & 77 deletions src/synthesizer/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from synthesizer import exceptions, line_ratios
from synthesizer.conversions import lnu_to_llam, standard_to_vacuum
from synthesizer.units import Quantity, accepts
from synthesizer.utils import TableFormatter
from synthesizer.warnings import deprecation


Expand Down Expand Up @@ -248,6 +249,9 @@ class LineCollection:
A list of available line diagrams.
"""

# Define quantities
wavelengths = Quantity()

def __init__(self, lines):
"""
Initialise LineCollection.
Expand Down Expand Up @@ -288,11 +292,11 @@ def __init__(self, lines):
self.wavelengths = self.wavelengths[sorted_arguments]

# Include line ratio and diagram definitions
self.line_ratios = line_ratios
self._line_ratios = line_ratios

# Create list of available line ratios
self.available_ratios = []
for ratio_id, ratio in self.line_ratios.ratios.items():
for ratio_id, ratio in self._line_ratios.ratios.items():
# Create a set from the ratio line ids while also unpacking
# any comma separated lines
ratio_line_ids = set()
Expand All @@ -305,7 +309,7 @@ def __init__(self, lines):

# Create list of available line diagnostics
self.available_diagrams = []
for diagram_id, diagram in self.line_ratios.diagrams.items():
for diagram_id, diagram in self._line_ratios.diagrams.items():
# Create a set from the diagram line ids while also unpacking
# any comma separated lines
diagram_line_ids = set()
Expand Down Expand Up @@ -357,33 +361,6 @@ def concatenate(self, other):

return LineCollection(my_lines)

def __str__(self):
"""
Function to print a basic summary of the LineCollection object.
Returns a string containing the id, wavelength, luminosity,
equivalent width, and flux if generated.
Returns:
summary (str)
Summary string containing the total mass formed and
lists of the available SEDs, lines, and images.
"""

# Set up string for printing
summary = ""

# Add the content of the summary to the string to be printed
summary += "-" * 10 + "\n"
summary += "LINE COLLECTION\n"
summary += f"number of lines: {len(self.line_ids)}\n"
summary += f"lines: {self.line_ids}\n"
summary += f"available ratios: {self.available_ratios}\n"
summary += f"available diagrams: {self.available_diagrams}\n"
summary += "-" * 10

return summary

def __iter__(self):
"""
Overload iteration to allow simple looping over Line objects,
Expand All @@ -408,6 +385,23 @@ def __next__(self):
# Return the filter
return self.lines[self.line_ids[self._current_ind - 1]]

def __len__(self):
"""Return the number of lines in the collection."""
return self.nlines

def __str__(self):
"""
Return a string representation of the LineCollection object.
Returns:
table (str)
A string representation of the LineCollection object.
"""
# Intialise the table formatter
formatter = TableFormatter(self)

return formatter.get_table("LineCollection")

def sum(self):
"""
For collections containing lines from multiple particles calculate the
Expand Down Expand Up @@ -465,7 +459,7 @@ def get_ratio(self, ratio_id):
# defined in the line_ratios module...
if isinstance(ratio_id, str):
# Check if ratio_id exists
if ratio_id not in self.line_ratios.available_ratios:
if ratio_id not in self._line_ratios.available_ratios:
raise exceptions.UnrecognisedOption(
f"ratio_id not recognised ({ratio_id})"
)
Expand All @@ -477,7 +471,7 @@ def get_ratio(self, ratio_id):
f"this ratio ({ratio_id})"
)

line1, line2 = self.line_ratios.ratios[ratio_id]
line1, line2 = self._line_ratios.ratios[ratio_id]

# Otherwise interpret as a list
elif isinstance(ratio_id, list):
Expand All @@ -502,7 +496,7 @@ def get_diagram(self, diagram_id):
# defined in the line_ratios module...
if isinstance(diagram_id, str):
# check if ratio_id exists
if diagram_id not in self.line_ratios.available_diagrams:
if diagram_id not in self._line_ratios.available_diagrams:
raise exceptions.UnrecognisedOption(
f"diagram_id not recognised ({diagram_id})"
)
Expand All @@ -514,7 +508,7 @@ def get_diagram(self, diagram_id):
f"this diagram ({diagram_id})"
)

ab, cd = self.line_ratios.diagrams[diagram_id]
ab, cd = self._line_ratios.diagrams[diagram_id]

# Otherwise interpret as a list
elif isinstance(diagram_id, list):
Expand Down Expand Up @@ -576,6 +570,77 @@ def get_flux(self, cosmo, z, igm=None):
for line in self.lines.values():
line.get_flux(cosmo, z, igm)

@accepts(wavelength_bins=angstrom)
def get_blended_lines(self, wavelength_bins):
"""
Blend lines separated by less than the provided wavelength resolution.
We use a set of wavelength bins to enable the user to control exactly
which lines are blended together. This also enables an array to be
used emulating an instrument resolution.
A simple resolution would lead to ambiguity in situations where A and
B are blended, and B and C are blended, but A and C are not.
Args:
wavelength_bins (unyt_array)
The wavelength bin edges into which the lines will be blended.
Any lines outside the range of the bins will be ignored.
Returns:
LineCollection
A new LineCollection object containing the blended lines.
"""
# Ensure the bins are sorted and actually have a length
wavelength_bins = np.sort(wavelength_bins)
if len(wavelength_bins) < 2:
raise exceptions.InconsistentArguments(
"Wavelength bins must have a length of at least 2"
)

# Sort wavelengths into the bins getting the indices in each bin
bin_inds = np.digitize(self.wavelengths, wavelength_bins)

# Create a dictionary to hold the blended lines
blended_lines = np.empty(len(wavelength_bins), dtype=object)

# Initialise the array of blended lines to None
for i in range(blended_lines.size):
blended_lines[i] = None

# Loop bin indices and combine the lines into the blended_lines array
for i, bin_ind in enumerate(bin_inds):
# If the bin index is 0 or the length of the bins then it lay
# outside the range of the bins
if bin_ind == 0 or bin_ind == len(wavelength_bins):
continue

# Ok, now we can handle the off by 1 error that digitize gives us
bin_ind -= 1

# Get the line id
line_id = self.line_ids[i]

# Get the line itself
line = self.lines[line_id]

# If the bin is empty, just store the line
if blended_lines[bin_ind] is None:
blended_lines[bin_ind] = line

# Otherwise, combine the line with the existing line
else:
blended_lines[bin_ind] = blended_lines[bin_ind] + line

# Convert the array of lines to a dictionary ready to make a new
# LineCollection
new_lines = {}
for line in blended_lines:
if line is not None:
new_lines[line.id] = line

return LineCollection(new_lines)


class Line:
"""
Expand Down Expand Up @@ -763,52 +828,16 @@ def _make_line_from_lines(self, lines):

def __str__(self):
"""
Return a basic summary of the Line object.
Returns a string containing the id, wavelength, luminosity,
equivalent width, and flux if generated.
Return a string representation of the LineCollection object.
Returns:
summary (str)
Summary string containing the total mass formed and
lists of the available SEDs, lines, and images.
"""
# Set up string for printing
pstr = ""

# Add the content of the summary to the string to be printed
pstr += "-" * 10 + "\n"
pstr += f"SUMMARY OF {self.id}" + "\n"
pstr += f"wavelength: {self.wavelength:.1f}" + "\n"
if isinstance(self.luminosity, np.ndarray):
mean_lum = np.mean(self._luminosity)
pstr += f"Npart: {self.luminosity.size}\n"
pstr += (
f"<log10(luminosity/{self.luminosity.units})>: "
f"{np.log10(mean_lum):.2f}\n"
)
mean_eq = np.mean(self.equivalent_width)
pstr += f"<equivalent width>: {mean_eq:.0f}" + "\n"
mean_flux = np.mean(self.flux) if self.flux is not None else None
pstr += (
f"<log10(flux/{self.flux.units}): {np.log10(mean_flux):.2f}"
if self.flux is not None
else ""
)
else:
pstr += (
f"log10(luminosity/{self.luminosity.units}): "
f"{np.log10(self.luminosity):.2f}\n"
)
pstr += f"equivalent width: {self.equivalent_width:.0f}" + "\n"
pstr += (
f"log10(flux/{self.flux.units}): {np.log10(self.flux):.2f}"
if self.flux is not None
else ""
)
pstr += "-" * 10
table (str)
A string representation of the LineCollection object.
"""
# Intialise the table formatter
formatter = TableFormatter(self)

return pstr
return formatter.get_table("Line")

def __add__(self, second_line):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/synthesizer/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, bar):
"lam": Angstrom,
"obslam": Angstrom,
"wavelength": Angstrom,
"wavelengths": Angstrom,
"vacuum_wavelength": Angstrom,
"original_lam": Angstrom,
"lam_min": Angstrom,
Expand Down Expand Up @@ -184,6 +185,8 @@ class Units(metaclass=UnitSingleton):
Observer frame wavelength unit.
wavelength (unyt.unit_object.Unit)
Alias for rest frame wavelength unit.
wavelengths (unyt.unit_object.Unit)
Alias for rest frame wavelength unit.
nu (unyt.unit_object.Unit)
Rest frame frequency unit.
Expand Down Expand Up @@ -295,6 +298,7 @@ def __init__(self, units=None, force=False):
# vacuum rest frame wavelengths alias
self.vacuum_wavelength = Angstrom
self.wavelength = Angstrom # rest frame wavelengths alias
self.wavelengths = Angstrom # rest frame wavelengths alias
self.original_lam = Angstrom # SVO filter wavelengths
self.lam_min = Angstrom # filter minimum wavelength
self.lam_max = Angstrom # filter maximum wavelength
Expand Down
38 changes: 33 additions & 5 deletions src/synthesizer/utils/ascii_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def format_array(self, array):
str:
The formatted string showing the mean value of the array.
"""
# Handle an empty array
if len(array) == 0:
return "[]"

# Handle the case where the array is full of strings
if isinstance(array[0], str):
# Print the first 3 elements followed by an ellipsis
return "[" + ", ".join(array[:3]) + ", ...]"

return (
f"{np.min(array):.2e} -> {np.max(array):.2e} "
f"(Mean: {np.mean(array):.2e})"
Expand Down Expand Up @@ -129,20 +138,39 @@ def format_list(self, lst):
out = []
line = []
for i, value in enumerate(lst):
# If the value is not a string, float or int, just get the Type
if not isinstance(value, (str, float, int)):
value = type(value).__name__

# Handle the first value
if i == 0:
line.append(f"[{value}, ")
line.append(f"[{value}")

# Handle the first value on a new line
elif len(line) == 0:
line.append(f" {value}")

# Handle any other value on a line
else:
line.append(f" {value}, ")
if len(line) == 4:
out.append("".join(line))
line.append(f"{value}")

# Do we need to start a new line?
if len(", ".join(line)) > 40:
out.append(", ".join(line) + ",")
line = []

# Trying to make things pretty... if theres only 1 element add a
# trailing comma
if len(line) == 1:
line[0] += ", "

# Handle the edge case where line is empty (we don't want the closing
# bracket on a new line).
if len(line) > 0:
out.append("".join(line) + "]")
out.append(", ".join(line) + "]")
else:
out[-1] += "]"

return out

def get_value_rows(self):
Expand Down

0 comments on commit 5214813

Please sign in to comment.