import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
from astropy.constants import G

from jax import numpy as jnp

from shone.chemistry import FastchemWrapper
from shone.opacity import Opacity
from shone.transmission import de_wit_seager_2013

wavelength = np.geomspace(0.5, 5, 500)  # [µm]
pressure = np.geomspace(1e-6, 1)  # [bar]
temperature = 700 * (pressure / 0.1) ** 0.05  # [K]

opacity_samples = []
molecules = ['H2O', 'CO2']
for molecule in molecules:
    opacity = Opacity.load_demo_species(molecule)
    interp_opacity = opacity.get_interpolator()
    opacity_samples.append(
        interp_opacity(wavelength, temperature, pressure)
    )

total_opacity = jnp.array(opacity_samples).sum(axis=0)

chem = FastchemWrapper(temperature, pressure)

vmr = chem.vmr()
vmr_indices = chem.get_column_index(species_name=molecules)
weights_amu = chem.get_weights()

R_p0 = (1 * u.R_earth).cgs.value
mass = (1 * u.M_earth).cgs.value
g = (G * mass / R_p0**2).cgs.value
R_star = (1 * u.R_sun).cgs.value

Rp_Rs = de_wit_seager_2013.transmission_radius(
    wavelength, temperature, pressure, g, R_p0,
    total_opacity[None, ...],
    vmr, vmr_indices, weights_amu,
    rayleigh_scattering=True
) / R_star

ax = plt.gca()
ax.plot(wavelength, Rp_Rs)

label_height = 0.0135
ax.annotate("CO$_2$", (4.32, label_height), ha='center')

water_peaks = [1.4, 1.9, 2.7]
for peak in water_peaks:
    ax.annotate("H$_2$O", (peak, label_height), ha='center')
ax.set(
    xlabel='Wavelength [µm]',
    ylabel='$R_{\\rm p}~/~R_{\\rm s}$',
    ylim=(0.009, 0.014)
)