import matplotlib.pyplot as plt

import astropy.units as u
from astropy.constants import m_p

from shone.transmission import heng_kitzmann_2017
from shone.opacity import Opacity, generate_synthetic_opacity

generate_synthetic_opacity()

# load the one opacity file:
opacity = Opacity.load_species_from_name('synthetic')

# get a jitted 3D interpolator over wavelength, temperature, pressure:
interp_opacity = opacity.get_interpolator()

wavelength = np.linspace(1, 5, 500)  # [µm]

temperature = np.array([200, 400, 600, 800])  # [K]
pressure = np.ones_like(temperature)  # [bar]


temperature = np.array([200, 400, 600, 800])
label = [f"{t} K" for t in temperature]

example_opacity = interp_opacity(wavelength, temperature, pressure)

kappa_cloud = 5e-2  # [cm2/g]

R_0 = 1 * u.R_earth  # reference radius
P_0 = 1 * u.bar  # reference pressure
T_0 = 290 * u.K  # reference temperature
mmw = 28 * m_p  # mean molecular weight (AMU)
g = 9.8 * u.m / u.s**2  # surface gravity

# convert the arguments from astropy `Quantity`s to
# floats in cgs units:
args = (R_0, P_0, T_0, mmw, g)
cgs_args = (arg.cgs.value for arg in args)

# compute the planetary radius as a function of wavelength:
Rp = heng_kitzmann_2017.transmission_radius_isothermal_isobaric(example_opacity + kappa_cloud, *cgs_args)

# convert to transit depth:
Rstar = (1 * u.R_sun).cgs.value
transit_depth_ppm = 1e6 * (Rp / Rstar) ** 2

label = [f"{t} K" for t in temperature]
plt.plot(wavelength, transit_depth_ppm.T, label=label)
plt.legend()
plt.gca().set(
    xlabel='Wavelength [µm]',
    ylabel='Transit depth [ppm]'
)