Source code for msaexp.drizzle

"""
Tools for drizzle-combining MSA spectra
"""

import glob
import warnings

import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as nd

import jwst.datamodels
import astropy.io.fits as pyfits

import grizli.utils

from . import utils

# Parameter defaults
DRIZZLE_PARAMS = dict(
    output=None,
    single=True,
    blendheaders=True,
    pixfrac=1.0,
    kernel="square",
    fillval=0,
    wht_type="ivm",
    good_bits=0,
    pscale_ratio=1.0,
    pscale=None,
)

FIGSIZE = (10, 4)

IMSHOW_KWS = dict(
    vmin=-0.1,
    vmax=None,
    aspect="auto",
    interpolation="nearest",
    origin="lower",
    cmap="cubehelix_r",
)


[docs]def metadata_tuple(slit): """ Tuple of (msa_metadata_file, msa_metadata_id) Parameters ---------- slit : `jwst.datamodels.SlitModel` Slitlet data object Returns ------- meta : (str, str) Tuple of `(msa_metadata_file, msa_metadata_id)` """ return ( slit.meta.instrument.msa_metadata_file, slit.meta.instrument.msa_metadata_id, )
[docs]def center_wcs( slit, waves, center_on_source=False, force_nypix=31, fix_slope=None, slit_center=0.0, center_phase=-0.5, ): """ Derive a 2D spectral WCS centered on the expected source position along the slit Parameters ---------- slit : `jwst.datamodels.SlitModel` Slitlet data object waves : array-like Target wavelength array center_on_source : bool Center on the source position along the slit. If not, center on `slit_center` slit coordinate force_nypix : int Cross-dispersion size of the output 2D WCS fix_slope : float Fixed cross-dispersion pixel size, in units of the slit coordinate frame slit_center : float Define the center of the slit in the slit coordinate frame center_phase : float Pixel phase defining the center of the slit alignment Returns ------- wcs_data : object Output from `msaexp.utils.build_slit_centered_wcs` offset_to_source : float Offset between center of the WCS and the expected source position meta : tuple MSA key from `msaexp.drizzle.metadata_tuple` """ # Centered on source wcs_data = utils.build_slit_centered_wcs( slit, waves, force_nypix=force_nypix, center_on_source=True, get_from_ypos=False, phase=center_phase, fix_slope=fix_slope, slit_center=slit_center, ) slit.drizzle_slit_offset_source = slit.drizzle_slit_offset # Centered on slitlet if center_on_source: offset_to_source = 0.0 else: wcs_data = utils.build_slit_centered_wcs( slit, waves, force_nypix=force_nypix, center_on_source=False, get_from_ypos=False, phase=center_phase, fix_slope=fix_slope, slit_center=slit_center, ) # Derived offset between source and center of slitlet offset_to_source = ( slit.drizzle_slit_offset_source - slit.drizzle_slit_offset ) return wcs_data, offset_to_source, metadata_tuple(slit)
[docs]def drizzle_slitlets( id, wildcard="*phot", files=None, output=None, verbose=True, drizzle_params=DRIZZLE_PARAMS, master_bkg=None, wave_arrays={}, wave_sample=1, log_step=True, force_nypix=31, center_on_source=False, center_phase=-0.5, fix_slope=None, outlier_threshold=5, sn_threshold=3, bar_threshold=-0.7, err_threshold=1000, bkg_offset=5, bkg_parity=[1, -1], mask_padded=False, show_drizzled=True, show_slits=True, imshow_kws=IMSHOW_KWS, get_quick_data=False, max_sn_threshold=20, reopen=True, **kwargs, ): """ Implementing more direct drizzling of multiple 2D slitlets Parameters ---------- id, wildcard : object, str Values to search for extracted slitlet files: .. code-block:: python :dedent: files = glob.glob(f'{wildcard}*_{id}.fits') files : list Explicit list of either slitlet filenames or `jwst.datamodels.SlitModel` objects. output : str Optional rootname of output figures and FITS data verbose : bool Verbose messaging drizzle_params : dict Drizzle parameters passed to `msaexp.utils.drizzle_slits_2d` master_bkg : array-like, int Master background to replace local background derived from the drizzled product wave_arrays : dict Explicit target wavelength arrays with keys for `{grating}-{filter}` combinations wave_sample, log_step : float, bool If `waves` not specified, generate with `msaexp.utils.get_standard_wavelength_grid` force_nypix, center_on_source, center_phase, fix_slope : int, bool, float Parameters of `msaexp.drizzle.center_wcs` outlier_threshold : int Outlier threshold in drizzle combination sn_threshold : float Mask pixels in slitlets where `data/err < sn_threshold`. For the prism, essentially all pixels should have S/N > 5 from the background, so this mask can help identify and mask stuck-closed slitlets bar_threshold : float Mask pixels in slitlets where `barshadow < bar_threshold` err_threshold : float Mask pixels in slitlets where `err > err_threshold*median(err)`. There are some strange pixels with very large uncertainties in the pipeline products. bkg_offset, bkg_parity : int, list Offset in pixels for defining the local background of the drizzled product, which is derived by rolling the data array by `bkg_offset*bkg_parity` pixels. The standard three-shutter nod pattern corresponds to about 5 pixels. An optimal combination seems to be ``fix_slope=0.2``, ``bkg_offset=6``. If ``bkg_offset < 0``, then don't do shifted offset. mask_padded : bool Mask pixels of slitlets that had been padded around the nominal MSA slitlets show_drizzled : bool Make a figure with `msaexp.drizzle.show_drizzled_product` showing the drizzled combined arrays. If `output` specified, save to `{output}-{id}-[grating].d2d.png`. show_slits : bool Make a figure with `msaexp.drizzle.show_drizzled_slits` showing the individual drizzled slitlets. If `output` specified, save to `{output}-{id}-[grating].slit2d.png`. imshow_kws : dict Keyword arguments for ``matplotlib.pyplot.imshow`` in `show_drizzled` and `show_slits` figures. get_quick_data : bool Just return `waves`, `slits`, and the drizzled `sci` and `err` data arrays before doing any outlier rejection, etc. max_sn_threshold : float S/N threshold for initial rejection of the maximum pixel in the set reopen : bool Re-initialize `jwst.datamodels.SlitModel` before drizzling to fix apparent memory leak issue Returns ------- figs : dict Any figures that were created, keys are separated by grating+filter here and below data : dict `~astropy.io.fits.HDUList` FITS data for the drizzled output wavedata : dict Wavelength arrays all_slits : dict `SlitModel` objects for the input slitlets drz_data : dict 3D `sci` and `wht` arrays of the drizzled slitlets that were combined into the drizzled stack """ if files is None: files = glob.glob(f"{wildcard}*_{id}.fits") # files = glob.glob(f'*{pipeline_extension}*_{id}.fits') files.sort() # Read the SlitModels gratings = {} grating_files = {} msg = f"msaexp.drizzle.drizzle_slitlets: {id} read {len(files)} files" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) for file in files: if isinstance(file, str): slit = jwst.datamodels.SlitModel(file) utils.update_slit_metadata(slit) else: slit = file reopen = False grating = slit.meta.instrument.grating.lower() filt = slit.meta.instrument.filter.lower() key = f"{grating}-{filt}" if key not in gratings: gratings[key] = [] grating_files[key] = [] gratings[key].append(slit) grating_files[key].append(file) if verbose: for g in gratings: msg = f"msaexp.drizzle.drizzle_slitlets: id={id} {g}" msg += f" N={len(gratings[g])}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) # DQ mask for gr in gratings: slits = gratings[gr] for slit in slits: utils.update_slit_dq_mask( slit, mask_padded=mask_padded, bar_threshold=bar_threshold, verbose=False, ) # Loop through gratings figs = {} data = {} wavedata = {} all_slits = {} drz_data = {} for gr in gratings: slits = gratings[gr] # Default wavelengths if gr in wave_arrays: waves = wave_arrays[gr] else: waves = utils.get_standard_wavelength_grid( gr.split("-")[0], sample=wave_sample, log_step=log_step ) # Drizzle 2D spectra drz = None drz_ids = [] wcs_data = None # Approximate for default to_ujy = 1.0e12 * 5.0e-13 wcs_meta = None # Get offset from one science slit for i in range(len(slits)): # [18:40]: slit = slits[i] if "background" in slit.source_name: continue elif "-" in slit.source_name: continue elif slit.data.shape[1] < 50: continue msg = f"msaexp.drizzle.drizzle_slitlets: get wcs from slit {i} = " msg += f" {slit.source_name}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) _center = center_wcs( slit, waves, force_nypix=force_nypix, center_on_source=center_on_source, fix_slope=fix_slope, center_phase=center_phase, ) wcs_data, offset_to_source, wcs_meta = _center try: to_ujy = 1.0e12 * slit.meta.photometry.pixelarea_steradians except TypeError: to_ujy = 1.0 break if wcs_meta is None: # Run for background slits centered on slitlet if all skipped above _center = center_wcs( slits[0], waves, force_nypix=force_nypix, center_on_source=False, fix_slope=fix_slope, slit_center=1.0, center_phase=center_phase, ) wcs_data, offset_to_source, wcs_meta = _center try: to_ujy = 1.0e12 * slits[0].meta.photometry.pixelarea_steradians except TypeError: to_ujy = 1.0 ################## # Now do the drizzling msg = f"msaexp.drizzle.drizzle_slitlets: output size = {wcs_data[2]}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) # FITS header metadata h = pyfits.Header() wcs_header = wcs_data[1] for k in wcs_header: h[k] = wcs_header[k], wcs_header.comments[k] h["BKGOFF"] = bkg_offset, "Background offset" h["OTHRESH"] = outlier_threshold, "Outlier mask threshold, sigma" h["WSAMPLE"] = wave_sample, "Wavelength sampling factor" h["LOGWAVE"] = log_step, "Target wavelengths are log spaced" h["BUNIT"] = "mJy" inst = slits[0].meta.instance["instrument"] for k in ["grating", "filter"]: h[k] = inst[k] h["NFILES"] = len(slits), "Number of extracted slitlets" h["EFFEXPTM"] = 0.0, "Total effective exposure time" for slit in slits: h["SRCNAME"] = slit.source_name, "source_name from MSA file" h["SRCID"] = slit.source_id, "source_id from MSA file" h["SRCRA"] = slit.source_ra, "source_ra from MSA file" h["SRCDEC"] = slit.source_dec, "source_dec from MSA file" if slit.source_ra > 0: break to_ujy_list = [] for i in range(len(slits)): # [18:40]: slit = slits[i] _file = grating_files[gr][i] msg = "msaexp.drizzle.drizzle_slitlets: " msg += f"{gr} {i:2} {slit.source_name:18} {slit.source_id:9}" msg += f" {slit.source_ypos:6.3f} {_file} {slit.data.shape}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) drz_ids.append(slit.source_name) h["EFFEXPTM"] += slit.meta.exposure.effective_exposure_time if (metadata_tuple(slit) != wcs_meta) & center_on_source: _center = center_wcs( slit, waves, force_nypix=force_nypix, center_on_source=center_on_source, fix_slope=fix_slope, center_phase=center_phase, ) wcs_data, offset_to_source, wcs_meta = _center msg = "msaexp.drizzle.drizzle_slitlets: " msg += f"Recenter on source ({metadata_tuple(slit)})" msg += f" y={offset_to_source:.2f}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) # Recalculate photometry scaling try: to_ujy = 1.0e12 * slit.meta.photometry.pixelarea_steradians except TypeError: to_ujy = 1.0 # Slit metadata # Do the drizzling if reopen: # _slit = jwst.datamodels.SlitModel(grating_files[gr][i]) # utils.update_slit_metadata(_slit) _slit = utils.update_slit_dq_mask( grating_files[gr][i], mask_padded=mask_padded, bar_threshold=bar_threshold, verbose=False, ) else: _slit = slit _waves, _header, _drz = utils.drizzle_slits_2d( [_slit], build_data=wcs_data, drizzle_params=drizzle_params ) to_ujy_list.append(to_ujy) if drz is None: drz = _drz else: drz.extend(_drz) slit.close() if reopen: _slit.close() drz_ids = np.array(drz_ids) # Are slitlets tagged as background? is_bkg = np.zeros(len(drz_ids), dtype=bool) ############ # Combined drizzled spectra # First pass - max-clipped median sci = np.array([d.data * to_ujy_list[i] for i, d in enumerate(drz)]) err = np.array([d.err * to_ujy_list[i] for i, d in enumerate(drz)]) if get_quick_data: return waves, slits, sci, err with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=RuntimeWarning) scimax = np.nanmax(sci, axis=0) flagged = (sci >= scimax) & (sci / err > max_sn_threshold) flagged |= (err <= 0) | (~np.isfinite(sci)) | (~np.isfinite(err)) flagged |= sci == 0 flagged |= ~np.isfinite(sci / err) flagged |= sci / err < sn_threshold for i in range(len(drz_ids)): ei = err[i, :, :] emi = np.isfinite(ei) & (ei > 0) emask = ei > err_threshold * np.median(ei[emi]) flagged[i, :, :] |= emask ivar = 1.0 / err**2 sci[flagged] = np.nan ivar[flagged] = 0 with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=RuntimeWarning) avg = np.nanmedian(sci, axis=0) avg_w = ivar.sum(axis=0) # Subsequent passes - weighted average with outlier rejection for _iter in range(3): sci = np.array( [d.data * to_ujy_list[i] for i, d in enumerate(drz)] ) err = np.array([d.err * to_ujy_list[i] for i, d in enumerate(drz)]) flagged = np.abs(sci - avg) * np.sqrt(avg_w) > outlier_threshold flagged |= (err <= 0) | (~np.isfinite(sci)) | (~np.isfinite(err)) flagged |= sci == 0 flagged |= ~np.isfinite(sci / err) flagged |= sci / err < sn_threshold for i in range(len(drz_ids)): ei = err[i, :, :] emi = np.isfinite(ei) & (ei > 0) emask = ei > err_threshold * np.median(ei[emi]) flagged[i, :, :] |= emask ivar = 1.0 / err**2 sci[flagged] = 0 ivar[flagged] = 0 # Weighted combination of unmasked pixels avg = (sci * ivar)[~is_bkg, :].sum(axis=0) / ivar[~is_bkg, :].sum( axis=0 ) avg_w = ivar[~is_bkg, :].sum(axis=0) # Set masked pixels to zero msk = ~np.isfinite(avg + avg_w) avg[msk] = 0 avg_w[msk] = 0 # Use master background if supplied if master_bkg is not None: if hasattr(master_bkg, "__len__"): bkg = master_bkg[0] bkg_w = master_bkg[1] elif master_bkg in [0, 0.0]: bkg = np.zeros_like(avg) bkg_w = np.zeros_like(avg) elif bkg_offset < 0: bkg = np.zeros_like(avg) bkg_w = np.zeros_like(avg) else: # Background by rolling full drizzled array bkg_num = avg * 0.0 bkg_w = avg * 0 for s in bkg_parity: bkg_num += np.roll(avg * avg_w, s * bkg_offset, axis=0) bkg_w += np.roll(avg_w, s * bkg_offset, axis=0) bkg_w[:bkg_offset] = 0 bkg_w[-bkg_offset:] = 0 bkg = bkg_num / bkg_w # Set masked back to nan avg[msk] = np.nan avg_w[msk] = np.nan # Trace center y0 = (avg.shape[0] - 1) // 2 to_source = y0 + offset_to_source h["SRCYPIX"] = to_source, "Expected row of source centering" # Valid data along wavelength axis xvalid = np.isfinite(avg).sum(axis=0) > 0 xvalid &= nd.binary_erosion( nd.binary_dilation(xvalid, iterations=2), iterations=4 ) # Build HDUList h["EXTNAME"] = "SCI" hdul = pyfits.HDUList([pyfits.PrimaryHDU()]) hdul.append(pyfits.ImageHDU(data=avg, header=h)) h["EXTNAME"] = "WHT" hdul.append(pyfits.ImageHDU(data=avg_w, header=h)) h["EXTNAME"] = "BKG" hdul.append(pyfits.ImageHDU(data=bkg, header=h)) h["EXTNAME"] = "WAVE" h["BUNIT"] = "micron" hdul.append(pyfits.ImageHDU(data=waves, header=h)) # output_key = f'{output}-{id}-{gr}' output_key = f"{output}_{gr}_{id}" # match spec.fits if output is not None: hdul.writeto( f"{output_key}.d2d.fits", overwrite=True, output_verify="fix" ) if imshow_kws["vmax"] is None: vmax = np.nanpercentile(avg, 95) * 2 print("xxx", vmax) imshow_kws["vmax"] = vmax imshow_kws["vmin"] = -0.1 * vmax reset_vmax = True else: reset_vmax = False # Make a figure if show_drizzled: dfig = show_drizzled_product(hdul, imshow_kws=imshow_kws) if output is not None: dfig.savefig(f"{output_key}.d2d.png") else: dfig = None if show_slits: sfig = show_drizzled_slits( slits, sci, ivar, hdul, imshow_kws=imshow_kws, with_background=(show_slits > 1), ) if output is not None: sfig.savefig(f"{output_key}.slit2d.png") else: sfig = None if reset_vmax: imshow_kws["vmax"] = None # Add to data dicts figs[gr] = dfig, sfig data[gr] = hdul wavedata[gr] = waves all_slits[gr] = slits drz_data[gr] = sci, ivar return figs, data, wavedata, all_slits, drz_data
[docs]def show_drizzled_slits( slits, sci, ivar, hdul, figsize=FIGSIZE, variable_size=True, imshow_kws=IMSHOW_KWS, with_background=False, ): """ Make a figure showing drizzled slitlets Parameters ---------- slits : list List of slitlet objects sci : (N,NY,NX) array Science array of drizzled slitlets ivar : (N,NY,NX) array Science array if inverse variance weights hdul : `~astropy.io.fits.HDUList` Drizzle-combined HDU figsize : tuple Figure size variable_size : bool, optional Whether to use variable size for the figure. imshow_kws : dict Keywords passed to `~matplotlib.pyplot.imshow` with_background : bool, optional Whether to include the background in the displayed slitlets. Returns ------- fig : Figure """ avg = hdul["SCI"].data xvalid = np.isfinite(avg).sum(axis=0) > 0 if xvalid.sum() > 1: xr = np.arange(avg.shape[1])[xvalid] else: xr = np.arange(avg.shape[1]) bkg = hdul["BKG"].data h = hdul["SCI"].header bkg_offset = np.abs(h["BKGOFF"]) x0 = h["SRCYPIX"] y0 = (avg.shape[0] - 1) // 2 msk = (ivar > 0) * 1.0 msk[ivar <= 0] = np.nan if variable_size: fs = (figsize[0], figsize[1] / 3 * len(slits)) else: fs = figsize fig, axes = plt.subplots( len(slits), 1, figsize=fs, sharex=True, sharey=True ) for i, slit in enumerate(slits): axes[i].imshow( (sci[i, :, :] - bkg * with_background) * msk[i, :, :], **imshow_kws ) axes[i].text( 0.02, 0.02 * figsize[1] / figsize[0] * len(slits) * 2, slit.meta.filename, ha="left", va="bottom", transform=axes[i].transAxes, bbox={"fc": "w", "alpha": 0.5, "ec": "None"}, fontsize=6, ) for ax in axes: ax.set_yticks([0, x0 - bkg_offset, x0, x0 + bkg_offset, avg.shape[0]]) ax.set_yticklabels([]) ax.grid() ax.set_xlim(xr[0] - 5, xr[-1] + 5) ax.set_ylim(y0 - 2 * bkg_offset, y0 + 2 * bkg_offset) ax.hlines(x0, *ax.get_xlim(), color="k", linestyle="-", alpha=0.1) ax.set_xlabel("pixel") axes[0].set_title(f"{h['SRCNAME']} {h['GRATING']}-{h['FILTER']}") fig.tight_layout(pad=0.5) return fig
[docs]def show_drizzled_product(hdul, figsize=FIGSIZE, imshow_kws=IMSHOW_KWS): """ Make a figure showing drizzled product Parameters ---------- hdul : `~astropy.io.fits.HDUList` Drizzle combined HDU figsize : tuple Figure size imshow_kws : dict kwargs for `~matplotlib.pyplot.imshow` Returns ------- fig : Figure Figure object """ avg = hdul["SCI"].data xvalid = np.isfinite(avg).sum(axis=0) > 0 # xr = np.arange(avg.shape[1])[xvalid] if xvalid.sum() > 1: xr = np.arange(avg.shape[1])[xvalid] else: xr = np.arange(avg.shape[1]) bkg = hdul["BKG"].data h = hdul["SCI"].header bkg_offset = np.abs(h["BKGOFF"]) x0 = h["SRCYPIX"] y0 = (avg.shape[0] - 1) // 2 fig, axes = plt.subplots(3, 1, figsize=figsize, sharex=True, sharey=True) axes[0].imshow(avg, **imshow_kws) axes[1].imshow(avg - bkg, **imshow_kws) axes[2].imshow(bkg, **imshow_kws) # Labels for i, label in enumerate(["Data", "Cleaned", "Background"]): axes[i].text( 0.02, 0.02 * figsize[1] / figsize[0] * 3 * 2, label, ha="left", va="bottom", transform=axes[i].transAxes, bbox={"fc": "w", "alpha": 0.5, "ec": "None"}, fontsize=8, ) for ax in axes: ax.set_yticks([0, x0 - bkg_offset, x0, x0 + bkg_offset, avg.shape[0]]) ax.set_yticklabels([]) ax.hlines(x0, *ax.get_xlim(), color="k", linestyle="-", alpha=0.1) ax.grid() ax.set_xlim(xr[0] - 5, xr[-1] + 5) ax.set_ylim(y0 - 2 * bkg_offset, y0 + 2 * bkg_offset) ax.set_xlabel("pixel") axes[0].set_title(f"{h['SRCNAME']} {h['GRATING']}-{h['FILTER']}") fig.tight_layout(pad=0.5) return fig
[docs]def get_xlimits_from_lines( hdul, sn_thresh=2, max_dy=4, n_erode=2, n_dilate=4, pad=10, verbose=True, ): """ Find emission lines in 2D spectrum Parameters ---------- hdul : HDUList The HDUList object containing the 2D spectrum. sn_thresh : float, optional The signal-to-noise threshold for identifying emission lines. Default is 2. max_dy : float, optional The maximum deviation in the y-direction from the central pixel. Default is 4. n_erode : int, optional The number of iterations for binary erosion. Default is 2. n_dilate : int, optional The number of iterations for binary dilation. Default is 4. pad : int, optional The padding value for the x-limits. Default is 10. verbose : bool, optional Whether to print verbose output. Default is True. Returns ------- xlim : tuple A tuple containing the x-limits of the emission lines. """ import scipy.ndimage as nd sh = hdul["SCI"].data.shape yp, xp = np.indices(sh) if "SRCYPIX" in hdul["SCI"].header: y0 = hdul["SCI"].header["SRCYPIX"] else: y0 = sh[0] / 2 msk = hdul["SCI"].data * np.sqrt(hdul["WHT"].data) > sn_thresh msk &= np.abs(yp - y0) < max_dy msk_erode = nd.binary_erosion(msk, iterations=n_erode) msk_dilate = nd.binary_dilation(msk_erode, iterations=n_dilate) if msk_dilate.sum() == 0: xlim = (0, sh[1]) else: xpx = xp[msk_dilate] xlim = np.clip([xpx.min() - pad, xpx.max() + pad], 0, sh[1]).tolist() msg = f"msaexp.drizzle.get_xlimits_from_lines: {msk_dilate.sum()} pixels, " msg += f"slice: {xlim}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) return xlim
[docs]def make_optimal_extraction( waves, sci2d, wht2d, var2d=None, profile_slice=None, prf_center=None, prf_sigma=1.0, sigma_bounds=(0.5, 2.5), center_limit=4, fix_center=False, fix_sigma=False, trim=0, bkg_offset=6, bkg_parity=[-1, 1], offset_for_chi2=1.0, max_wht_percentile=None, max_med_wht_factor=10, verbose=True, ap_radius=None, ap_center=None, **kwargs, ): """ Optimal extraction from 2D arrays Parameters ---------- waves : 1D array Wavelengths, microns sci2d : 2D array Data array wht2d : 2D array Inverse variance weight array profile_slice : tuple, slice Slice along wavelength axis where to determine the cross-dispersion profile. If a tuple of floats, interpret as wavelength limits in microns prf_center : float Profile center, relative to the cross-dispersion center of the array. If `None`, then try to estimate it from the data prf_sigma : float Width of the extraction profile in pixels sigma_bounds : (float, float) Parameter bounds for `prf_sigma` center_limit : float Maximum offset from `prf_center` allowed fix_center : bool Fix the centering in the fit fix_sigma : bool Fix the width in the fit trim : int Number of pixels to trim from the edges of the extracted spectrum bkg_offset, bkg_parity : int, list Parameters for the local background determination (see `~msaexp.drizzle.drizzle_slitlets`). The profile is "subtracted" in the same way as the data. offset_for_chi2 : float If specified, compute chi2 of the profile fit offseting the first parameter by +/- this value max_wht_percentile : float Maximum percentile of WHT to consider valid max_med_wht_factor : float Maximum weight value relative to the median nonzero weight to consider valid verbose : bool Status messages ap_center, ap_radius : int, int Center and radius of fixed-width aperture extraction, in pixels. If not specified, then .. code-block:: python :dedent: >>> ap_center = int(np.round(ytrace + fit_center)) >>> ap_radius = np.clip(int(np.round(prof_sigma*2.35/2)), 1, 3) kwargs : dict Ignored keyword args Returns ------- sci2d_out : array Output 2D sci array wht2d_out : array Output 2D wht array profile2d : array 2D optimal extraction profile spec : `~astropy.table.Table` Optimally-extracted 1D spectrum prof_tab : `~astropy.table.Table` Table of the collapsed 1D profile """ import scipy.ndimage as nd import astropy.units as u from scipy.optimize import least_squares from .version import __version__ as msaexp_version sh = wht2d.shape yp, xp = np.indices(sh) ok = np.isfinite(sci2d * wht2d) & (wht2d > 0) if max_wht_percentile is not None: wperc = np.percentile(wht2d[ok], max_wht_percentile) ok &= wht2d < wperc if max_med_wht_factor is not None: med_wht = np.nanmedian(wht2d[ok]) ok &= wht2d < max_med_wht_factor * med_wht if var2d is None: wht_mask = wht2d * 1 else: wht_mask = 1.0 / var2d wht_mask[~ok] = 0.0 if profile_slice is not None: if not isinstance(profile_slice, slice): if isinstance(profile_slice[0], int): # pixels profile_slice = slice(*profile_slice) else: # Wavelengths interpolated on pixel grid xpix = np.arange(sh[1]) xsl = np.round(np.interp(profile_slice, waves, xpix)).astype(int) xsl = np.clip(xsl, 0, sh[1]) print(f"Wavelength slice: {profile_slice} > {xsl} pix") profile_slice = slice(*xsl) prof1d = np.nansum((sci2d * wht_mask)[:, profile_slice], axis=1) prof1d /= np.nansum(wht_mask[:, profile_slice], axis=1) slice_limits = profile_slice.start, profile_slice.stop pmask = ok & True pmask[:, profile_slice] &= True ok &= pmask else: prof1d = np.nansum(sci2d * wht_mask, axis=1) / np.nansum( wht_mask, axis=1 ) slice_limits = 0, sh[1] xpix = np.arange(sh[0]) ytrace = (sh[0] - 1) / 2.0 x0 = np.arange(sh[0]) - ytrace if prf_center is None: prf_center = np.nanargmax(prof1d) - (sh[0] - 1) / 2.0 if verbose: print(f"Set prf_center: {prf_center} {sh} {ok.sum()}") msg = "msaexp.drizzle.extract_from_hdul: Initial center = " msg += f" {prf_center:6.2f}, sigma = {prf_sigma:6.2f}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) ############# # Integrated gaussian profile fit_type = 3 - 2 * fix_center - 1 * fix_sigma wht_mask[~ok] = 0.0 p00_name = None if fit_type == 0: args = ( waves, sci2d, wht_mask, prf_center, prf_sigma, bkg_offset, bkg_parity, 3, 1, (verbose > 1), ) pnorm, pmodel = utils.objfun_prf([prf_center, prf_sigma], *args) profile2d = pmodel / pnorm pmask = (profile2d > 0) & np.isfinite(profile2d) profile2d[~pmask] = 0 fit_center = prf_center fit_sigma = prf_sigma else: # Fit it if fix_sigma: p00_name = "center" p0 = [prf_center] bounds = (-center_limit, center_limit) elif fix_center: p00_name = "sigma" p0 = [prf_sigma] bounds = sigma_bounds else: p00_name = "center" p0 = [prf_center, prf_sigma] bounds = ( (-center_limit + prf_center, sigma_bounds[0]), (center_limit + prf_center, sigma_bounds[1]), ) args = ( waves, sci2d, wht_mask, prf_center, prf_sigma, bkg_offset, bkg_parity, fit_type, 1, (verbose > 1), ) lmargs = ( waves, sci2d, wht_mask, prf_center, prf_sigma, bkg_offset, bkg_parity, fit_type, 2, (verbose > 1), ) _res = least_squares( utils.objfun_prf, p0, args=lmargs, method="trf", bounds=bounds, loss="huber", ) # dchi2 / dp0 if offset_for_chi2 is not None: chiargs = ( waves, sci2d, wht_mask, prf_center, prf_sigma, bkg_offset, bkg_parity, fit_type, 3, (verbose > 1), ) delta = _res.x * 0.0 dchi2dp = [] for d in [-offset_for_chi2, 0.0, offset_for_chi2]: delta[0] = d dchi2dp.append(utils.objfun_prf(_res.x + delta, *chiargs)) msg = f"msaexp.drizzle.extract_from_hdul: dchi2/d{p00_name} = " dchi = dchi2dp[0] - dchi2dp[1] dchi += dchi2dp[2] - dchi2dp[1] msg += f"{dchi/2.:.1f}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) else: dchi2dp = None pnorm, pmodel = utils.objfun_prf(_res.x, *args) profile2d = pmodel / pnorm pmask = (profile2d > 0) & np.isfinite(profile2d) profile2d[~pmask] = 0 if fix_sigma: fit_center = _res.x[0] fit_sigma = prf_sigma elif fix_center: fit_sigma = _res.x[0] fit_center = prf_center else: fit_center, fit_sigma = _res.x wht1d = np.nansum(wht_mask * profile2d**2, axis=0) sci1d = np.nansum(sci2d * wht_mask * profile2d, axis=0) / wht1d if profile_slice is not None: pfit1d = np.nansum( (wht_mask * profile2d * sci1d)[:, profile_slice], axis=1 ) pfit1d /= np.nansum((wht_mask)[:, profile_slice], axis=1) else: pfit1d = np.nansum(profile2d * sci1d * wht_mask, axis=1) pfit1d /= np.nansum(wht_mask, axis=1) if trim > 0: bad = nd.binary_dilation(wht1d <= 0, iterations=trim) wht1d[bad] = 0 sci1d[wht1d <= 0] = 0 err1d = np.sqrt(1 / wht1d) err1d[wht1d <= 0] = 0 ####### # Make tables # Flux conversion to_ujy = 1.0 spec = grizli.utils.GTable() spec.meta["VERSION"] = msaexp_version, "msaexp software version" spec.meta["TOMUJY"] = to_ujy, "Conversion from pixel values to microJansky" spec.meta["PROFCEN"] = fit_center, "PRF profile center" spec.meta["PROFSIG"] = fit_sigma, "PRF profile sigma" spec.meta["PROFSTRT"] = slice_limits[0], "Start of profile slice" spec.meta["PROFSTOP"] = slice_limits[1], "End of profile slice" spec.meta["YTRACE"] = ytrace, "Expected center of trace" spec.meta["MAXWPERC"] = max_wht_percentile, "Maximum weight percentile" spec.meta["MAXWFACT"] = max_med_wht_factor, "Maximum weight factor" prof_tab = grizli.utils.GTable() prof_tab.meta["VERSION"] = msaexp_version, "msaexp software version" prof_tab["pix"] = x0 prof_tab["profile"] = prof1d prof_tab["pfit"] = pfit1d prof_tab.meta["PROFCEN"] = fit_center, "PRF profile center" prof_tab.meta["PROFSIG"] = fit_sigma, "PRF profile sigma" prof_tab.meta["PROFSTRT"] = slice_limits[0], "Start of profile slice" prof_tab.meta["PROFSTOP"] = slice_limits[1], "End of profile slice" prof_tab.meta["YTRACE"] = ytrace, "Expected center of trace" if (dchi2dp is not None) & (p00_name is not None): prof_tab.meta["DCHI2PAR"] = (p00_name, "Parameter for dchi2/dparam") prof_tab.meta["CHI2A"] = ( dchi2dp[0], "Chi2 with d{p00_name} = -{offset_for_chi2}", ) prof_tab.meta["CHI2B"] = dchi2dp[1], "Chi2 with dparam = 0. (best fit)" prof_tab.meta["CHI2C"] = ( dchi2dp[2], "Chi2 with d{p00_name} = +{offset_for_chi2}", ) spec["wave"] = waves spec["wave"].unit = u.micron spec["flux"] = sci1d * to_ujy spec["err"] = err1d * to_ujy spec["flux"].unit = u.microJansky spec["err"].unit = u.microJansky # Aperture extraction if ap_center is None: ap_center = int(np.round(ytrace + fit_center)) elif ap_center < 0: ap_center = np.nanargmax(prof1d) if ap_radius is None: ap_radius = np.clip(int(np.round(fit_sigma * 2.35 / 2)), 1, 3) sly = slice(ap_center - ap_radius, ap_center + ap_radius + 1) aper_sci = np.nansum(sci2d[sly, :], axis=0) aper_var = np.nansum(1.0 / wht_mask[sly, :], axis=0) aper_corr = np.nansum(profile2d, axis=0) / np.nansum( profile2d[sly, :], axis=0 ) spec["aper_flux"] = aper_sci * to_ujy spec["aper_err"] = np.sqrt(aper_var) * to_ujy spec["aper_corr"] = aper_corr spec["aper_flux"].unit = u.microJansky spec["aper_err"].unit = u.microJansky spec["aper_flux"].description = ( f"Flux in trace aperture ({ap_center}, {ap_radius})" ) spec["aper_err"].description = ( "Flux uncertainty in trace aperture " f"({ap_center}, {ap_radius})" ) spec.meta["APER_Y0"] = (ap_center, "Fixed aperture center") spec.meta["APER_DY"] = (ap_radius, "Fixed aperture radius, pix") msg = "msaexp.drizzle.extract_from_hdul: aperture extraction = " msg += f"({ap_center}, {ap_radius})" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) msk = np.isfinite(sci2d + wht_mask) sci2d[~msk] = 0 wht_mask[~msk] = 0 return sci2d * to_ujy, wht2d / to_ujy**2, profile2d, spec, prof_tab
[docs]def extract_from_hdul( hdul, prf_center=None, master_bkg=None, verbose=True, line_limit_kwargs={}, **kwargs, ): """ Run 1D extraction on arrays from a combined dataset Parameters ---------- hdul : `~astropy.io.fits.HDUList` Output data from from `~msaexp.drizzle.drizzle_slitlets` prf_center : float, None Initial profile center. If not specified, get from the `'SRCYPIX'` keyword in `hdul['SCI'].header` master_bkg : array Optional master background to use instead of `hdul['BKG'].data` verbose : bool Printing status messages line_limit_kwargs : dict Keyword arguments passed to `msaexp.drizzle.get_xlimits_from_lines` kwargs : dict Keyword arguments passed to `msaexp.drizzle.make_optimal_extraction` Returns ------- outhdu : `~astropy.io.fits.HDUList` Modified HDU including 1D extraction """ if master_bkg is None: if "BKG" in hdul: bkg_i = hdul["BKG"].data else: bkg_i = None else: bkg_i = master_bkg sci = hdul["SCI"] sci2d = sci.data * 1 if bkg_i is not None: sci2d -= bkg_i wht2d = hdul["WHT"].data * 1 if "WAVE" in hdul: waves = hdul["WAVE"].data elif "SPEC1D" in hdul: tab = grizli.utils.read_catalog(hdul["SPEC1D"]) waves = tab["wave"].data else: _gr = sci.header["GRATING"].lower() waves = utils.get_standard_wavelength_grid( _gr, sample=sci.header["WSAMPLE"], log_step=sci.header["LOGWAVE"] ) if line_limit_kwargs: kwargs["profile_slice"] = get_xlimits_from_lines( hdul, **line_limit_kwargs ) if prf_center is None: if "SRCYPIX" in sci.header: y0 = sci.header["SRCYPIX"] else: y0 = (sci.data.shape[0] - 1) / 2 prf_center = y0 - (sci.data.shape[0] - 1) / 2.0 _data = make_optimal_extraction( waves, sci2d, wht2d, prf_center=prf_center, verbose=verbose, **kwargs, ) _sci2d, _wht2d, profile2d, spec, prof = _data hdul = pyfits.HDUList() hdul.append(pyfits.BinTableHDU(data=spec, name="SPEC1D")) header = sci.header for k in spec.meta: header[k] = spec.meta[k] msg = "msaexp.drizzle.extract_from_hdul: Output center = " msg += f" {header['PROFCEN']:6.2f}, sigma = {header['PROFSIG']:6.2f}" grizli.utils.log_comment( grizli.utils.LOGFILE, msg, verbose=verbose, show_date=False ) hdul.append(pyfits.ImageHDU(data=sci2d, header=header, name="SCI")) hdul.append(pyfits.ImageHDU(data=wht2d, header=header, name="WHT")) hdul.append(pyfits.ImageHDU(data=profile2d, header=header, name="PROFILE")) hdul.append(pyfits.BinTableHDU(data=prof, name="PROF1D")) for k in hdul["SCI"].header: if k not in hdul["SPEC1D"].header: hdul["SPEC1D"].header[k] = hdul["SCI"].header[k] return hdul