Source code for tweakwcs.matchutils

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
A module that provides algorithms matching catalogs and for initial estimation
of shifts based on 2D histograms.

:License: :doc:`LICENSE`

"""
import logging
import warnings
from abc import ABC, abstractmethod

import numpy as np
import astropy
from astropy.utils.decorators import deprecated
from astropy.utils.exceptions import AstropyDeprecationWarning

from stsci.stimage import xyxymatch
from scipy import spatial

from . import __version__  # noqa: F401

__author__ = 'Mihai Cara'

__all__ = ['MatchCatalogs', 'XYXYMatch']

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


[docs] class MatchCatalogs(ABC): """ A class that provides common interface for matching catalogs. """ def __init__(self): """ """ @abstractmethod def __call__(self, refcat, imcat, **kwargs): """ Performs catalog matching. Parameters ---------- refcat: astropy.table.Table A reference source catalog. Reference catalog must contain ``'TPx'`` and ``'TPy'`` columns that provide undistorted (distortion-correction applied) source coordinates coordinate system common (shared) with the image catalog ``imcat``. imcat: astropy.table.Table Source catalog associated with an image. Image catalog must contain ``'TPx'`` and ``'TPy'`` columns that provide undistorted (distortion-correction applied) source coordinates coordinate system common (shared) with the image catalog ``refcat``. **kwargs : dict Any keyword arguments for ``__call__`` specific to subclass. Returns ------- (refcat_idx, imcat_idx): tuple of numpy.ndarray A tuple of two 1D `numpy.ndarray` containing indices of matched sources in the ``refcat`` and ``imcat`` catalogs accordingly. """ pass
[docs] class XYXYMatch(MatchCatalogs): """ Catalog source matching in tangent plane. Uses ``xyxymatch`` algorithm to cross-match sources between this catalog and a reference catalog. .. note:: The tangent plane is a plane tangent to the celestial sphere and it must be not distorted, that is, if *image* coordinates are distorted, then distortion correction must be applied to them before tangent plane coordinates are computed. Alternatively, one can think that undistorted world coordinates are projected from the sphere onto the tangent plane. """ def __init__(self, searchrad=3.0, separation=0.5, use2dhist=True, xoffset=0.0, yoffset=0.0, tolerance=1.0): """ Parameters ---------- searchrad: float, optional The search radius for a match (in units of the tangent plane). separation: float, optional The minimum separation in the tangent plane (in units of the tangent plane) for sources in the image and reference catalogs in order to be considered to be disctinct sources. Objects closer together than ``separation`` distance are removed from the image and reference coordinate catalogs prior to matching. This parameter gets passed directly to :py:func:`~stsci.stimage.xyxymatch` for use in matching the object lists from each image with the reference catalog's object list. use2dhist: bool, optional Use 2D histogram to find initial offset? xoffset: float, optional Initial estimate for the offset in X (in units of the tangent plane) between the sources in the image and the reference catalogs. This offset will be used for all input images provided. This parameter is ignored when ``use2dhist`` is `True`. yoffset: float, optional Initial estimate for the offset in Y (in units of the tangent plane) between the sources in the image and the reference catalogs. This offset will be used for all input images provided. This parameter is ignored when ``use2dhist`` is `True`. tolerance: float, optional The matching tolerance (in units of the tangent plane) after applying an initial solution derived from the 'triangles' algorithm. This parameter gets passed directly to :py:func:`~stsci.stimage.xyxymatch` for use in matching the object lists from each image with the reference image's object list. """ self._use2dhist = use2dhist if searchrad > 0: self._searchrad = searchrad else: raise ValueError("'searchrad' must be a positive number.") if separation > 0: self._separation = separation else: raise ValueError("'separation' must be a positive number.") if tolerance > 0: self._tolerance = tolerance else: raise ValueError("'tolerance' must be a positive number.") self._xoffset = float(xoffset) self._yoffset = float(yoffset) def __call__(self, refcat, imcat, tp_pscale=1.0, tp_units=None, **kwargs): r""" Performs catalog matching. Parameters ---------- refcat: astropy.table.Table A reference source catalog. Reference catalog must contain ``'TPx'`` and ``'TPy'`` columns that provide undistorted (distortion-correction applied) source coordinates in some *tangent plane*. The ``'RA'`` and ``'DEC'`` columns in the ``refcat`` catalog will be ignored. imcat: astropy.table.Table Source catalog associated with an image. The catalog must contain ``'TPx'`` and ``'TPy'`` columns that provide undistorted (distortion-correction applied) source coordinates in **the same**\ *tangent plane* used to define ``refcat``'s tangent plane coordinates. In this case, the ``'x'`` and ``'y'`` columns in the ``imcat`` catalog will be ignored. tp_pscale: float Pixel scale: size of an image pixel in the tangent plane. Pixel scale is in the same units as the coordinates of the tangent plane. Pixel scale is used to compute bin size used for initial 2D histogram alignment performed before matching. tp_units: str Units of the tangent plane coordinates. Returns ------- (refcat_idx, imcat_idx): tuple of numpy.ndarray A tuple of two 1D `numpy.ndarray` containing indices of matched sources in the ``refcat`` and ``imcat`` catalogs accordingly. """ # Check catalogs: if not isinstance(refcat, astropy.table.Table): raise TypeError("'refcat' must be an instance of " "astropy.table.Table") if not refcat: raise ValueError("Reference catalog must contain at least one " "source.") if not isinstance(imcat, astropy.table.Table): raise TypeError("'imcat' must be an instance of " "astropy.table.Table") if not imcat: raise ValueError("Image catalog must contain at least one " "source.") if 'tp_wcs' in kwargs: warnings.warn( "Argument 'tp_wcs' has been deprecated since version 0.8.1. " "Please use 'tp_pscale' instead and populate 'TPx' and 'TPy' " "columns of input catalogs.", AstropyDeprecationWarning ) tp_wcs = kwargs.get('tp_wcs') if tp_wcs is None: if 'TPx' not in refcat.colnames or 'TPy' not in refcat.colnames: raise KeyError("When tangent plane WCS is not provided, " "'refcat' must contain both 'TPx' and 'TPy' " "columns.") if 'TPx' not in imcat.colnames or 'TPy' not in imcat.colnames: raise KeyError("When tangent plane WCS is not provided, " "'imcat' must contain both 'TPx' and 'TPy' " "columns.") imxy = np.asarray([imcat['TPx'], imcat['TPy']]).T refxy = np.asarray([refcat['TPx'], refcat['TPy']]).T else: if 'RA' not in refcat.colnames or 'DEC' not in refcat.colnames: raise KeyError("When tangent plane WCS is provided, 'refcat' " "must contain both 'RA' and 'DEC' columns.") if 'x' not in imcat.colnames or 'y' not in imcat.colnames: raise KeyError("When tangent plane WCS is provided, 'imcat' " "must contain both 'x' and 'y' columns.") # compute x & y in the tangent plane provided by tp_wcs: imxy = np.asarray( tp_wcs.det_to_tanp(imcat['x'], imcat['y']) ).T refxy = np.asarray( tp_wcs.world_to_tanp(refcat['RA'], refcat['DEC']) ).T imcat_name = imcat.meta.get('name', 'Unnamed') if imcat_name is None: imcat_name = 'Unnamed' refcat_name = refcat.meta.get('name', 'Unnamed') if refcat_name is None: refcat_name = 'Unnamed' log.info("Matching sources from '{:s}' catalog with sources from the " "reference '{:s}' catalog." .format(imcat_name, refcat_name)) if self._use2dhist: # Determine xyoff (X,Y offset) and tolerance # to be used with xyxymatch: xyoff = _estimate_2dhist_shift( imxy, refxy, searchrad=self._searchrad, pscale=tp_pscale, units=tp_units ) else: xyoff = (self._xoffset, self._yoffset) matches = xyxymatch( imxy, refxy, origin=xyoff, tolerance=self._tolerance, separation=self._separation ) return matches['ref_idx'], matches['input_idx']
def _xy_2dhist(imgxy, refxy, r): # trim to only pairs within (r+0.5) * np.sqrt(2) using a kdtree # to avoid computing differences for many widely separated pairs. kdtree = spatial.KDTree(refxy) neighbors = kdtree.query_ball_point(imgxy, (r + 0.5) * np.sqrt(2)) lens = [len(n) for n in neighbors] mi = np.repeat(np.arange(imgxy.shape[0]), lens) if len(mi) > 0: mr = np.concatenate([n for n in neighbors if len(n) > 0]) else: mr = mi dx = imgxy[mi, 0] - refxy[mr, 0] dy = imgxy[mi, 1] - refxy[mr, 1] idx = np.where((dx < r + 0.5) & (dx >= -r - 0.5) & (dy < r + 0.5) & (dy >= -r - 0.5)) r = int(np.ceil(r)) h = np.histogram2d(dx[idx], dy[idx], 2 * r + 1, [[-r - 0.5, r + 0.5], [-r - 0.5, r + 0.5]]) return h[0].T def _estimate_2dhist_shift(imgxy, refxy, searchrad=3.0, pscale=1.0, units=None): """ Create a 2D matrix-histogram which contains the delta between each XY position and each UV position. Then estimate initial offset between catalogs. ``pscale`` is used to make bins of size approximately equal to image pixel. """ log.info("Computing initial guess for X and Y shifts...") if units is None: units = 'tangent plane units' # create ZP matrix zpmat = _xy_2dhist(imgxy / pscale, refxy / pscale, r=searchrad / pscale) nonzeros = np.count_nonzero(zpmat) if nonzeros == 0: # no matches within search radius. Return (0, 0): log.warning( f"No matches found within a search radius of {searchrad:g} ({units})." ) return 0.0, 0.0 elif nonzeros == 1: # only one non-zero bin: yp, xp = np.unravel_index(np.argmax(zpmat), zpmat.shape) maxval = zpmat[yp, xp] xp = pscale * xp - searchrad yp = pscale * yp - searchrad log.info( f"Found initial X and Y shifts of {xp:.4g}, {yp:.4g} ({units}) " f"based on a single non-zero bin and {int(maxval):d} matches." ) return xp, yp (xp, yp), fit_status, fit_sl = _find_peak( zpmat, peak_fit_box=5, mask=zpmat > 0 ) if fit_status.startswith('ERROR'): log.warning( f"No valid shift found within a search radius of {searchrad:g} {units}." ) return 0.0, 0.0 xp = pscale * xp - searchrad yp = pscale * yp - searchrad if fit_status == 'WARNING:EDGE': log.info("Found peak in the 2D histogram lies at the edge of the " "histogram. Try increasing 'searchrad' for improved results.") flux = int(zpmat[fit_sl].sum()) # Attempt to estimate "significance of detection": maxval = zpmat.max() zpmat_mask = (zpmat > 0) & (zpmat < maxval) bkg = zpmat[zpmat_mask].mean() if np.any(zpmat_mask) else -1.0 if bkg > 0: # pragma: no branch bkg = zpmat[zpmat_mask].mean() sig = maxval / np.sqrt(bkg) log.info( f"Found initial X and Y shifts of {xp:.4g}, {yp:.4g} ({units}) " f"with significance of {sig:.4g} and {flux:d} matches." ) else: log.warning("Unable to estimate significance of the detection of the " "initial shift.") log.info( f"Found initial X and Y shifts of {xp:.4g}, {yp:.4g} ({units}) " f"with {flux:d} matches." ) return xp, yp def _find_peak(data, peak_fit_box=5, mask=None): """ Find location of the peak in an array. This is done by fitting a second degree 2D polynomial to the data within a `peak_fit_box` and computing the location of its maximum. An initial estimate of the position of the maximum will be performed by searching for the location of the pixel/array element with the maximum value. Parameters ---------- data: numpy.ndarray 2D data. peak_fit_box: int, optional Size (in pixels) of the box around the initial estimate of the maximum to be used both for quadratic fitting from which peak location is computed and for the center-of-mass estimate. It is assumed that fitting box is a square with sides of length given by ``peak_fit_box``. mask: numpy.ndarray, optional A boolean type `~numpy.ndarray` indicating "good" pixels in image data (`True`) and "bad" pixels (`False`). If not provided all pixels in `image_data` will be used for fitting. Returns ------- coord: tuple of float A pair of coordinates of the peak. fit_status: str Status of the peak search. Currently the following values can be returned: - ``'SUCCESS'``: Fit was successful and peak is not on the edge of the input array; - ``'ERROR:NODATA'``: Not enough valid data to perform the fit; The returned coordinate is the center of input array; - ``'WARNING:EDGE'``: Peak lies on the edge of the input array. Returned coordinates are the result of a discreet search; - ``'WARNING:BADFIT'``: Performed fid did not find a maximum or the estimated maximum is outside of the fit box. Returned coordinates are the result of either a center-of-mass estimate or a discreet search; - ``'WARNING:CENTER-OF-MASS'``: Returned coordinates are the result of a center-of-mass estimate instead of a polynomial fit. This is either due to too few points to perform a fit, polynomial peak being outside of the fit box, or due to a failure of the polynomial fit. fit_box: a tuple of `slice` A tuple of `slice` objects of the form ``(slice(y1, y2, None), slice(x1, x2, None))`` that indicates pixel ranges used for fitting (these indices can be used directly for slicing input data) """ def _center_of_mass(v, d, x1, x2, y1, y2): # Compute center-of-mass. Assumes that ``v`` was computed using # coordinates ``x``, ``y`` relative to ``x1 - 1`` and ``y1 - 1``. # Returned coordinate is relative to origin. m = np.logical_not(np.isfinite(d)) vx = v[:, 1].flatten() vy = v[:, 2].flatten() d[m] = 0 vx[m] = 0 vy[m] = 0 dt = d.sum() if dt == 0.0: coord = ((x2 + x1 - 1.0) / 2.0, (y2 + y1 - 1.0) / 2.0) return coord, 'ERROR:NODATA' xc = np.dot(vx, d) / dt yc = np.dot(vy, d) / dt return ((float(x1 + xc - 1), float(y1 + yc - 1)), 'WARNING:CENTER-OF-MASS') # check arguments: if peak_fit_box < 1: raise ValueError("peak_fit_box must be at least 1 pixel in size.") data = np.asarray(data, dtype=np.double) ny, nx = data.shape # find index of the pixel having maximum value: finite = np.isfinite(data) mask = finite if mask is None else np.logical_and(mask, finite) j, i = np.indices(data.shape) i = i[mask] j = j[mask] if i.size == 0: # no valid data: coord = ((nx - 1.0) / 2.0, (ny - 1.0) / 2.0) return coord, 'ERROR:NODATA', np.s_[0:ny, 0:nx] ind = np.argmax(data[mask]) imax = i[ind] jmax = j[ind] coord = (float(imax), float(jmax)) if data[jmax, imax] < 1: # no valid data: we need some counts in the histogram bins coord = ((nx - 1.0) / 2.0, (ny - 1.0) / 2.0) return coord, 'ERROR:NODATA', np.s_[0:ny, 0:nx] # choose a box around maxval pixel: x1 = max(0, imax - peak_fit_box // 2) x2 = min(nx, x1 + peak_fit_box) y1 = max(0, jmax - peak_fit_box // 2) y2 = min(ny, y1 + peak_fit_box) # if peak is at the edge of the box, return integer indices of the max: if imax == x1 or imax == x2 - 1 or jmax == y1 or jmax == y2 - 1: return coord, 'WARNING:EDGE', np.s_[y1:y2, x1:x2] # expand the box if needed: if (x2 - x1) < peak_fit_box: # pragma: no branch if x1 == 0: # pragma: no branch x2 = min(nx, x1 + peak_fit_box) if x2 == nx: # pragma: no branch x1 = max(0, x2 - peak_fit_box) if (y2 - y1) < peak_fit_box: # pragma: no branch if y1 == 0: # pragma: no branch y2 = min(ny, y1 + peak_fit_box) if y2 == ny: # pragma: no branch y1 = max(0, y2 - peak_fit_box) assert x2 - x1 > 0 or y2 - y1 > 0 fit_slice = np.s_[y1:y2, x1:x2] # fit a 2D 2nd degree polynomial to data: m = mask[fit_slice].ravel() xi = np.arange(x1, x2) - (x1 - 1) yi = np.arange(y1, y2) - (y1 - 1) x, y = np.meshgrid(xi, yi) x = x.ravel() y = y.ravel() v = np.vstack((np.ones_like(x), x, y, x * y, x * x, y * y)).T[m] d = data[fit_slice].ravel()[m] if d.size < 6: # we need at least 6 points to fit a 2D quadratic polynomial # attempt center-of-mass instead: coord, fit_status = _center_of_mass(v, d, x1, x2, y1, y2) return coord, fit_status, fit_slice try: c = np.linalg.lstsq(v, d, rcond=None)[0] if not np.all(np.isfinite(c)): raise np.linalg.LinAlgError("Results of the fit are not finite.") except np.linalg.LinAlgError as e: log.warning("Least squares failed!\n{}".format(e)) # attempt center-of-mass instead: coord, fit_status = _center_of_mass(v, d, x1, x2, y1, y2) return coord, fit_status, fit_slice # find maximum of the polynomial: _, c10, c01, c11, c20, c02 = c det = 4 * c02 * c20 - c11**2 if det <= 0 or ((c20 > 0.0 and c02 >= 0.0) or (c20 >= 0.0 and c02 > 0.0)): # polynomial does not have max. return maximum value in the data: coord, fit_status = _center_of_mass(v, d, x1, x2, y1, y2) if fit_status.startswith('ERROR'): return coord, fit_status, fit_slice return coord, 'WARNING:BADFIT', fit_slice xm = (c01 * c11 - 2.0 * c02 * c10) / det + x1 - 1 ym = (c10 * c11 - 2.0 * c01 * c20) / det + y1 - 1 if x1 <= xm <= (x2 - 1.0) and y1 <= ym <= (y2 - 1.0): coord = (xm, ym) fit_status = 'SUCCESS' else: coord, fit_status = _center_of_mass(v, d, x1, x2, y1, y2) return coord, fit_status, fit_slice @deprecated(since='0.8.0', alternative='XYXYMatch') class TPMatch(XYXYMatch): pass