import numpy as np
import matplotlib.pyplot as plt
from msaexp.utils import get_prism_wave_bar_correction

scaled_yshutter = np.linspace(-1.6, 1.6, 512)

fig, ax = plt.subplots(1,1,figsize=(6,4))

for w in [1.0, 2.0, 3.0, 4.0, 5.0]:
    bar, _wrapped = get_prism_wave_bar_correction(
                            scaled_yshutter,
                            np.full_like(scaled_yshutter, w),
                            num_shutters=3,
                            wrap=False)

    ax.plot(scaled_yshutter, bar,
            label=f'{w:.0f} um',
            alpha=0.5,
            color=plt.cm.RdYlBu_r(np.interp(w, [0.8, 5.3], [0, 1]))
            )

ax.legend(loc='lower right', fontsize=6)
ax.grid()

ax.set_xlabel('scaled_yshutter = cross-dispersion pixel / 5')
ax.set_ylabel('bar shadow factor')

fig.tight_layout(pad=1)
fig.show()