Source code for spotter.star

"""
Star object and related utilities for HEALPix-based stellar surface modeling.

Defines the Star class, which encapsulates a HEALPix map, limb darkening,
orientation, and physical properties. Includes visualization and transit utilities.
"""

import equinox as eqx
import healpy as hp
import jax
import jax.numpy as jnp
import numpy as np
from jax.typing import ArrayLike

from spotter import core, utils, viz


[docs] class Star(eqx.Module): """ A Star object whose surface is described by HEALPix map(s). The HEALPix maps can be a 2D array with a shape of (wavelengths, pixels), or a 1D array with a shape of (pixels). When providing polynomial limb darkening coefficients, different options are possible: * u is 1D and y is 1D: Single set of limb darkening coefficients and a single map. * u is 1D and y is 2D: The same limb darkening coefficients are applied to all wavelength maps. * u is 2D and y is 1D: The limb darkening coefficients are different for each wavelength but the map is the same. * ``u.shape[0]`` == ``y.shape[0]``: u and y are 2D arrays specifying the limb darkening coeffs and maps for each wavelength. Parameters ---------- y : ArrayLike or None, optional HEALPix map of the star, with shape (pixels,) or (wavelengths, pixels). Must be provided. u : ArrayLike or None, optional Polynomial limb darkening coefficients with shape (order,) or (wavelengths, order). By default None. If provided, must either be coefficients applied to all wavelengths, or have the same length as y (i.e. defined for the same number of wavelengths). inc : float or None, optional Inclination of the star, in radians. 0 is pole-on, pi/2 is equator-on. By default None. obl : float or None, optional Obliquity of the star, in radians. 0 is no obliquity, pi/2 is maximum obliquity. By default None. period : float or None, optional Period of the star, in days. By default None. radius : float or None, optional Radius of the star, in solar radii. By default None. wv : float or None, optional Wavelength of the star maps, in meters. By default None. If provided, must be compatible with either the shape of u and/or y. Attributes ---------- y : ArrayLike HEALPix map of the star, with shape (wavelengths, pixels). u : ArrayLike or None Polynomial limb darkening coefficients with shape (wavelengths, order). period : float or None Period of the star, in days. inc : float or None Inclination of the star, in radians. 0 is pole-on, pi/2 is equator-on. obl : float or None Obliquity of the star, in radians. 0 is no obliquity, pi/2 is maximum obliquity. radius : float or None Radius of the star, in solar radii. wv : float or None Wavelength of the star maps, in meters. sides : int Number of HEALPix sides. Examples -------- .. plot:: import numpy as np from spotter.star import Star, show star = Star.from_sides(30, inc=0.5, u=(0.4, 0.3), obl=0.5) show(star) """
[docs] y: ArrayLike
"""HEALPix map of the star, with shape (wavelengths, pixels)."""
[docs] u: ArrayLike | None = None
"""Polynomial limb darkening coefficients with shape (wavelengths, order)."""
[docs] period: float | None = None
"""Period of the star, in days."""
[docs] inc: float | None = None
"""Inclination of the star, in radians. 0 is pole-on, pi/2 is equator-on."""
[docs] obl: float | None = None
"""Obliquity of the star, in radians. 0 is no obliquity, pi/2 is maximum obliquity."""
[docs] radius: float | None = None
"""Radius of the star, in solar radii."""
[docs] wv: float | None = None
"""Wavelength of the star maps, in meters."""
[docs] sides: int = eqx.field(static=True)
"""Number of HEALPix sides.""" def __init__( self, y: ArrayLike | None = None, u: ArrayLike | None = None, inc: float | None = None, obl: float | None = None, period: float | None = None, radius: float | None = None, wv: float | None = None, ): self.y = jnp.atleast_2d(y) self.u = jnp.atleast_2d(u) if u is not None else None self.inc = inc self.obl = obl self.period = period self.sides = core._N_or_Y_to_N_n(self.y[0])[0] self.radius = radius if radius is not None else 1.0 self.wv = wv @property
[docs] def N(self): """Return the number of sides of the star map.""" return self.sides
@property
[docs] def x(self): """Return the xyz coordinates of the star pixels.""" return core.vec(self.sides)
@property
[docs] def size(self): """Return the number of pixels in the star map.""" return hp.nside2npix(self.sides)
@property
[docs] def resolution(self): """Return the approximate size of a single map pixel in radians.""" return hp.nside2resol(self.sides)
[docs] def __getitem__(self, key): """ Return a new Star with selected wavelength(s). Parameters ---------- key : int, slice, or array_like Index or indices to select. Returns ------- Star New Star object with selected map(s). """ return self.set(y=self.y[key])
@classmethod
[docs] def from_sides(cls, sides: int, **kwargs): """ Create a Star object with a given number of sides. Parameters ---------- sides : int Number of sides of the HEALPix map. **kwargs Additional keyword arguments for Star. Returns ------- Star Star object with the given number of sides. """ y = np.ones(core._N_or_Y_to_N_n(sides)[1]) return cls(y, **kwargs)
[docs] def phase(self, time: ArrayLike | None) -> ArrayLike: """ Compute the rotation phase for a given time. Parameters ---------- time : array_like or None Time(s) in days. Returns ------- phase : float or array_like Rotation phase(s) in radians. """ if time is None: return 0.0 return ( 2 * jnp.pi * time / self.period if self.period is not None else jnp.zeros_like(time) )
[docs] def __mul__(self, other): """ Multiply the star map by another Star or scalar. Parameters ---------- other : Star or scalar Object to multiply with. Returns ------- Star Resulting Star object. """ if isinstance(other, Star): y = self.y * other.y else: y = self.y * other return self.set(y=y)
[docs] def __rmul__(self, other): """ Multiply the star map by another Star or scalar (right-mult). Parameters ---------- other : Star or scalar Object to multiply with. Returns ------- Star Resulting Star object. """ return self.__mul__(other)
[docs] def __add__(self, other): """ Add another Star or scalar to the star map. Parameters ---------- other : Star or scalar Object to add. Returns ------- Star Resulting Star object. """ if isinstance(other, Star): y = self.y + other.y else: y = self.y + other return self.set(y=y)
[docs] def __radd__(self, other): """ Add another Star or scalar to the star map (right-add). Parameters ---------- other : Star or scalar Object to add. Returns ------- Star Resulting Star object. """ return self.__add__(other)
[docs] def __sub__(self, other): """ Subtract another Star or scalar from the star map. Parameters ---------- other : Star or scalar Object to subtract. Returns ------- Star Resulting Star object. """ if isinstance(other, Star): y = self.y - other.y else: y = self.y - other return self.set(y=y)
[docs] def __rsub__(self, other): """ Subtract the star map from another Star or scalar (right-sub). Parameters ---------- other : Star or scalar Object to subtract from. Returns ------- Star Resulting Star object. """ return self.__sub__(other)
[docs] def set(self, **kwargs): """ Return a Star object with updated attributes. Parameters ---------- **kwargs Attributes to update. Returns ------- Star Star object with updated attributes. """ current = { "y": self.y, "u": self.u, "inc": self.inc, "obl": self.obl, "period": self.period, "radius": self.radius, "wv": self.wv, } current.update(kwargs) return Star(**current)
[docs] def spot(self, lat: float, lon: float, radius: float, sharpness: float = 20): """ Return a HEALPix map with a spot. Parameters ---------- lat : float Latitude of the spot, in radians. lon : float Longitude of the spot, in radians. radius : float Radius of the spot, in radians. sharpness : float, optional Sharpness of the spot edge (default 20). Returns ------- ArrayLike HEALPix map with a spot. """ return core.spot(self.sides, lat, lon, radius, sharpness=sharpness)
@property
[docs] def coords(self): """ Return the coordinates of the star pixels. Returns ------- coords : ndarray Cartesian coordinates of pixels. """ return core.vec(self.sides)
[docs] def show(star: Star, phase: ArrayLike = 0.0, ax=None, xsize=800, rv=False, **kwargs): """ Show the star map. If `star.y` is 2D, the first map is shown. Parameters ---------- star : Star Star object to show. phase : ArrayLike, optional Phase of the star map to show (default 0.0). ax : matplotlib axis, optional Axis to plot the star map (default None). xsize : int, optional Output image size (default 800). **kwargs Additional keyword arguments for viz.show. """ viz.show( star.y[0], star.inc if star.inc is not None else np.pi / 2, star.obl if star.obl is not None else 0.0, star.u[0] if star.u is not None else None, radius=(star.radius or 1.0) if rv else None, period=(star.period or None) if rv else None, rv=rv, phase=phase, ax=ax, xsize=xsize, **kwargs, )
[docs] def video(star: Star, duration: int = 4, fps: int = 10, rv=False, **kwargs): """ Create an HTML video of the star map (for Jupyter notebooks). Parameters ---------- star : Star Star object to show. duration : int, optional Duration of the video in seconds (default 4). fps : int, optional Frames per second (default 10). **kwargs Additional keyword arguments for viz.video. """ viz.video( star.y[0], star.inc if star.inc is not None else np.pi / 2, star.obl if star.obl is not None else 0.0, star.u[0] if star.u is not None else None, radius=star.radius if rv else None, period=star.period if rv else None, rv=rv, duration=duration, fps=fps, **kwargs, )
[docs] def transited_star( star: Star, x: float = 0.0, y: float = 0.0, z: float = 0.0, r: float = 0.0, time: float = None, ): """ Return a star transited by a circular opaque disk. Parameters ---------- star : Star Star object to be transited. x : float, optional x coordinate of the disk center (default 0.0). y : float, optional y coordinate of the disk center (default 0.0). z : float, optional z coordinate of the disk center (default 0.0). r : float, optional Radius of the disk (default 0.0). time : float, optional Time in days (default None). Returns ------- Star Star object transited by the disk. """ from jax.scipy.spatial.transform import Rotation _z, _y, _x = core.vec(star.sides).T v = jnp.stack((_x, _y, _z), axis=-1) if time is not None: phase = star.phase(time) _rv = Rotation.from_rotvec([phase, 0.0, 0.0]).apply(v) rv = jnp.where(phase == 0.0, v, _rv) else: rv = v inc_angle = -jnp.pi / 2 + star.inc if star.inc is not None else 0.0 _inc_angle = jnp.where(inc_angle == 0.0, 1.0, inc_angle) _rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(rv) rv = jnp.where(inc_angle == 0.0, v, _rv) if star.obl is not None: obl_angle = jnp.where(star.obl == 0.0, 1.0, star.obl) _rv = Rotation.from_rotvec([0.0, 0.0, obl_angle]).apply(rv) rv = jnp.where(obl_angle == 0.0, v, _rv) _x, _y, _ = rv.T distance = jnp.linalg.norm( jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0 ) spotted_star = utils.sigmoid(distance - r, 1000.0) * star return star.set(y=jnp.where(z < 0, star.y, spotted_star.y))