# ExoIris: fast, flexible, and easy exoplanet transmission spectroscopy in Python.
# Copyright (C) 2024 Hannu Parviainen
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import Union, Optional
import numba
from astropy.io import fits as pf
from astropy.stats import mad_std
from astropy.utils import deprecated
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.pyplot import subplots, setp
from matplotlib.ticker import LinearLocator, FuncFormatter
from numpy import (
any,
all,
isfinite,
where,
all,
zeros_like,
diff,
asarray,
interp,
arange,
floor,
ndarray,
ceil,
newaxis,
inf,
array,
ones,
poly1d,
polyfit,
nanpercentile,
atleast_2d,
nan,
sqrt,
nanmedian,
nanmean,
unique,
ascontiguousarray,
vstack,
ones_like,
average,
)
from pytransit.orbits import fold
from .binning import Binning, CompoundBinning
from .ephemeris import Ephemeris
from .bin1d import bin1d
from .bin2d import bin2d
def _load(fname: Path | str) -> "TSData | TSDataGroup":
fname = Path(fname)
with pf.open(fname) as hdul:
if 'BBDATA' in hdul[0].header:
from .bbdata import BBData
return BBData.import_fits(hdul[0].header['BBDATA'], hdul)
elif 'TSDATA' in hdul[0].header:
return TSData.import_fits(hdul[0].header['TSDATA'], hdul)
elif 'TSDGROUP' in hdul[0].header:
return TSDataGroup.import_fits(hdul)
else:
raise ValueError(f"'{fname.name}' is not a proper TSData or TSDataGroup file.")
[docs]
class TSData:
"""
`TSData` is a utility class representing transmission spectroscopy time series data with associated wavelength,
fluxes, and errors. It provides methods for manipulating and analyzing the data.
"""
[docs]
def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, errors: Sequence, name: str,
noise_group: int = 0, wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
transit_mask: ndarray | None = None, ephemeris: Ephemeris | None = None, n_baseline: int = 1,
mask: ndarray = None, epoch_group: int = 0, offset_group: int = 0,
mask_nonfinite_errors: bool = True, covs: ndarray | None = None) -> None:
"""
Parameters
----------
time
1D Array of time values.
wavelength
1D Array of wavelength values.
fluxes
2D array of flux values with a shape ``(nwl, npt)``, where ``nwl`` is the number of wavelengths and ``npt`` the
number of exposures.
errors
2D Array of error values with a shape ``(nwl, npt)``, where ``nwl`` is the number of wavelengths and ``npt`` the
number of exposures.
name
Name for the data set.
noise_group
Noise group the data belongs to.
wl_edges
Tuple containing left and right wavelength edges for each wavelength element.
tm_edges
Tuple containing left and right time edges for each exposure.
"""
time, wavelength, fluxes, errors = asarray(time), asarray(wavelength), asarray(fluxes), asarray(errors)
if fluxes.shape[0] != wavelength.size:
raise ValueError("The size of the flux array's first axis must match the size of the wavelength array.")
if transit_mask is not None and transit_mask.size != time.size:
raise ValueError("The size of the out-of-transit mask array must match the size of the time array.")
if n_baseline < 0:
raise ValueError("n_baseline must be greater than zero.")
if noise_group < 0:
raise ValueError("noise_group must be a positive integer.")
if epoch_group < 0:
raise ValueError("epoch_group must be a non-negative integer.")
if offset_group < 0:
raise ValueError("offset_group must be a non-negative integer.")
if not all(isfinite(time)):
raise ValueError("The time array must contain only finite values.")
if not all(isfinite(wavelength)):
raise ValueError("The wavelength array must contain only finite values.")
self.name: str = name
self.mask_nonfinite_errors: bool = mask_nonfinite_errors
self.time: ndarray = time.copy()
self.wavelength: ndarray = wavelength
self.mask: ndarray = mask if mask is not None else isfinite(fluxes)
if self.mask_nonfinite_errors:
self.mask &= isfinite(errors)
self.fluxes: ndarray = where(self.mask, fluxes, nan)
self.errors: ndarray = where(self.mask, errors, nan)
if covs is not None:
self.covs: ndarray = covs
else:
ctime = self.time - self.time.mean()
self.covs = ascontiguousarray(vstack([ones(self.time.size)]+[ctime**i for i in range(1, n_baseline+1)]).T)
self.covs[:, 1:] /= self.covs[:, 1:].std(axis=0)
self.transit_mask: ndarray = transit_mask if transit_mask is not None else ones(time.size, dtype=bool)
self._wlmask: ndarray = all(self.mask, 1)
self._wls_with_nan: ndarray = where(~self._wlmask)[0]
self._ephemeris: Ephemeris | None = ephemeris
self.n_baseline: int = n_baseline
self.noise_group: int = noise_group
self.epoch_group: int = epoch_group
self.offset_group: int = offset_group
self._dataset: Optional['TSDataGroup'] = None
self.minwl: float = 0.0
self.maxwl: float = inf
self.mintm: float = 0.0
self.maxtm: float = inf
self._update()
if wl_edges is None:
dwl = zeros_like(self.wavelength)
dwl[:-1] = diff(self.wavelength)
dwl[-1] = dwl[-2]
self._wl_l_edges = self.wavelength - 0.5 * dwl
self._wl_r_edges = self.wavelength + 0.5 * dwl
else:
self._wl_l_edges = wl_edges[0]
self._wl_r_edges = wl_edges[1]
if tm_edges is None:
dt = zeros_like(self.time)
dt[:-1] = diff(self.time)
dt[-1] = dt[-2]
self._tm_l_edges = self.time - 0.5 * dt
self._tm_r_edges = self.time + 0.5 * dt
else:
self._tm_l_edges = tm_edges[0]
self._tm_r_edges = tm_edges[1]
[docs]
def export_fits(self) -> pf.HDUList:
"""Generate a `~astropy.io.fits.HDUList` containing HDUs storing the data and metadata.
Returns
-------
~astropy.io.fits.HDUList
"""
time = pf.ImageHDU(self.time, name=f'time_{self.name}')
wave = pf.ImageHDU(self.wavelength, name=f'wave_{self.name}')
data = pf.ImageHDU(array([self.fluxes, self.errors]), name=f'data_{self.name}')
covs = pf.ImageHDU(self.covs, name=f'covs_{self.name}')
ootm = pf.ImageHDU(self.transit_mask.astype(int), name=f'ootm_{self.name}')
mask = pf.ImageHDU(self.mask.astype(int), name=f'mask_{self.name}')
data.header['ngroup'] = self.noise_group
data.header['nbasel'] = self.n_baseline
data.header['epgroup'] = self.epoch_group
data.header['offgroup'] = self.offset_group
#TODO: export ephemeris
return pf.HDUList([time, wave, data, covs, ootm, mask])
[docs]
@staticmethod
def import_fits(name: str, hdul: pf.HDUList) -> 'TSData':
"""Import a data set from a `~astropy.io.fits.HDUList`.
Parameters
----------
name
The name of the dataset to be imported from the `~astropy.io.fits.HDUList`.
hdul
The `~astropy.io.fits.HDUList` containing the data.
Returns
-------
TSData
"""
time = hdul[f'TIME_{name}'].data.astype('d')
wave = hdul[f'WAVE_{name}'].data.astype('d')
data = hdul[f'DATA_{name}'].data.astype('d')
ootm = hdul[f'OOTM_{name}'].data.astype(bool)
mask = hdul[f'MASK_{name}'].data.astype(bool)
try:
covs = hdul[f'COVS_{name}'].data.astype('d')
except KeyError:
covs = None
try:
noise_group = hdul[f'DATA_{name}'].header['NGROUP']
except KeyError:
noise_group = 0
try:
ephemeris_group = hdul[f'DATA_{name}'].header['EPGROUP']
except KeyError:
ephemeris_group = 0
try:
offset_group = hdul[f'DATA_{name}'].header['OFFGROUP']
except KeyError:
offset_group = 0
try:
n_baseline = hdul[f'DATA_{name}'].header['NBASEL']
except KeyError:
n_baseline = 1
#TODO: import ephemeris
return TSData(time, wave, data[0], data[1], name=name, noise_group=noise_group, transit_mask=ootm,
n_baseline=n_baseline, mask=mask, epoch_group=ephemeris_group, offset_group=offset_group,
covs=covs)
def save(self, fname: Path, overwrite: bool = True):
hdul = pf.HDUList([pf.PrimaryHDU()] + self.export_fits())
hdul[0].header['TSDATA'] = self.name
hdul.writeto(fname, overwrite=overwrite)
@staticmethod
def load(fname: Path | str, noise_group: int | None = None) -> "TSData | TSDataGroup":
d = _load(fname)
if noise_group is not None:
d.noise_group = noise_group
return d
def __repr__(self) -> str:
return f"TSData Name:'{self.name}' [{self.wavelength[0]:.2f} - {self.wavelength[-1]:.2f}] nwl={self.nwl} npt={self.npt}"
@property
def ephemeris(self) -> Ephemeris:
"""Ephemeris."""
return self._ephemeris
@ephemeris.setter
def ephemeris(self, ep: Ephemeris) -> None:
self._ephemeris = ep
self.mask_transit(ephemeris=ep)
@property
def bbox_wl(self) -> tuple[float, float]:
"""Wavelength bounds of the bounding box."""
return self.minwl, self.maxwl
@property
def bbox_tm(self) -> tuple[float, float]:
return self.mintm, self.maxtm
[docs]
def mask_transit(self, t0: float | None = None, p: float | None = None, t14: float | None = None,
ephemeris : Ephemeris | None = None, elims: tuple[int, int] | None = None) -> 'TSData':
"""Create a transit mask based on a given ephemeris or exposure index limits.
Parameters
----------
t0
The zero-epoch time.
p
The orbital period of the planet.
t14
The duration of the full transit in days.
ephemeris
The ephemeris object containing transit timing information.
elims
The limits of the region to mask in exposure indices.
"""
if (t0 and p and t14) or ephemeris is not None:
if ephemeris is not None:
self._ephemeris = ephemeris
else:
self._ephemeris = Ephemeris(t0, p, t14)
phase = fold(self.time, self.ephemeris.period, self.ephemeris.zero_epoch)
self.transit_mask = abs(phase) > 0.502 * self.ephemeris.duration
elif elims is not None:
self.transit_mask = ones(self.fluxes.shape, bool)
self.transit_mask[:, elims[0]:elims[1]] = False
else:
raise ValueError("Transit masking requires either t0, p, and t14, ephemeris, or transit limits in exposure indices.")
return self
[docs]
def estimate_average_uncertainties(self):
"""Estimate the per-wavelength average flux uncertainties.
Estimate the per-wavelength flux uncertainties as standard deviation of the first differences of
fluxes outside the target object's region. The result is normalized to provide the estimated
uncertainty for each data point.
Notes
-----
Modifies the `~TSData.errors` attribute in place.
"""
for ipb in range(self.nwl):
self.errors[ipb, :] = (diff(self.fluxes[ipb, self.transit_mask & self.mask[ipb]]).std() / sqrt(2))
def _update(self) -> None:
"""Update the internal attributes."""
self.nwl = self.wavelength.size
self.npt = self.time.size
self.minwl = self.wavelength.min()
self.maxwl = self.wavelength.max()
self.mintm = self.time.min()
self.maxtm = self.time.max()
if self._ephemeris is not None:
self.mask_transit(ephemeris=self._ephemeris)
self._wlmask = all(self.mask, 1)
self._wls_with_nan = where(~self._wlmask)[0]
def _update_data_mask(self) -> None:
self.mask = isfinite(self.fluxes)
if self.mask_nonfinite_errors:
self.mask &= isfinite(self.errors)
self.fluxes = where(self.mask, self.fluxes, nan)
self.errors = where(self.mask, self.errors, nan)
[docs]
def normalize_to_poly(self, deg: int = 1) -> 'TSData':
"""Normalize the baseline flux for each spectroscopic light curve.
Normalize the baseline flux using a low-order polynomial fitted to the out-of-transit
data for each spectroscopic light curve.
Parameters
----------
deg
The degree of the fitted polynomial. Should be 0 or 1. Higher degrees are not allowed
because they could affect the transit depths.
Raises
------
ValueError
If `deg` is greater than 1.
"""
if deg > 1:
raise ValueError("The degree of the fitted polynomial ('deg') should be 0 or 1. Higher degrees "
"are not allowed because they could affect the transit depths.")
if self.transit_mask is None:
raise ValueError("The out-of-transit mask must be defined for normalization. "
"Call TSData.mask_transit(...) first.")
for ipb in range(self.nwl):
mask = self.transit_mask & self.mask[ipb]
if mask.sum() > 2:
bl = poly1d(polyfit(self.time[mask], self.fluxes[ipb, mask], deg=deg))(self.time)
self.fluxes[ipb, :] /= bl
self.errors[ipb, :] /= bl
else:
self.fluxes[ipb, :] = nan
self.errors[ipb, :] = nan
self._update_data_mask()
return self
[docs]
def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataGroup':
"""Partition the data into n segments defined by tlims.
Parameters
----------
tlims
The lower and upper time limits for each segment.
"""
masks = [(self.time >= l[0]) & (self.time <= l[1]) for l in tlims]
m = masks[0]
d = TSData(name=f'{self.name}_1', time=self.time[m], wavelength=self.wavelength,
fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
transit_mask=self.transit_mask[m],
ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
mask_nonfinite_errors=self.mask_nonfinite_errors,
covs=self.covs[m])
for i, m in enumerate(masks[1:]):
d = d + TSData(name=f'{self.name}_{i+2}', time=self.time[m], wavelength=self.wavelength,
fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
transit_mask=self.transit_mask[m],
ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
mask_nonfinite_errors=self.mask_nonfinite_errors,
covs=self.covs[m])
return d
[docs]
def crop_wavelength(self, lmin: float, lmax: float, inplace: bool = True) -> 'TSData':
"""Crop the data to include only the wavelength range between lmin and lmax.
Parameters
----------
lmin
The minimum wavelength value to crop.
lmax
The maximum wavelength value to crop.
inplace
If True, the data will be modified in place, otherwise a new TSData object will be returned.
"""
m = (self.wavelength > lmin) & (self.wavelength < lmax)
if inplace:
self.wavelength = self.wavelength[m]
self.fluxes = self.fluxes[m]
self.errors = self.errors[m]
self.mask = self.mask[m]
self._wl_l_edges = self._wl_l_edges[m]
self._wl_r_edges = self._wl_r_edges[m]
self._update()
return self
else:
return TSData(name=self.name, time=self.time,
wavelength=self.wavelength[m],
fluxes=self.fluxes[m],
errors=self.errors[m],
mask=self.mask[m],
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
wl_edges=(self._wl_l_edges[m], self._wl_r_edges[m]),
tm_edges=(self._tm_l_edges, self._tm_r_edges),
transit_mask=self.transit_mask, ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
mask_nonfinite_errors=self.mask_nonfinite_errors)
[docs]
def crop_time(self, tmin: float, tmax: float, inplace: bool = True) -> 'TSData':
"""Crop the data to include only the time range between lmin and lmax.
Parameters
----------
tmin
The minimum time value to crop.
tmax
The maximum time value to crop.
inplace
If True, the data will be modified in place, otherwise a new TSData object will be returned.
"""
m = (self.time > tmin) & (self.time < tmax)
if inplace:
self.time = self.time[m]
self.fluxes = self.fluxes[:, m]
self.errors = self.errors[:, m]
self.mask = self.mask[:, m]
self.transit_mask = self.transit_mask[m]
self._tm_l_edges = self._tm_l_edges[m]
self._tm_r_edges = self._tm_r_edges[m]
self.covs = self.covs[m]
self._update()
return self
else:
return TSData(name=self.name, time=self.time[m],
wavelength=self.wavelength,
fluxes=self.fluxes[:, m],
errors=self.errors[:, m],
mask = self.mask[:, m],
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
wl_edges=(self._wl_l_edges, self._wl_r_edges),
tm_edges=(self._tm_l_edges[m], self._tm_r_edges[m]),
transit_mask=self.transit_mask[m], ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
mask_nonfinite_errors=self.mask_nonfinite_errors,
covs=self.covs[m])
def remove_fully_masked(self, inplace: bool = True) -> 'TSData':
"""Drop wavelength rows and time columns that are fully masked.
A row or column is fully masked when every element of `self.mask` in it
is False. All wavelength- and time-indexed arrays are sliced
consistently, and derived attributes are refreshed via `_update`.
Parameters
----------
inplace
If True, modify the current object in place and return `self`.
Otherwise return a new `TSData` object.
"""
kw = any(self.mask, axis=1)
kt = any(self.mask, axis=0)
if inplace:
self.wavelength = self.wavelength[kw]
self.time = self.time[kt]
self.fluxes = self.fluxes[kw][:, kt]
self.errors = self.errors[kw][:, kt]
self.mask = self.mask[kw][:, kt]
self.transit_mask = self.transit_mask[kt]
self.covs = self.covs[kt]
self._wl_l_edges = self._wl_l_edges[kw]
self._wl_r_edges = self._wl_r_edges[kw]
self._tm_l_edges = self._tm_l_edges[kt]
self._tm_r_edges = self._tm_r_edges[kt]
self._update()
return self
else:
return TSData(name=self.name, time=self.time[kt],
wavelength=self.wavelength[kw],
fluxes=self.fluxes[kw][:, kt],
errors=self.errors[kw][:, kt],
mask=self.mask[kw][:, kt],
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
wl_edges=(self._wl_l_edges[kw], self._wl_r_edges[kw]),
tm_edges=(self._tm_l_edges[kt], self._tm_r_edges[kt]),
transit_mask=self.transit_mask[kt], ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
mask_nonfinite_errors=self.mask_nonfinite_errors,
covs=self.covs[kt])
# TODO: separate mask into bad data mask and outlier mask.
[docs]
def mask_outliers(self, sigma: float = 5.0) -> 'TSData':
"""Mask outliers along the wavelength axis.
Outliers are defined as data points that deviate from the running 5-point median by more
than sigma times the median absolute deviation along the wavelength axis.
Parameters
----------
sigma
The number of standard deviations to use as the threshold for outliers.
Note
----
The data will be modified in place.
"""
fm = nanmedian(self.fluxes, axis=0)
fe = mad_std(self.fluxes, axis=0, ignore_nan=True)
self.mask &= abs(self.fluxes - fm) / fe < sigma
self.fluxes = where(self.mask, self.fluxes, nan)
self.errors = where(self.mask, self.errors, nan)
self._wlmask = all(self.mask, 1)
self._wls_with_nan = where(~self._wlmask)[0]
return self
@deprecated("0.10", alternative="TSData.mask_outliers")
def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
"""Remove outliers along the wavelength axis."""
self.mask_outliers(sigma=sigma)
[docs]
def plot(self, ax=None, vmin: float = None, vmax: float = None, cmap=None, figsize=None, data=None,
plims: tuple[float, float] | None = None) -> Figure:
"""Plot the spectroscopic light curves as a 2D image.
Plot the spectroscopic light curves as a 2D image with time on the x-axis, wavelength and light curve index
on the y-axis, and the flux as a color.
Parameters
----------
ax
The subplot axes on which to plot. If None, a new figure and axes will be created.
vmin
The minimum value of the color scale.
vmax
The maximum value of the color scale.
cmap
The colormap to be used.
figsize
The size of the figure in inches (width, height).
data
Dataset to plot instead of self.fluxes.
plims
Percentile flux limits. Overrides vmin and vmax.
Returns
-------
~matplotlib.figure.Figure
"""
if ax is None:
fig, ax = subplots(figsize=figsize, constrained_layout=True)
else:
fig = ax.figure
tref = floor(self.time.min())
def forward_y(y):
return interp(y, self.wavelength, arange(self.nwl))
def inverse_y(y):
return interp(y, arange(self.nwl), self.wavelength)
def forward_x(x):
return interp(x, self.time-tref, arange(self.npt))
def inverse_x(x):
return interp(x, arange(self.npt), self.time-tref)
data = data if data is not None else self.fluxes
if plims is not None:
vmin, vmax = nanpercentile(data, plims)
ax.pcolormesh(self.time - tref, self.wavelength, data, vmin=vmin, vmax=vmax, cmap=cmap)
if self.ephemeris is not None:
[ax.axvline(tl-tref, ls='--', c='k') for tl in self.ephemeris.transit_limits(self.time.mean())]
setp(ax, ylabel=r'Wavelength [$\mu$m]', xlabel=f'Time - {tref:.0f} [BJD]')
ax.yaxis.set_major_locator(LinearLocator(10))
ax.yaxis.set_major_formatter('{x:.2f}')
ax.xaxis.set_major_locator(LinearLocator())
ax.xaxis.set_major_formatter('{x:.3f}')
if self.name != "":
ax.set_title(self.name)
axy2 = ax.secondary_yaxis('right', functions=(forward_y, inverse_y))
axy2.set_ylabel('Light curve index')
axy2.set_yticks(forward_y(ax.get_yticks()))
axy2.yaxis.set_major_formatter('{x:.0f}')
axx2 = ax.secondary_xaxis('top', functions=(forward_x, inverse_x))
axx2.set_xlabel('Exposure index')
axx2.xaxis.set_major_locator(LinearLocator())
axx2.xaxis.set_major_formatter('{x:.0f}')
ax.axx2 = axx2
ax.axy2 = axy2
return fig
def create_white_light_curve(self, data=None) -> ndarray:
"""Create a white light curve."""
if data is not None and data.shape != self.fluxes.shape:
raise ValueError("The data must have the same shape as the 2D flux array.")
data = data if data is not None else self.fluxes
weights = where(isfinite(data) & isfinite(self.errors), 1/self.errors**2, 0.0)
return average(where(isfinite(data), data, 0), axis=0, weights=weights)
[docs]
def plot_white(self, ax: Axes | None = None, figsize: tuple[float, float] | None = None) -> Figure:
"""Plot a white light curve.
Parameters
----------
ax
The axes on which to plot. If None, a new figure and axes are created.
figsize
The size of the figure to create if `ax` is None. It should be a tuple in the format (width, height).
Returns
-------
~matplotlib.figure.Figure
"""
if ax is None:
fig, ax = subplots(figsize=figsize)
else:
fig = ax.figure
tref = floor(self.time.min())
ax.plot(self.time, self.create_white_light_curve())
if self.ephemeris is not None:
[ax.axvline(tl, ls='--', c='k') for tl in self.ephemeris.transit_limits(self.time.mean())]
if self.name != "":
ax.set_title(self.name)
def forward_x(x):
return interp(x, self.time, arange(self.npt))
def inverse_x(x):
return interp(x, arange(self.npt), self.time)
axx2 = ax.secondary_xaxis('top', functions=(forward_x, inverse_x))
axx2.set_xlabel('Exposure index')
axx2.xaxis.set_major_locator(LinearLocator())
axx2.xaxis.set_major_formatter('{x:.0f}')
ax.xaxis.set_major_formatter(FuncFormatter(lambda x,p: f"{x-tref:.3f}"))
setp(ax, xlabel=f'Time - {tref:.0f} [BJD]', ylabel='Normalized flux', xlim=[self.time[0]-0.003, self.time[-1]+0.003])
return fig
def plot_baseline(self, ax: Axes | None = None, figsize: tuple[float, float] | None = None) -> Figure:
"""Plot the out-of-transit spectroscopic light curves before and after the normalization.
Parameters
----------
ax
The axes on which to plot. If None, a new figure and axes are created.
figsize
The size of the figure to create if `ax` is None. It should be a tuple in the format (width, height).
Returns
-------
~matplotlib.figure.Figure
"""
return self.plot(ax=ax, figsize=figsize, data=where(self.transit_mask, self.fluxes, nan))
def plot_mean_error(self, ax: Axes | None = None, figsize: tuple[float, float] | None = None) -> Figure:
if ax is None:
fig, ax = subplots(figsize=figsize, constrained_layout=True)
else:
fig = ax.figure
ax.plot(self.wavelength, self.errors.mean(1)*1e6)
setp(ax, xlabel=r'Wavelength [$\mu$m]', ylabel='Mean flux error [ppm]')
return fig
def __add__(self, other: Union['TSData', 'TSDataGroup']) -> 'TSDataGroup':
"""Combine two transmission spectra along the wavelength axis.
Parameters
----------
other
The TSData object to be added to the current TSData object.
Returns
-------
TSDataGroup
"""
if isinstance(other, TSData):
return TSDataGroup([self, other])
else:
return TSDataGroup([self]) + other
[docs]
def bin(self,
wave_binning: Optional[Union[Binning, CompoundBinning]] = None,
time_binning: Optional[Union[Binning, CompoundBinning]] = None,
wave_nb: Optional[int] = None, wave_bw: Optional[float] = None, wave_r: Optional[float] = None,
time_nb: Optional[int] = None, time_bw: Optional[float] = None,
estimate_errors: bool = False) -> 'TSData':
"""Bin the data along the wavelength and/or time axes.
Bin the data along the wavelength and/or time axes. If binning is not specified, a Binning object is created using the
minimum and maximum time and wavelength values.
Parameters
----------
binning
The binning method to use.
nb
Number of bins.
bw
Bin width.
r
Bin resolution.
estimate_errors
Should the uncertainties be estimated from the data.
Returns
-------
TSData
"""
if wave_binning is None and wave_nb is None and wave_bw is None and wave_r is None:
return self.bin_time(time_binning, time_nb, time_bw, estimate_errors=estimate_errors)
if time_binning is None and time_nb is None and time_bw is None:
return self.bin_wavelength(wave_binning, wave_nb, wave_bw, wave_r, estimate_errors=estimate_errors)
with warnings.catch_warnings():
warnings.simplefilter('ignore', numba.NumbaPerformanceWarning)
if wave_binning is None:
wave_binning = Binning(self.bbox_wl[0], self.bbox_wl[1], nb=wave_nb, bw=wave_bw, r=wave_r)
if time_binning is None:
time_binning = Binning(self.time.min(), self.time.max(), nb=time_nb, bw=time_bw/(24*60*60) if time_bw is not None else None)
bf, be, bn = bin2d(self.fluxes, self.errors,
self._wl_l_edges, self._wl_r_edges,
self._tm_l_edges, self._tm_r_edges,
wave_binning.bins, time_binning.bins,
estimate_errors=estimate_errors)
bc, _ = bin1d(self.covs, ones_like(self.covs),
self._tm_l_edges, self._tm_r_edges,
time_binning.bins,
estimate_errors=False)
if not all(isfinite(be)):
warnings.warn('Error estimation failed for some bins, check the error array.')
d = TSData(time_binning.bins.mean(1), wave_binning.bins.mean(1), bf, be,
name=self.name,
wl_edges=(wave_binning.bins[:, 0], wave_binning.bins[:, 1]),
tm_edges=(time_binning.bins[:, 0], time_binning.bins[:, 1]),
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
covs=bc)
if self.ephemeris is not None:
d.mask_transit(ephemeris=self.ephemeris)
return d
[docs]
def bin_wavelength(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
nb: Optional[int] = None, bw: Optional[float] = None, r: Optional[float] = None,
estimate_errors: bool = False) -> 'TSData':
"""Bin the data along the wavelength axis.
Bin the data along the wavelength axis. If binning is not specified, a Binning object is created using the
minimum and maximum values of the wavelength.
Parameters
----------
binning
The binning method to use.
nb
Number of bins.
bw
Bin width.
r
Bin resolution.
estimate_errors
Should the uncertainties be estimated from the data.
Returns
-------
TSData
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore', numba.NumbaPerformanceWarning)
if binning is None:
binning = Binning(self.bbox_wl[0], self.bbox_wl[1], nb=nb, bw=bw, r=r)
bf, be = bin1d(self.fluxes, self.errors, self._wl_l_edges, self._wl_r_edges,
binning.bins, estimate_errors=estimate_errors)
if not all(isfinite(be)):
warnings.warn('Error estimation failed for some bins, check the error array.')
return TSData(self.time, binning.bins.mean(1), bf, be,
wl_edges=(binning.bins[:,0], binning.bins[:,1]),
name=self.name,
tm_edges=(self._tm_l_edges, self._tm_r_edges),
noise_group=self.noise_group,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
transit_mask=self.transit_mask,
ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
covs=self.covs)
[docs]
def bin_time(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
nb: Optional[int] = None, bw: Optional[float] = None,
estimate_errors: bool = False) -> 'TSData':
"""Bin the data along the time axis.
Bin the data along the time axis. If binning is not specified, a Binning object is created using the
minimum and maximum time values.
Parameters
----------
binning
The binning method to use.
nb
Number of bins.
bw
Bin width in seconds.
estimate_errors
Should the uncertainties be estimated from the data.
Returns
-------
TSData
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore', numba.NumbaPerformanceWarning)
if binning is None:
binning = Binning(self.time.min(), self.time.max(), nb=nb, bw=bw/(24*60*60) if bw is not None else None)
bf, be = bin1d(self.fluxes.T, self.errors.T, self._tm_l_edges, self._tm_r_edges,
binning.bins, estimate_errors=estimate_errors)
bc, _ = bin1d(self.covs, ones_like(self.covs), self._tm_l_edges, self._tm_r_edges, binning.bins, False)
d = TSData(binning.bins.mean(1), self.wavelength, bf.T, be.T,
wl_edges=(self._wl_l_edges, self._wl_r_edges),
tm_edges=(binning.bins[:,0], binning.bins[:,1]),
name=self.name,
noise_group=self.noise_group,
ephemeris=self.ephemeris,
n_baseline=self.n_baseline,
epoch_group=self.epoch_group,
offset_group=self.offset_group,
covs=bc)
if self.ephemeris is not None:
d.mask_transit(ephemeris=self.ephemeris)
return d
[docs]
class TSDataGroup:
"""`TSDataGroup` is a high-level data storage class that can contain multiple `TSData` objects.
"""
[docs]
def __init__(self, data: Sequence[TSData]):
self.data: list[TSData] = []
self.wlmin: float = inf
self.wlmax: float = -inf
self.tmin: float = inf
self.tmax: float = -inf
self._noise_groups: ndarray | None = None
for d in data:
self._add_data(d)
def _add_data(self, d: TSData) -> None:
if d.name in self.names:
raise ValueError('A TSData object with the same name already exists.')
d._dataset = self
self.data.append(d)
self._noise_groups = array([d.noise_group for d in self.data])
self.wlmin = min(self.wlmin, d.wavelength.min())
self.wlmax = max(self.wlmax, d.wavelength.max())
self.tmin = min(self.tmin, d.time.min())
self.tmax = max(self.tmax, d.time.max())
@property
def names(self) -> list[str]:
"""List of data set names."""
return [d.name for d in self.data]
@property
def times(self) -> list[ndarray]:
"""List of 1D time arrays."""
return [d.time for d in self.data]
@property
def wavelengths(self) -> list[ndarray]:
"""List of 1D wavelength arrays."""
return [d.wavelength for d in self.data]
@property
def fluxes(self) -> list[ndarray]:
"""List of 2D flux arrays."""
return [d.fluxes for d in self.data]
@property
def errors(self) -> list[ndarray]:
"""List of 2D error arrays."""
return [d.errors for d in self.data]
@property
def noise_groups(self) -> ndarray[int] | None:
"""Array of noise groups."""
return self._noise_groups
@property
def n_noise_groups(self) -> int:
"""Number of noise groups."""
return len(unique(self.noise_groups))
@property
def offset_groups(self) -> list[int]:
"""List of offset groups."""
return [d.offset_group for d in self.data]
@property
def epoch_groups(self) -> list[int]:
"""List of epoch groups."""
return [d.epoch_group for d in self.data]
@property
def n_baselines(self) -> list[int]:
"""Number of baseline coefficients for each data set."""
return [d.n_baseline for d in self.data]
@property
def size(self) -> int:
"""Number of data sets."""
return len(self.data)
[docs]
def export_fits(self) -> pf.HDUList:
"""Export the dataset along with its metadata to a FITS HDU list.
Returns
-------
~astropy.io.fits.HDUList
"""
ds = pf.ImageHDU(name=f'dataset')
ds.header['ndata'] = self.size
for i,n in enumerate(self.names):
ds.header[f'name_{i}'] = n
hdul = pf.HDUList([ds])
for d in self.data:
hdul += d.export_fits()
return hdul
[docs]
@staticmethod
def import_fits(hdul: pf.HDUList) -> 'TSDataGroup':
"""Import all the data from a FITS HDU list.
Parameters
----------
hdul
HDU list containing FITS data.
Returns
-------
TSDataGroup
"""
ds = hdul['DATASET']
data = []
for i in range(ds.header['NDATA']):
name = ds.header[f'NAME_{i}']
if f'FLUX_{name}' in hdul and hdul[f'FLUX_{name}'].header.get('BBDATA', False):
from .bbdata import BBData
data.append(BBData.import_fits(name, hdul))
else:
data.append(TSData.import_fits(name, hdul))
return TSDataGroup(data)
[docs]
def mask_transit(self, tc: float, p: float, t14: float):
for d in self.data:
d.mask_transit(tc, p, t14)
def __getitem__(self, index: int) -> TSData:
return self.data[index]
def __len__(self) -> int:
return self.size
def __repr__(self):
return f"TSDataGroup with {self.size} groups"
def save(self, fname: Path, overwrite: bool = True):
hdul = pf.HDUList([pf.PrimaryHDU()] + self.export_fits())
hdul[0].header['tsdgroup'] = True
hdul.writeto(fname, overwrite=overwrite)
@staticmethod
def load(fname: Path | str) -> "TSData | TSDataGroup":
return _load(fname)
[docs]
def plot(self, axs=None, vmin: float = None, vmax: float = None, ncols: int = 1, cmap=None, figsize=None, data: ndarray | None = None) -> Figure:
"""Plot all the data sets.
Parameters
----------
axs
A 2D ndarray of Axes used for plotting. If None, a new set of subplots will be created.
vmin
The minimum value for the color mapping.
vmax
The maximum value for the color mapping.
ncols
The number of columns in the subplot grid.
cmap
The colormap used for mapping the data values to colors.
figsize
The size of the figure created if `ax` is None.
data
The data to be plotted. If None, the `self.data` attribute will be used.
Returns
-------
~matplotlib.figure.Figure
"""
if axs is None:
nrows = int(ceil(self.size / ncols))
fig, axs = subplots(nrows, ncols=ncols, figsize=figsize, squeeze=False)
else:
axs = atleast_2d(axs)
fig = axs.flat[0].get_figure()
if data is None:
for i in range(self.size):
self.data[i].plot(ax=axs.flat[i], vmin=vmin, vmax=vmax, cmap=cmap)
else:
for i in range(self.size):
self.data[i].plot(ax=axs.flat[i], vmin=vmin, vmax=vmax, cmap=cmap, data=data[i])
setp(axs[:-1, :], xlabel='')
return fig
def plot_white(self, axs=None, ncols: int = 1, figsize=None) -> Figure:
"""Plot the white light curves.
Parameters
----------
axs
A 2D ndarray of Axes used for plotting. If None, a new set of subplots will be created.
ncols
The number of columns in the subplot grid.
figsize
The size of the figure created if `ax` is None.
Returns
-------
~matplotlib.figure.Figure
"""
if axs is None:
nrows = int(ceil(self.size / ncols))
fig, axs = subplots(nrows, ncols=ncols, figsize=figsize, squeeze=False)
else:
axs = atleast_2d(axs)
fig = axs.flat[0].get_figure()
for i in range(self.size):
self.data[i].plot_white(ax=axs.flat[i])
setp(axs[:-1, :], xlabel='')
return fig
def __add__(self, other):
if isinstance(other, TSData):
return TSDataGroup(self.data + [other])
elif isinstance(other, TSDataGroup):
return TSDataGroup(self.data + other.data)
class TSDataSet(TSDataGroup):
pass