Source code for jwst.residual_fringe.residual_fringe

"""Apply residual fringe correction."""

import logging
import warnings

import numpy as np
from astropy.io import ascii as astropy_ascii
from astropy.io import fits
from astropy.table import Table
from stdatamodels import fits_support
from stdatamodels.jwst import datamodels

from jwst.residual_fringe import utils
from jwst.stpipe import Step

log = logging.getLogger(__name__)

# Noise factor for DER_SNR spectroscopic signal-to-noise calculation
# (see Stoehr, ADASS 2008: https://archive.stsci.edu/vodocs/der_snr.pdf)
DER_SNR_FACTOR = 1.482602 / np.sqrt(6)

__all__ = ["ResidualFringeCorrection", "NoFringeFlatError"]


[docs] class ResidualFringeCorrection: """ Calculate and apply correction for residual fringes. Parameters ---------- input_model : `~stdatamodels.jwst.datamodels.IFUImageModel` Input data to correct; updated in-place. residual_fringe_reference_file : str Path to FRINGEFREQ reference file. regions_reference_file : str Path to REGIONS reference file. ignore_regions : dict Wavelength regions to ignore. Keys are "num", "min", and "max". Values are the number of regions specified (int), the list of minimum wavelength values, and the list of maximum wavelength values. Length of minimum and maximum lists must match. save_intermediate_results : bool, optional If `True`, intermediate files are saved to disk. transmission_level : int, optional The transmission level used to extract the appropriate region definitions from the REGIONS reference file. make_output_path : callable or None, optional If provided, is used to create the output file names when ``save_intermediate_results`` is `True`. If None, filenames are created with the default ``Step.make_output_path`` method. """ def __init__( self, input_model, residual_fringe_reference_file, regions_reference_file, ignore_regions, save_intermediate_results=False, transmission_level=80, make_output_path=None, ): self.input_model = input_model self.residual_fringe_reference_file = residual_fringe_reference_file self.regions_reference_file = regions_reference_file self.ignore_regions = ignore_regions self.save_intermediate_results = save_intermediate_results self.transmission_level = transmission_level # define how filenames are created if make_output_path is None: self.make_output_path = Step().make_output_path else: self.make_output_path = make_output_path self.rfc_factors = None self.fit_mask = None self.weighted_pix_num = None self.rejected_fit = None self.weights_feat = None self.input_weights = None self.max_amp = None self.freq_table = None self.slice_map = None self.background_fit = None self.knot_locations = None self.band = None self.channel = None # used to create additional data that can be plotted outside of step self.diagnostic_mode = True
[docs] def do_correction(self): """ Apply residual fringe correction to a copy of input model. Returns ------- output_model : `~stdatamodels.jwst.datamodels.IFUImageModel` Datamodel with correction applied. """ # Check that the fringe flat has been applied if self.input_model.meta.cal_step.fringe != "COMPLETE": raise NoFringeFlatError( f"The fringe flat step has not been run on file {self.input_model.meta.filename}" ) # Remove any NaN values and flagged DO_NOT_USE pixels from the data prior to processing # Set them to 0 for the residual fringe routine # They will be re-added at the end output_data = self.input_model.data.copy() dnu = datamodels.dqflags.pixel["DO_NOT_USE"] nanval_indx = np.where( np.logical_or( np.bitwise_and(self.input_model.dq, dnu).astype(bool), ~np.isfinite(output_data) ) ) output_data[nanval_indx] = 0 # normalise the output_data to remove units pos_data = self.input_model.data[self.input_model.data > 0] normalization_factor = np.median(pos_data) output_data /= normalization_factor # Load the fringe reference file residual_fringe_model = datamodels.FringeFreqModel(self.residual_fringe_reference_file) # read in the band band = self.input_model.meta.instrument.band.lower() if band == "short": residual_fringe_table = residual_fringe_model.rfc_freq_short_table elif band == "medium": residual_fringe_table = residual_fringe_model.rfc_freq_medium_table else: residual_fringe_table = residual_fringe_model.rfc_freq_long_table self.max_amp = residual_fringe_model.max_amp_table residual_fringe_model.close() self.freq_table = residual_fringe_table # Read in the regions reference file # Use throughput array defined by self.transmission_level allregions = datamodels.RegionsModel(self.regions_reference_file) self.transmission_level = int(self.transmission_level / 10) slice_map = (allregions.regions)[self.transmission_level - 1, :, :].copy() log.info(f" Using {self.transmission_level} throughput threshold.") self.slice_map = slice_map # set up the channels for the detector detector = self.input_model.meta.instrument.detector.lower() if "short" in detector: self.channels = [1, 2] elif "long" in detector: self.channels = [3, 4] log.info(f"Detector {detector} {self.channels} ") self.input_weights = self.calc_weights() self.weights_feat = self.input_weights.copy() self.rfc_factors = np.zeros(self.input_model.data.shape) self.fit_mask = np.zeros(self.input_model.data.shape) self.weighted_pix_num = np.zeros(self.input_model.data.shape) self.rejected_fit = np.zeros(self.input_model.data.shape) self.background_fit = np.zeros(self.input_model.data.shape) self.knot_locations = np.full_like(self.input_model.data, np.nan) allregions.close() # intermediate output product - Tables stat_table = Table( names=( "Slice", "mean", "median", "stddev", "max", "pmean", "pmedian", "pstddev", "pmax", ), dtype=("i4", "f8", "f8", "f8", "f8", "f8", "f8", "f8", "f8"), ) out_table = Table( names=( "Slice", "col", "fringe", "sn", "periodogram_res", "opt_fringes", "peak_freq", "freq_min", "freq_max", ), dtype=("i4", "i4", "i4", "f8", "f8", "f8", "f8", "f8", "f8"), ) wave_map = self._get_wave_map() for c in self.channels: num_corrected = 0 log.info(f"Processing channel {c}") (slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks) = utils.slice_info( slice_map, c ) log.debug(f" Slice Ranges {slice_x_ranges}") # if the user wants to ignore some values, use the wave_map # array to set the corresponding weight values to 0 if self.ignore_regions["num"] > 0: for r in range(self.ignore_regions["num"]): min_wave = self.ignore_regions["min"][r] max_wave = self.ignore_regions["max"][r] self.input_weights[((wave_map > min_wave) & (wave_map < max_wave))] = 0 for n, ss in enumerate(slices_in_channel): log.info(f" Processing slice {ss} =================================") log.debug(f" X ranges of slice {slice_x_ranges[n, 1]} {slice_x_ranges[n, 2]}") # use the mask to set all out-of-slice pixels to 0 in wmap and data # set out-of-slice pixels to 0 in arrays ss_data = all_slice_masks[n] * output_data ss_wmap = all_slice_masks[n] * wave_map ss_weight = all_slice_masks[n] * self.input_weights # get the freq_table info for this slice this_row = np.where(self.freq_table["slice"] == float(ss))[0][0] log.debug(f"Row in reference file for slice {this_row}") slice_row = self.freq_table[(self.freq_table["slice"] == float(ss))] ffreq = slice_row["ffreq"][0] dffreq = slice_row["dffreq"][0] max_nfringes = slice_row["max_nfringes"][0] min_snr = slice_row["min_snr"][0] pgram_res = slice_row["pgram_res"][0] # cycle through the cols and fit the fringes for col in np.arange(slice_x_ranges[n, 1], slice_x_ranges[n, 2]): col_data = ss_data[:, col] col_wmap = ss_wmap[:, col] # because of the curvature of the slices there can be # large regions not falling on a column valid = np.logical_and((col_wmap > 0), ~np.isnan(col_wmap)) num_good = len(np.where(valid)[0]) # Need at least 50 pixels in column to proceed if num_good <= 50: continue test_flux = col_data[valid] test_flux[test_flux < 0] = 1e-08 # Transform wavelength in micron to wavenumber in cm^-1. with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) col_wnum = 10000.0 / col_wmap # use the error array to get col snr, used to remove noisy pixels col_snr = self.input_model.data[:, col] / self.input_model.err[:, col] # Do some checks on column to make sure there is # reasonable signal. If the SNR < min_snr (CDP), pass n = len(test_flux) signal = np.nanmean(test_flux) noise = DER_SNR_FACTOR * np.nanmedian( np.abs(2.0 * test_flux[2 : n - 2] - test_flux[0 : n - 4] - test_flux[4:n]) ) snr2 = 0.0 if noise != 0: snr2 = signal / noise # Sometimes can return nan, inf for bad data so include this in check if snr2 < min_snr[0]: log.debug(f"SNR too low; not fitting column {col}, {snr2}, {min_snr[0]}") continue log.debug(f"Fitting column {col}") log.debug(f"SNR > {min_snr[0]} ") col_weight = ss_weight[:, col] col_max_amp = np.interp( col_wmap, self.max_amp["Wavelength"], self.max_amp["Amplitude"] ) col_snr2 = np.where(col_snr > 10, 1, 0) # hardcoded at snr > 10 for now # Double the max amplitude col_max_amp *= 2 # get the in-slice pixel indices for replacing in output later idx = np.where(col_data > 0) # BayesicFitting doesn't like zeros at data or weight array # edges so set zeros to an arbitrarily small value col_data[col_data <= 0] = 1e-08 col_weight[col_weight <= 0] = 1e-08 # Check for off-slice pixels and send to be filled with # interpolated/extrapolated wnums to stop BayesicFitting from # crashing. They will not be fitted anyway. found_bad = np.logical_or(np.isnan(col_wnum), np.isinf(col_wnum)) num_bad = len(np.where(found_bad)[0]) if num_bad > 0: col_wnum[found_bad] = 0 col_wnum = utils.fill_wavenumbers(col_wnum) # do feature finding on slice now column-by-column log.debug(" Starting feature finding") # narrow features (similar or less than fringe #1 period) # find spectral features (env is spline fit of troughs and peaks) env, l_x, l_y, _, _, _ = utils.fit_envelope( np.arange(col_data.shape[0]), col_data ) mod = np.abs(col_data / env) - 1 # Use col_snr to ignore noisy pixels: # given signal in mod, find location of # lines > col_max_amp (fringe contrast) weight_factors = utils.find_lines(mod * col_snr2, col_max_amp) weights_feat = col_weight * weight_factors # account for fringe 2 on broad features in channels 3 and 4 # need to smooth out the dichroic fringe as it breaks # the feature finding method if c in [3, 4]: # smoothing window hardcoded to 7 for now (based on testing) win = 7 cumsum = np.cumsum(np.insert(col_data, 0, 0)) sm_col_data = (cumsum[win:] - cumsum[:-win]) / float(win) # find spectral features (env is spline fit of troughs and peaks) env, l_x, l_y, _, _, _ = utils.fit_envelope( np.arange(col_data.shape[0]), sm_col_data ) mod = np.abs(col_data / env) - 1 # given signal in mod find location of lines > col_max_amp weight_factors = utils.find_lines(mod, col_max_amp) weights_feat *= weight_factors # iterate over the fringe components to fit, initialize other output arrays # in case fit fails proc_data = col_data.copy() proc_factors = np.ones(col_data.shape) bg_fit = col_data.copy() res_fringe_fit_flag = np.zeros(col_data.shape) wpix_num = 1024 # check the end points. A single value followed by gap of zero can cause # problems in the fitting. index = np.where(weights_feat != 0.0) length = np.diff(index[0]) if weights_feat[0] != 0 and length[0] > 1: weights_feat[0] = 1e-08 if weights_feat[-1] != 0 and length[-1] > 1: weights_feat[-1] = 1e-08 # jane added this - fit can fail in evidence function. # once we replace evidence function with astropy routine - we can test # removing setting weights < 0.003 to zero (1e-08) weights_feat[weights_feat <= 0.003] = 1e-08 # currently the reference file fits one fringe originating in the # detector pixels, and a second high frequency, low amplitude fringe # in channels 3 and 4 which has been attributed to the dichroics. try: for fn, ff in enumerate(ffreq): # ignore place holder fringes if ff <= 1e-03: continue # check if snr criteria is met for fringe component, # should always be true for fringe 1 if snr2 <= min_snr[fn]: continue log.debug(f" Start ffreq = {ff}") log.debug(" Fit spectral baseline") bg_fit, bgindx = utils.fit_1d_background_complex( proc_data, weights_feat, col_wnum, ffreq=ffreq[fn], channel=c, ) # get the residual fringes as fraction of signal res_fringes = np.divide( proc_data, bg_fit, out=np.zeros_like(proc_data), where=bg_fit != 0, ) np.subtract(res_fringes, 1, out=res_fringes, where=res_fringes != 0) res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) # fit the residual fringes log.debug(" Set up Bayes evidence") ( res_fringe_fit, wpix_num, opt_nfringe, peak_freq, freq_min, freq_max, ) = utils.fit_1d_fringes_bayes_evidence( res_fringes, weights_feat, col_wnum, ffreq[fn], dffreq[fn], max_nfringes[fn], pgram_res[fn], col_snr2, ) # check for fit blowing up, reset rfc fit to 0, raise a flag log.debug(" Check residual fringe fit for bad fit regions") res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes( res_fringe_fit, col_max_amp ) # correct for residual fringes log.debug(" Divide out residual fringe fit") _, _, _, env, u_x, u_y = utils.fit_envelope( np.arange(res_fringe_fit.shape[0]), res_fringe_fit ) rfc_factors = 1 / ( res_fringe_fit * (col_weight > 1e-05).astype(int) + 1 ) proc_data *= rfc_factors proc_factors *= rfc_factors # handle nans or infs that may exist proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) proc_data[proc_data < 0] = 1e-08 out_table.add_row( ( ss, col, fn, snr2, pgram_res[fn], opt_nfringe, peak_freq, freq_min, freq_max, ) ) # define fringe sub after all fringe components corrections fringe_sub = proc_data.copy() rfc_factors = proc_factors.copy() # get the residual fringes as fraction of signal pbg_fit, pbgindx = utils.fit_1d_background_complex( fringe_sub, weights_feat, col_wnum, ffreq=ffreq[0], channel=c ) fit_res = np.divide( fringe_sub, pbg_fit, out=np.zeros_like(fringe_sub), where=pbg_fit != 0, ) np.subtract(fit_res, 1, out=fit_res, where=fit_res != 0) fit_res *= np.where(col_weight > 1e-07, 1, 1e-08) out_table.add_row( ( ss, col, fn, snr2, pgram_res[0], opt_nfringe, peak_freq, freq_min, freq_max, ) ) # replace the corrected in-slice column pixels in the data_cor array log.debug(" Update the trace pixels in the output") output_data[idx, col] = fringe_sub[idx] self.rfc_factors[idx, col] = rfc_factors[idx] self.fit_mask[idx, col] = np.ones(1024)[idx] self.weights_feat[idx, col] = weights_feat[idx] self.weighted_pix_num[idx, col] = np.ones(1024)[idx] * (wpix_num / 1024) self.rejected_fit[idx, col] = res_fringe_fit_flag[idx] self.background_fit[idx, col] = bg_fit[idx] self.knot_locations[: bgindx.shape[0], col] = bgindx num_corrected = num_corrected + 1 except Exception as e: log.warning(f" Skipping col={col} {ss}:") log.warning(f" {str(e)}") del ss_data, ss_wmap, ss_weight # end of column del slice_x_ranges, all_slice_masks, slices_in_channel # end of channel log.info(f"Number of columns corrected for channel {num_corrected}") log.info("Processing complete") # add units back to output data log.debug("Adding units back to output array") output_data *= normalization_factor # Add NaNs back to output data output_data[nanval_indx] = np.nan self.input_model.data = output_data del output_data if self.save_intermediate_results: stat_table_name = self.make_output_path( basepath=self.input_model.meta.filename, suffix="stat_table", ext=".ecsv" ) log.info(f"Saving intermediate stats table {stat_table_name}") astropy_ascii.write( stat_table, stat_table_name, format="ecsv", fast_writer=False, overwrite=True ) out_table_name = self.make_output_path( basepath=self.input_model.meta.filename, suffix="out_table", ext=".ecsv" ) log.info(f"Saving intermediate output table {out_table_name}") astropy_ascii.write( out_table, out_table_name, format="ecsv", fast_writer=False, overwrite=True ) fit_results_name = self.make_output_path( basepath=self.input_model.meta.filename, suffix="fit_results", ext=".fits" ) log.info(f"Saving intermediate fit results output {fit_results_name}") # Get a primary header from the input model hdul = fits_support.to_fits(self.input_model.instance, self.input_model._schema) # noqa: SLF001 hdr = hdul[0].header hdul.close() hdu0 = fits.PrimaryHDU(header=hdr) hdu1 = fits.ImageHDU(self.rfc_factors, name="RFC_FACTORS") hdu2 = fits.ImageHDU(self.fit_mask, name="FIT_MASK") hdu3 = fits.ImageHDU(self.weights_feat, name="WEIGHTS_FEATURES") hdu4 = fits.ImageHDU(self.weighted_pix_num, name="WEIGHTED_PIXEL_FRACTION") hdu5 = fits.ImageHDU(self.background_fit, name="BACKGROUND_FIT") hdu6 = fits.ImageHDU(self.knot_locations, name="KNOT_LOCATIONS") hdu = fits.HDUList([hdu0, hdu1, hdu2, hdu3, hdu4, hdu5, hdu6]) hdu.writeto(fit_results_name, overwrite=True) hdu.close() return self.input_model
[docs] def calc_weights(self): """ Make a weights array based on flux. This is a placeholder function. For now, it just returns a normalised flux array to use as a weights array. This is because any smoothing results in incorrect fringe correction around emission lines. This can be changed in the future if need be. Returns ------- weights : ndarray Weights array. """ weights = np.zeros(self.input_model.data.shape) for c in np.arange(weights.shape[1]): flux_1d = self.input_model.data[:, c] with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) w = flux_1d / np.nanmean(flux_1d) weights[:, c] = w # replace infs and nans in weights with 0 weights[weights == np.inf] = 0 weights[np.isnan(weights)] = 0 return weights
def _get_wave_map(self): """ Get a wavelength map from the input WCS. Returns ------- ndarray 2D map of wavelengths matching self.input.data. """ ysize = self.input_model.data.shape[0] xsize = self.input_model.data.shape[1] y, x = np.mgrid[:ysize, :xsize] _, _, wave_map = self.input_model.meta.wcs(x, y) return wave_map
[docs] class NoFringeFlatError(Exception): """Error raised when the input has not been fringe flat corrected.""" pass