Source code for corgidrp.combine

"""
Module to support frame combination
"""
import warnings
import numpy as np
import corgidrp.data as data
from pyklip.klip import rotate
import corgidrp

[docs] def combine_images(data_subset, err_subset, dq_subset, collapse, num_frames_scaling, other_hdus=None): """ Combines several images together Args: data_subset (np.array): 3-D array of N 2-D images err_subset (np.array): 4-D array of N 3-D error maps dq_subset (np.array): 3-D array of N 2-D DQ maps collapse (str): "mean" or "median". num_frames_scaling (bool): Multiply by number of frames in sequence in order to ~conserve photons other_hdus (list of HDULists, optional): list of other HDULists to be combined in the same way Returns: tuple: combined images np.array: 2-D array of combined images np.array: 3-D array of combined error map np.array: 2-D array of combined DQ maps list of np.array: list of the combined data of other HUDs """ tot_frames = data_subset.shape[0] # mask bad pixels bad = np.where(dq_subset > 0) data_subset[bad] = np.nan err_subset[bad[0],:,bad[1],bad[2]] = np.nan # track the number of good values that go into the combination n_samples = np.ones(data_subset.shape) n_samples[bad] = 0 n_samples = np.sum(n_samples, axis=0) if collapse.lower() == "mean": with warnings.catch_warnings(): # prevent RuntimeWarning: Mean of empty slice warnings.filterwarnings('ignore', category=RuntimeWarning) data_collapse = np.nanmean(data_subset, axis=0) err_collapse = np.sqrt(np.nanmean(err_subset**2, axis=0)) /np.sqrt(n_samples) # correct assuming standard error propagation elif collapse.lower() == "median": with warnings.catch_warnings(): # prevent RuntimeWarning: Mean of empty slice warnings.filterwarnings('ignore', category=RuntimeWarning) data_collapse = np.nanmedian(data_subset, axis=0) err_collapse = np.sqrt(np.nanmean(err_subset**2, axis=0)) /np.sqrt(n_samples) * np.sqrt(np.pi/2) # inflate median error if num_frames_scaling: # scale up by the number of frames data_collapse *= tot_frames err_collapse *= tot_frames # dq collpase: keep all flags on dq_collapse = np.bitwise_or.reduce(dq_subset, axis=0) # except for those pixels that have been replaced with good values dq_collapse[np.where((dq_collapse > 0) & (~np.isnan(data_collapse)))] = 0 # other hdus if other_hdus is not None: combined_hdus = [[] for _ in range(len(other_hdus[0]))] # iterate over each hdulist and append the data for hdul in other_hdus: for i, hdu in enumerate(hdul): combined_hdus[i].append(np.copy(hdu.data)) # now combine each hdu data for i in range(len(combined_hdus)): # TODO: not implemented how to take means of anything beyond np.arrays (e.g., np.recarray) try: if collapse.lower() == "mean": combined_hdus[i] = np.nanmean(np.array(combined_hdus[i]), axis=0) elif collapse.lower() == "median": combined_hdus[i] = np.nanmedian(np.array(combined_hdus[i]), axis=0) # nothing here makes sense to scale by number of frames # if num_frames_scaling: # combined_hdus[i] *= tot_frames except: combined_hdus[i] = combined_hdus[i][0] # just take the first one if cannot combine else: combined_hdus = None return data_collapse, err_collapse, dq_collapse, combined_hdus
[docs] def combine_subexposures(input_dataset, num_frames_per_group=None, collapse="mean", num_frames_scaling=True, combine_other_hdus=False): """ Combines a sequence of exposures assuming a constant nubmer of frames per group. The length of the dataset must be divisible by the number of frames per group. The combination is done with either the mean or median, but the collapsed image can be scaled in order to ~conserve the total number of photons in the input dataset (this essentially turns a median into a sum) Args: input_dataset (corgidrp.data.Dataset): input data. num_frames_per_group (int): number of subexposures per group. If None, combines all images together collapse (str): "mean" or "median". (default: mean) num_frames_scaling (bool): Multiply by number of frames in sequence in order to ~conserve photons (default: True) combine_other_hdus (bool): Whether to combine other HDUs in the same way as the main data, err, DQ. Otherwise, uses the HDUs from the first frame in a subset (default: False) Returns: corgidrp.data.Dataset: dataset after combination of every "num_frames_per_group" frames together """ if num_frames_per_group is None: num_frames_per_group = len(input_dataset) if len(input_dataset) % num_frames_per_group != 0: raise ValueError("Input dataset of length {0} cannot be grouped in sets of {1}".format(len(input_dataset), num_frames_per_group)) if collapse.lower() not in ["mean", "median"]: raise ValueError("combine_subexposures can only collapse with mean or median") num_groups = len(input_dataset) // num_frames_per_group new_dataset = [] for i in range(num_groups): data_subset = np.copy(input_dataset.all_data[num_frames_per_group*i:num_frames_per_group*(i+1)]) err_subset = np.copy(input_dataset.all_err[num_frames_per_group*i:num_frames_per_group*(i+1)]) dq_subset = np.copy(input_dataset.all_dq[num_frames_per_group*i:num_frames_per_group*(i+1)]) if combine_other_hdus: other_hdus = [input_dataset[j].hdu_list for j in range(num_frames_per_group*i, num_frames_per_group*(i+1))] else: other_hdus = None data_collapse, err_collapse, dq_collapse, combined_hdus = combine_images(data_subset, err_subset, dq_subset, collapse=collapse, num_frames_scaling=num_frames_scaling, other_hdus=other_hdus) # grab the headers from the first frame in this sub sequence pri_hdr = input_dataset[num_frames_per_group*i].pri_hdr.copy() ext_hdr = input_dataset[num_frames_per_group*i].ext_hdr.copy() ext_hdr["NUM_FR"] = num_frames_per_group err_hdr = input_dataset[num_frames_per_group*i].err_hdr.copy() dq_hdr = input_dataset[num_frames_per_group*i].dq_hdr.copy() hdulist = input_dataset[num_frames_per_group*i].hdu_list.copy() # update other hdus if needed if combine_other_hdus: for j, hdu in enumerate(hdulist): hdu.data = combined_hdus[j] hdu.header["NUM_FR"] = num_frames_per_group hdu.header['HISTORY'] = "Combined {0} frames by {1}".format(num_frames_per_group, collapse) new_image = data.Image(data_collapse, pri_hdr=pri_hdr, ext_hdr=ext_hdr, err=err_collapse, dq=dq_collapse, err_hdr=err_hdr, dq_hdr=dq_hdr, input_hdulist=hdulist) # always take the last filename in the group for the combined frame last_idx_in_group = num_frames_per_group*(i+1) - 1 new_image.filename = input_dataset[last_idx_in_group].filename new_image._record_parent_filenames(input_dataset[num_frames_per_group*i:num_frames_per_group*(i+1)]) new_dataset.append(new_image) new_dataset = data.Dataset(new_dataset) drpnfile = new_dataset[0].ext_hdr['DRPNFILE'] # Here we change header keywords only for the combined non-coronagraphic imaging datasets if (input_dataset[0].ext_hdr['DPAMNAME'] == 'IMAGING' and input_dataset[0].ext_hdr['LSAMNAME'] == 'OPEN') and input_dataset[0].ext_hdr['DATALVL'] == 'L3': # average/delete header keywords as L4 involves combination of multiple frames pri_hdr_comb, ext_hdr_comb, _, _ = corgidrp.check.merge_headers(input_dataset, last_frame_keywords=['VISITID', 'MJDEND', 'SCTEND'], first_frame_keywords=['MJDSRT','SCTSRT','CD1_1', 'CD1_2', 'CD2_1', 'CD2_2', 'CRPIX1', 'CRPIX2','NORTHANG'], deleted_keywords=['CDELT1','CDELT2','FILE0'] + corgidrp.check.deleted_keywords_default, #we re-add FILE0 below invalid_keywords=[ #Primary header keywords 'FILETIME', 'PA_V3', 'PA_APER','SVB_1', 'SVB_2', 'SVB_3', 'ROLL', 'PITCH', 'YAW', 'WBJ_1', 'WBJ_2', 'WBJ_3', #Extension header keywords 'DATETIME', 'FTIMEUTC','DATATYPE'], averaged_keywords=['EXCAMT','NOVEREXP','PROXET', 'FCMPOS','FSMSG1', 'FSMSG2', 'FSMSG3', 'FSMX', 'FSMY', 'SB_FP_DX', 'SB_FP_DY', 'SB_FS_DX', 'SB_FS_DY', 'Z2AVG', 'Z3AVG', 'Z4AVG', 'Z5AVG', 'Z6AVG', 'Z7AVG', 'Z8AVG', 'Z9AVG', 'Z10AVG', 'Z11AVG', 'Z12AVG', 'Z13AVG', 'Z14AVG', 'Z2RES', 'Z3RES', 'Z4RES', 'Z5RES', 'Z6RES', 'Z7RES', 'Z8RES', 'Z9RES', 'Z10RES', 'Z11RES', 'Z2VAR', 'Z3VAR']) # incorporate modified headers in L4 dataset for img in new_dataset: img.pri_hdr = pri_hdr_comb img.ext_hdr = ext_hdr_comb img.ext_hdr['NUM_FR'] = num_frames_per_group img.ext_hdr['DRPNFILE'] = drpnfile img._record_parent_filenames(input_dataset) new_dataset.update_after_processing_step("Combine_subexposures: combined every {0} frames by {1}".format(num_frames_per_group, collapse)) return new_dataset
[docs] def derotate_arr(data_arr,northang_deg, xcen,ycen,new_center=None,astr_hdr=None, is_dq=False,dq_round_threshold=0.05): """Derotates an array based on the provided NORTHANG angle, about the provided center. Treats DQ arrays specially, converting to float to do the rotation, and converting back to np.int64 afterwards. DQ output becomes only zeros and ones, so detailed DQ flag information is not preserved. Args: data_arr (np.array): an array with 2-4 dimensions northang_deg (float): angle (measured counter-clockwise) of the detector y axis from celestial north (degrees). Calculated from the northangle of the astrometric cal frame and the PA_APER offset between the astrom cal frame and the science frame. xcen (float): x-coordinate of center about which to rotate ycen (float): y-coordinate of center about which to rotate new_center (tuple, optional): tuple of x- and y- coordinate of the new center to shift to. astr_hdr (astropy.fits.Header, optional): WCS header which will be updated. Defaults to None. is_dq (bool, optional): Flag to determine if this is a DQ array. Defaults to False. dq_round_threshold (float, optional): value between 0-1 which determines the threshold for spreading dq values to neighboring pixels after derotation. Returns: np.array: The derotated array. """ # Temporarily convert dq to floats if is_dq: data_arr = data_arr.astype(np.float32) if data_arr.ndim == 2: derotated_arr = rotate(data_arr,northang_deg,(xcen,ycen), new_center=new_center, astr_hdr=astr_hdr) # astr_hdr is corrected at above lines elif data_arr.ndim == 3: derotated_arr = [] for i,im in enumerate(data_arr): derotated_im = rotate(im,northang_deg,(xcen,ycen), new_center=new_center, astr_hdr=astr_hdr if (i==0) else None) # astr_hdr is corrected only once derotated_arr.append(derotated_im) derotated_arr = np.array(derotated_arr) elif data_arr.ndim == 4: derotated_arr = [] for s,set in enumerate(data_arr): derotated_set = [] for i,im in enumerate(set): derotated_im = rotate(im,northang_deg,(xcen,ycen), new_center=new_center, astr_hdr=astr_hdr if (i==0 and s==0) else None) # astr_hdr is corrected only once derotated_set.append(derotated_im) derotated_arr.append(derotated_set) derotated_arr = np.array(derotated_arr) else: raise ValueError('derotate_arr() not configured for data with >4 dimensions') # convert dq_array back to ints if is_dq: derotated_arr[np.isnan(derotated_arr)] = 1 # assign nans to 1 derotated_arr_int = (derotated_arr>dq_round_threshold).astype(np.int64) # import matplotlib.pyplot as plt # plt.imshow(derotated_arr_int,origin='lower') # plt.colorbar() # plt.title(f'round_threshold: {round_threshold}') # plt.show() return derotated_arr_int return derotated_arr
[docs] def prop_err_dq(sci_dataset,ref_dataset,mode,dq_thresh=1,new_center=None): """Applies logic to propagate the dq arrays and error arrays in a dataset through PSF subtraction. Args: sci_dataset (corgidrp.data.Dataset): The input science dataset. ref_dataset (corgidrp.data.Dataset): The input reference dataset (or None if ADI only). mode (str): The PSF subtraction mode, e.g. "ADI", "RDI", "ADI+RDI". dq_thresh (int): Minimum dq flag value to be considered a bad pixel. Defaults to 1. new_center (tuple): New center (xy) to align all frames. Defaults to pixel closest to array center. Returns: tuple of np.array: the dq array and err array which should apply to the PSF subtraction output dataset. """ # Assign master output dq & error (before derotation) # dq shape = (n_rolls, n_wls(optional), y, x) sci_input_dqs = sci_dataset.all_dq >= dq_thresh sci_input_errs = np.full_like(sci_dataset.all_err,np.nan) # Set errors to np.nan for now if new_center is None: new_center = [int(sci_dataset.all_data.shape[-1]//2), int(sci_dataset.all_data.shape[-2]//2)] # Align frames aligned_sci_dq_arr = [] aligned_sci_err_arr = [] for i,frame in enumerate(sci_dataset): xcen, ycen = frame.ext_hdr['STARLOCX'], frame.ext_hdr['STARLOCY'] frame.ext_hdr['STARLOCX'], frame.ext_hdr['STARLOCY'] = new_center aligned_sci_dq = derotate_arr(sci_input_dqs[i],0, xcen,ycen, new_center=new_center,is_dq=True) aligned_sci_err = derotate_arr(sci_input_errs[i],0, xcen,ycen, new_center=new_center) aligned_sci_dq_arr.append(aligned_sci_dq) aligned_sci_err_arr.append(aligned_sci_err) aligned_sci_dq_arr = np.array(aligned_sci_dq_arr) aligned_sci_err_arr = np.array(aligned_sci_err_arr) if "RDI" in mode: ref_input_dqs = ref_dataset.all_dq >= dq_thresh ref_input_errs = np.full_like(ref_dataset.all_err,np.nan) # Set errors to np.nan for now aligned_ref_dq_arr = [] aligned_ref_err_arr = [] for i,frame in enumerate(ref_dataset): xcen, ycen = frame.ext_hdr['STARLOCX'], frame.ext_hdr['STARLOCY'] frame.ext_hdr['STARLOCX'], frame.ext_hdr['STARLOCY'] = new_center aligned_ref_dq = derotate_arr(ref_input_dqs[i],0, xcen,ycen, new_center=new_center,is_dq=True) aligned_ref_err = derotate_arr(ref_input_errs[i],0, xcen,ycen, new_center=new_center) aligned_ref_dq_arr.append(aligned_ref_dq) aligned_ref_err_arr.append(aligned_ref_err) aligned_ref_dq_arr = np.array(aligned_ref_dq_arr) aligned_ref_err_arr = np.array(aligned_ref_err_arr) # If doing ADI, flag pixels that are bad in all science frames if 'ADI' in mode: aligned_sci_dq_arr[:] = np.all(aligned_sci_dq_arr,axis=0) # If using references, flag pixels that are bad in all the ref frames if 'RDI' in mode: ref_output_dqs_flat = np.all(aligned_ref_dq_arr,axis=0,keepdims=True) aligned_sci_dq_arr = np.logical_or(aligned_sci_dq_arr,ref_output_dqs_flat) # Derotate dq & error derotated_dq_arr = [] derotated_err_arr = [] for i,frame in enumerate(sci_dataset): northang_deg = frame.ext_hdr['NORTHANG'] xcen, ycen = frame.ext_hdr['STARLOCX'], frame.ext_hdr['STARLOCY'] derotated_dq = derotate_arr(aligned_sci_dq_arr[i],northang_deg, xcen,ycen,is_dq=True) derotated_err = derotate_arr(aligned_sci_err_arr[i],northang_deg, xcen,ycen) derotated_dq_arr.append(derotated_dq) derotated_err_arr.append(derotated_err) # Collapse dq & error dq_out_collapsed = np.where(np.all(derotated_dq_arr,axis=0),1,0) err_out_collapsed = np.sqrt(np.sum(np.array(derotated_err_arr)**2,axis=0)) return dq_out_collapsed, err_out_collapsed