### Make a plot with slitlets

import numpy as np
import matplotlib.pyplot as plt
from msaexp import msa

uri = 'https://mast.stsci.edu/api/v0.1/Download/file?uri=mast:JWST/product/'
meta = msa.MSAMetafile(uri+'jw02756001001_01_msa.fits')

fig, axes = plt.subplots(1, 3, figsize=(9, 2.6),
                         sharex=True, sharey=True)
cosd = np.cos(np.median(meta.src_table['dec'])/180*np.pi)

# Show offset slitlets from three dithered exposures
for i in [0,1,2]:
    ax = axes[i]
    ax.scatter(meta.src_table['ra'], meta.src_table['dec'],
               marker='.', color='k', alpha=0.5)
    slits = meta.regions_from_metafile(dither_point_index=i+1,
                                       as_string=False,
                                       with_bars=True)
    for s in slits:
        if s.meta['is_source']:
            if s.meta['source_id'] in [110003, 410044, 410045]:
                ax.text(s.meta['ra'] - 0.8/3600, s.meta['dec'],
                        s.meta['source_id'],
                        fontsize=7, ha='left', va='center')
            fc = '0.5'
        else:
            fc = 'pink'

        for patch in s.get_patch(fc=fc, ec='None', alpha=0.8,
                                 zorder=100):
            ax.add_patch(patch)

    ax.set_aspect(1./cosd)
    ax.set_xlim(3.5936537138517317, 3.588363444812261)
    ax.set_ylim(-30.39750646306242, -30.394291511397544)

    ax.grid()
    ax.set_title(f'Dither point #{i+1}')

x0 = np.mean(ax.get_xlim())
ax.set_xticks(np.array([-5, 0, 5])/3600./cosd + x0)
ax.set_xticklabels(['+5"', 'R.A.', '-5"'])

y0 = np.mean(ax.get_ylim())
ax.set_yticks(np.array([-5, 0, 5])/3600. + y0)
axes[0].set_yticklabels(['-5"', 'Dec.', '+5"'])
axes[1].scatter(x0, y0, marker='x', c='b')
axes[1].text(0.5, 0.45, f'({x0:.6f}, {y0:.6f})',
             ha='left', va='top',
             transform=axes[1].transAxes, fontsize=6,
             color='b')

fig.tight_layout(pad=0.5)