"""
Light curve and design matrix utilities for rotating and transited stars described by HEALPix maps.
This module provides functions to compute design matrices, light curves, and
transit light curves for stars with arbitrary surface maps and limb darkening.
"""
from functools import partial
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
from spotter import core, utils
from spotter.star import Star, transited_star
[docs]
def design_matrix(star: Star, time: ArrayLike, normalize: bool = True) -> ArrayLike:
"""
Compute the design matrix for a rotating Star.
Parameters
----------
star : Star
Star object.
time : ArrayLike
Time array in days.
Returns
-------
matrix : ndarray
Design matrix.
"""
def impl(star, time):
if star.u is not None:
if len(star.y) == 1:
return jax.vmap(
lambda u: core.design_matrix(star.y[0], star.phase(time), star.inc, u, normalize = normalize)
)(star.u)
else:
if len(star.u) == 1:
return jax.vmap(
lambda y: core.design_matrix(
y, star.phase(time), star.inc, star.u[0], normalize = normalize
)
)(star.y)
else:
return jax.vmap(
lambda y, u: core.design_matrix(y, star.phase(time), star.inc, u, normalize = normalize)
)(star.y, star.u)
else:
return jax.vmap(
lambda y: core.design_matrix(y, star.phase(time), star.inc, star.u, normalize = normalize)
)(star.y)
return jnp.vectorize(impl, excluded=(0,), signature="()->(m,n)")(star, time)
[docs]
def light_curve(star: Star, time: ArrayLike, normalize=True) -> ArrayLike:
"""
Compute the light curve of a rotating Star.
Parameters
----------
star : Star
Star object.
time : ArrayLike
Time array in days.
normalize : bool, optional
Whether to normalize the light curve (default True).
Returns
-------
lc : ndarray
Light curve array.
"""
def impl(star, time):
return jnp.einsum("ij,ij->i", design_matrix(star, time), star.y)
norm = 1 / jnp.mean(star.y) if normalize else 1.0
return jnp.vectorize(impl, excluded=(0,), signature="()->(n)")(star, time).T * norm
[docs]
def transit_design_matrix(star, x, y, z, r, time=None, normalize = True):
"""
Compute the design matrix for a transited Star.
Parameters
----------
star : Star
Star object.
x : float
x coordinate of the disk center.
y : float
y coordinate of the disk center.
z : float
z coordinate of the disk center.
r : float
Radius of the disk.
time : float or None, optional
Time in days.
Returns
-------
matrix : ndarray
Transit design matrix.
"""
X = design_matrix(star, time, normalize)
from jax.scipy.spatial.transform import Rotation
_z, _y, _x = core.vec(star.sides).T
v = jnp.stack((_x, _y, _z), axis=-1)
phase = star.phase(time)
# # ensures non-zero phase
# phase = jnp.where(phase == 0.0, 1.0, phase)
_rv = Rotation.from_rotvec([phase, 0.0, 0.0]).apply(v)
rv = jnp.where(phase == 0.0, v, _rv)
inc_angle = -jnp.pi / 2 + star.inc if star.inc is not None else 0.0
_inc_angle = jnp.where(inc_angle == 0.0, 1.0, inc_angle)
_rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(rv)
rv = jnp.where(inc_angle == 0.0, rv, _rv)
if star.obl is not None:
obl_angle = jnp.where(star.obl == 0.0, 1.0, star.obl)
_rv = Rotation.from_rotvec([0.0, 0.0, obl_angle]).apply(rv)
rv = jnp.where(obl_angle == 0.0, rv, _rv)
_x, _y, _ = rv.T
distance = jnp.linalg.norm(
jnp.array([_x, _y]) - jnp.array([y, -x])[:, None], axis=0
)
transited_y = utils.sigmoid(distance - r, 1000.0)
return X * jnp.where(z >= 0, transited_y, jnp.ones_like(transited_y))
[docs]
def transit_light_curve(
star: Star,
x: float = 0.0,
y: float = 0.0,
z: float = 0.0,
r: float = 0.0,
time: float = 0.0,
normalize=True,
):
"""
Compute the light curve of a transited Star.
Parameters
----------
star : Star
Star object.
x : float, optional
x coordinate of the disk center (default 0.0).
y : float, optional
y coordinate of the disk center (default 0.0).
z : float, optional
z coordinate of the disk center (default 0.0).
r : float, optional
Radius of the disk (default 0.0).
time : float, optional
Time in days (default 0.0).
normalize : bool, optional
Whether to normalize the light curve (default True).
Returns
-------
lc : ndarray
Transit light curve array.
"""
def impl(star, time, x, y, z):
return jnp.einsum(
"ij,ij->i", transit_design_matrix(star, x, y, z, r, time), star.y
)
norm = 1 / jnp.mean(star.y) if normalize else 1.0
return (
jnp.vectorize(impl, excluded=(0,), signature="(),(),(),()->(n)")(
star, time, x, y, z
).T
* norm
)