diff --git a/src/synthesizer/emission_models/attenuation/dust.py b/src/synthesizer/emission_models/attenuation/dust.py index a2d90daaa5..94208ea591 100644 --- a/src/synthesizer/emission_models/attenuation/dust.py +++ b/src/synthesizer/emission_models/attenuation/dust.py @@ -473,10 +473,10 @@ def get_tau(self, lam, interp="cubic"): An array of wavelengths or a single wavlength at which to calculate optical depths (in AA, global unit). interp (str) - The type of interpolation to use. Can be ‘linear’, ‘nearest’, - ‘nearest-up’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, - ‘previous’, or ‘next’. ‘zero’, ‘slinear’, ‘quadratic’ and - ‘cubic’ refer to a spline interpolation of zeroth, first, + The type of interpolation to use. Can be 'linear', 'nearest', + 'nearest-up', 'zero', 'slinear', 'quadratic', 'cubic', + 'previous', or 'next'. 'zero', 'slinear', 'quadratic' and + 'cubic' refer to a spline interpolation of zeroth, first, second or third order. Uses scipy.interpolate.interp1d. Returns: @@ -531,7 +531,7 @@ def __init__(self, model="SMCBar"): self.emodel = WD01(self.model) @accepts(lam=angstrom) - def get_tau(self, lam): + def get_tau(self, lam, interp="slinear"): """ Calculate V-band normalised optical depth. @@ -540,33 +540,35 @@ def get_tau(self, lam): An array of wavelengths or a single wavlength at which to calculate optical depths (in AA, global unit). + interp (str) + The type of interpolation to use. Can be 'linear', 'nearest', + 'nearest-up', 'zero', 'slinear', 'quadratic', 'cubic', + 'previous', or 'next'. 'zero', 'slinear', 'quadratic' and + 'cubic' refer to a spline interpolation of zeroth, first, + second or third order. Uses scipy.interpolate.interp1d. + Returns: float/array-like, float The optical depth. """ - return self.emodel(lam.to_astropy()) - @accepts(lam=angstrom) - def get_transmission(self, tau_v, lam): - """ - Return the transmitted flux/luminosity fraction. - - Args: - tau_v (float/array-like, float) - Optical depth in the V-band. Can either be a single float or - array. + lam_lims = np.logspace(2, 8, 10000) * angstrom + lam_v = 5500 * angstrom # V-band wavelength + func = interpolate.interp1d( + lam_lims, + self.emodel(lam_lims.to_astropy()), + kind=interp, + fill_value="extrapolate", + ) + out = func(lam) / func(lam_v) - lam (array-like, float) - The wavelengths (with units) at which to calculate - transmission. + if np.isscalar(lam): + if lam > lam_lims[-1]: + out = func(lam_lims[-1]) + elif np.sum(lam > lam_lims[-1]) > 0: + out[(lam > lam_lims[-1])] = func(lam_lims[-1]) - Returns: - array-like - The transmission at each wavelength. Either (lam.size,) in - shape for singular tau_v values or (tau_v.size, lam.size) - tau_v is an array. - """ - return self.emodel.extinguish(x=lam.to_astropy(), Av=1.086 * tau_v) + return out @accepts(lam=um)