import numpy as np
import matplotlib.pyplot as plt
import msaexp.utils as msautils

waves = np.linspace(0.8, 5.6, 128)
yslit = np.arange(-9, 10, dtype=float)
w2d, y2d = np.meshgrid(waves, yslit)

fig, axes = plt.subplots(2,1,figsize=(8,5), sharex=True, sharey=True)

# Initialize the lookup table
prf_model = msautils.LookupTablePSF()

# Straight trace
prf = prf_model.evaluate(sigma=0, dy=0.0, slit_coords=(w2d, y2d))

axes[0].imshow(
    prf,
    extent=(waves[0], waves[-1], yslit[0], yslit[-1]),
    aspect='auto'
)

# Curved trace
dy = -((w2d-4)/2)**2
prf2 = prf_model.evaluate(sigma=0.2, dy=dy, slit_coords=(w2d, y2d))

axes[1].imshow(
    prf2,
    extent=(waves[0], waves[-1], yslit[0], yslit[-1]),
    aspect='auto'
)

axes[1].set_xlabel('wavelength')
axes[1].set_ylabel('y pixel')
fig.tight_layout(pad=1)

# Verify that integral of each along the trace is 1.0
assert(np.allclose(prf.sum(axis=0), 1., rtol=0.01))
assert(np.allclose(prf2.sum(axis=0), 1., rtol=0.01))