import numpy as np
import matplotlib.pyplot as plt
from msaexp.utils import get_prism_bar_correction

scaled_yshutter = np.linspace(-1.6, 1.6, 512)

fig, ax = plt.subplots(1,1,figsize=(6,4))

for n in [1,2,3]:
    bar, _wrapped = get_prism_bar_correction(scaled_yshutter,
                                             num_shutters=n,
                                             wrap=False)
    ax.plot(scaled_yshutter, bar, label=f'{n}-shutter', alpha=0.5)

ax.legend(loc='lower right', fontsize=6)
ax.grid()

ax.set_xlabel('scaled_yshutter = cross-dispersion pixel / 5')
ax.set_ylabel('bar shadow factor')

fig.tight_layout(pad=1)
fig.show()