import numpy as np
import matplotlib.pyplot as plt

from jax import numpy as jnp

from shone.chemistry import FastchemWrapper
from shone.opacity import Opacity

wavelength = np.geomspace(0.5, 5, 500)  # [µm]
pressure = np.geomspace(1e-6, 1)  # [bar]
temperature = 700 * (pressure / 0.1) ** 0.05  # [K]

opacity_samples = []
molecules = ['H2O', 'CO2']
for molecule in molecules:
    opacity = Opacity.load_demo_species(molecule)
    interp_opacity = opacity.get_interpolator()
    opacity_samples.append(
        interp_opacity(wavelength, temperature, pressure)
    )

total_opacity = jnp.array(opacity_samples).sum(axis=0)

chem = FastchemWrapper(temperature, pressure)

vmr = chem.vmr()
vmr_indices = chem.get_column_index(species_name=molecules)
weights_amu = chem.get_weights()

fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharey=True)

ax[0].semilogy(temperature, pressure, color='k')
ax[0].set(
    xlabel='Temperature [K]',
    ylabel='Pressure [bar]',
    title='p-T structure'
)
ax[0].invert_yaxis()

for molecule, vmr_i in zip(molecules, vmr[:, vmr_indices].T):
    ax[1].loglog(vmr_i, pressure, label=molecule.replace('2', '$_2$'))
ax[1].legend()
ax[1].set(
    xlabel='VMR',
    title='Chemistry'
)
plt.tight_layout()