diff --git a/pyproject.toml b/pyproject.toml index 97311fe5b..2eee17cf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,9 @@ dependencies = [ "PyYAML>=6.0" ] +[project.scripts] +archiveheaders = "kbmod.mocking.dump_headers:main" + [project.urls] Documentation = "https://epyc.astro.washington.edu/~kbmod/" Repository = "https://github.com/dirac-institute/kbmod" diff --git a/src/kbmod/mocking/__init__.py b/src/kbmod/mocking/__init__.py new file mode 100644 index 000000000..1644f4b3a --- /dev/null +++ b/src/kbmod/mocking/__init__.py @@ -0,0 +1,6 @@ +from . import utils +from .catalogs import * +from .headers import * +from .data import * +from .fits import * +from .callbacks import * diff --git a/src/kbmod/mocking/callbacks.py b/src/kbmod/mocking/callbacks.py new file mode 100644 index 000000000..22d13ea62 --- /dev/null +++ b/src/kbmod/mocking/callbacks.py @@ -0,0 +1,84 @@ +import random + +from astropy.time import Time +import astropy.units as u + + +__all__ = ["IncrementObstime", "ObstimeIterator"] + + +class IncrementObstime: + """Endlessly incrementing FITS-standard timestamp. + + Parameters + ---------- + start : `astropy.time.Time` + Starting timestamp, or a value from which AstroPy can instantiate one. + dt : `float` or `astropy.units.Quantity` + Size of time-step to take. Assumed to be in days by default. + + Examples + -------- + >>> from kbmod.mocking import IncrementObstime + >>> obst = IncrementObstime("2021-01-01T00:00:00.0000", 1) + >>> obst() + '2021-01-01T00:00:00.000' + >>> obst() + '2021-01-02T00:00:00.000' + >>> import astropy.units as u + >>> obst = IncrementObstime("2021-01-01T00:00:00.0000", 1*u.hour) + >>> obst(); obst() + '2021-01-01T00:00:00.000' + '2021-01-01T01:00:00.000' + """ + + default_unit = "day" + + def __init__(self, start, dt): + self.start = Time(start) + if not isinstance(dt, u.Quantity): + dt = dt * getattr(u, self.default_unit) + self.dt = dt + + def __call__(self, header=None): + curr = self.start + self.start += self.dt + return curr.fits + + +class ObstimeIterator: + """Iterate through given timestamps. + + Parameters + ---------- + obstimes : `astropy.time.Time` + Starting timestamp, or a value from which AstroPy can instantiate one. + + Raises + ------ + StopIteration + When all the obstimes are exhausted. + + Examples + -------- + >>> from astropy.time import Time + >>> times = Time(range(60310, 60313, 1), format="mjd") + >>> from kbmod.mocking import ObstimeIterator + >>> obst = ObstimeIterator(times) + >>> obst(); obst(); obst(); obst() + '2024-01-01T00:00:00.000' + '2024-01-02T00:00:00.000' + '2024-01-03T00:00:00.000' + Traceback (most recent call last): + File "", line 1, in + File "/local/tmp/kbmod/src/kbmod/mocking/callbacks.py", line 49, in __call__ + + StopIteration + """ + + def __init__(self, obstimes, **kwargs): + self.obstimes = Time(obstimes, **kwargs) + self.generator = (t for t in obstimes) + + def __call__(self, header=None): + return Time(next(self.generator)).fits diff --git a/src/kbmod/mocking/catalogs.py b/src/kbmod/mocking/catalogs.py new file mode 100644 index 000000000..bd862a9d8 --- /dev/null +++ b/src/kbmod/mocking/catalogs.py @@ -0,0 +1,543 @@ +import numpy as np +from astropy.table import QTable +from astropy.coordinates import SkyCoord + +from .config import Config + + +__all__ = [ + "gen_random_catalog", + "SimpleCatalog", + "SourceCatalogConfig", + "SourceCatalog", + "ObjectCatalogConfig", + "ObjectCatalog", +] + + +def expand_gaussian_cols(cat): + """Expands columns ``flux`` and ``stddev`` into ``amplitude``, ``x_stddev`` + and ``y_stddev`` assuming the intended catalog model is a symmetric 2D + Gaussian. + + Amplitude is caluclated as: + A = flux/(2*pi*sigma_x*sigma_y) + + Parameters + ---------- + cat : `astropy.table.Table` + A catalog of simplified model parameters. + + Returns + ------ + expanded : `astropy.table.Table` + A catalog of AstroPy model parameters. + """ + if "x_stddev" not in cat.columns and "stddev" in cat.columns: + cat["x_stddev"] = cat["stddev"] + if "y_stddev" not in cat.columns and "stddev" in cat.columns: + cat["y_stddev"] = cat["stddev"] + + if "flux" in cat.columns and "amplitude" not in cat.columns: + cat["amplitude"] = cat["flux"] / (2.0 * np.pi * cat["x_stddev"] * cat["y_stddev"]) + + return cat + + +def gen_random_catalog(n, param_ranges, seed=None, assume_gaussian=True): + """Generates a random catalog of parameters of n sources based on + AstroPy's 2D models. + + The object parameters are specified as a dict where keys become columns and + the values represent the range from which each parameter is uniformly + randomly drawn from. + + If parameter ranges contain ``flux``, a column ``amplitude`` will be added + which value will be calculated assuming a 2D Gaussian model. If ``stddev`` + column is preset, then ``x_stddev`` and ``y_stddev`` columns are added, + assuming the model intended to be used with the catalog is a symmetrical + 2D Gaussian. + + Parameters + ---------- + n : `int` + Number of objects to create + param_ranges : `dict` + Dictionary whose keys become columns of the catalog and which values + define the ranges from which the catalog values are drawn from. + seed : `int` + NumPy's random number generator seed. + assume_gaussian : `bool` + Assume the catalog is intended for use with a 2D Gaussian model and + expand ``flux`` and ``stddev`` columns appropriately. See + `expand_gaussian_cols`/` + + Returns + ------- + catalog : `astropy.table.Table` + Catalog + + Examples + -------- + >>> import kbmod.mocking as kbmock + >>> kbmock.gen_random_catalog(3, {"amplitude": [0, 3], "x": [3, 6], "y": [6, 9]}, seed=100) + + amplitude x y + float64 float64 float64 + ------------------ ----------------- ----------------- + 2.504944891506027 3.128854712082634 8.370789493256163 + 1.7896620809036619 5.920963185318643 8.73101814388636 + 0.8665897250736108 4.789415112194066 8.064463342775236 + """ + cat = QTable() + rng = np.random.default_rng(seed) + + for param_name, (lower, upper) in param_ranges.items(): + cat[param_name] = rng.uniform(lower, upper, n) + + if assume_gaussian: + cat = expand_gaussian_cols(cat) + + return cat + + +class SimpleCatalogConfig(Config): + """A simple catalog configuration.""" + + mode = "static" + """Static, progressive or folding; static catalogs remain the same for every + realization, progressive catalogs modify values of realized catalog columns + and folding catalogs select rows from a larger catalog to return as a realization. + """ + + kind = "pixel" + """Either ``pixel`` or ``world``. The kind of position coordinates encoded + by the catalog. On-sky, world, coordinates require a list of WCSs to be given + to the mocking method""" + + return_copy = False + """For static catalogs, return a reference to the underlying catalog or + a copy that can be modified.""" + + seed = None + """Random number generator seed. When `None` a different seed is used for + every catalog factory.""" + + n = 100 + """Number of objects to generate.""" + + param_ranges = {} + """Default parameter ranges of the catalog values and columns that will be + generated.""" + + pix_pos_cols = ["x_mean", "y_mean"] + """Which """ + + pix_vel_cols = ["vx", "vy"] + + world_pos_cols = ["ra_mean", "dec_mean"] + + world_vel_cols = ["v_ra", "v_dec"] + + fold_col = "obstime" + + +class SimpleCatalog: + """Default base class for mocked catalogs factory. + + This base class will always generate empty catalogs and is intended to be + inherited from. + + A catalog is an `astropy.table.Table` of model values (position, amplitude + etc.) of sources. A catalog factory mock realizations of catalogs. There + are 2 ``kind``s of catalogs that can be mocked in 3 different ``modes``. + + In progressive catalogs mode an existing base catalog template column + values are incremented, or otherwise updated for each new realization. + Static catalogs always return the same realization of a catalog. And the + folding catalogs depend on an underlying larger catalog, from which they + select which rows to return as a new realization. This is namely most + appropriate for catalogs with timestamps, where a different realization of + a catalog is returned per timestamp. + When the catalog mode is ``folding`` the mocking method expects the values + on which to fold on. By default, these are timestamps `t`. + + If the catalog coordinate kind is ``pixel``, then the positions are + interpreted as pixel coordiantes. If the kind of catalog coordinates are + ``world`` then the positions are interpreted as on-sky coordinates in + decimal degrees and a list of WCSs is expected to be provided to the mocking + method. + + Parameters + ---------- + config : `SimpleCatalogConfig` + Factory configuration. + table : `astropy.table.Table` + The catalog template from which new realizations will be generated. + + Attributes + ---------- + config : `SimpleCatalogConfig` + The instance-bound configuration of the factory + table : `astropy.table.Table` + The template catalog used to create new realizations. + _realization : `astropy.table.Table` + The last realization of a catalog. + current : `int` + The current iterator counter. + + Examples + -------- + Directly instantiate a simple static catalog: + + >>> import kbmod.mocking as kbmock + >>> table = kbmock.gen_random_catalog(3, {"A": [0, 1], "x": [1, 2], "y": [2, 3]}, seed=100) + >>> f = kbmock.SimpleCatalog(table) + >>> f.mock() + [ + A x y + float64 float64 float64 + ------------------ ------------------ ------------------ + 0.8349816305020089 1.0429515706942114 2.7902631644187212 + 0.5965540269678873 1.9736543951062142 2.9103393812954526 + 0.2888632416912036 1.5964717040646885 2.688154447591745] + + Instantiating from factory methods will derive additional information + regarding the catalog contents: + + >>> f2 = kbmock.SimpleCatalog.from_table(table) + >>> f2.mock() + [ + A x y + float64 float64 float64 + ------------------ ------------------ ------------------ + 0.8349816305020089 1.0429515706942114 2.7902631644187212 + 0.5965540269678873 1.9736543951062142 2.9103393812954526 + 0.2888632416912036 1.5964717040646885 2.688154447591745] + + >>> f2.config["param_ranges"] + {'A': (0.2888632416912036, 0.8349816305020089), 'x': (1.0429515706942114, 1.9736543951062142), 'y': (2.688154447591745, 2.9103393812954526)} + >>> f.config["param_ranges"] + {} + + Folding catalogs just return subsets of the template catalog: + + >>> table["obstime"] = [1, 1, 2] + >>> f = kbmock.SimpleCatalog(table, mode="folding") + >>> f.mock(t=[1, 2]) + [ + A x y obstime + float64 float64 float64 int64 + ------------------ ------------------ ------------------ ------- + 0.8349816305020089 1.0429515706942114 2.7902631644187212 1 + 0.5965540269678873 1.9736543951062142 2.9103393812954526 1, + A x y obstime + float64 float64 float64 int64 + ------------------ ------------------ ----------------- ------- + 0.2888632416912036 1.5964717040646885 2.688154447591745 2] + + And progressive catalogs increment selected column values (note the velocities + were assigned the default expected column names but positions weren't): + + >>> table["vx"] = [1, 1, 1] + >>> table["vy"] = [10, 10, 10] + >>> f = kbmock.SimpleCatalog(table, mode="progressive", pix_pos_cols=["x", "y"]) + >>> _ = f.mock(dt=1); f.mock(dt=1) + [ + A x y vx vy + float64 float64 float64 float64 float64 + ------------------ ------------------ ------------------ ------- ------- + 0.8349816305020089 2.0429515706942114 12.790263164418722 1.0 10.0 + 0.5965540269678873 2.973654395106214 12.910339381295453 1.0 10.0 + 0.2888632416912036 2.5964717040646885 12.688154447591746 1.0 10.0] + """ + + default_config = SimpleCatalogConfig + + def __init__(self, table, config=None, **kwargs): + self.config = self.default_config(config=config, **kwargs) + self.table = table + self.current = 0 + self._realization = self.table.copy() + self.mode = self.config["mode"] + self.kind = self.config["kind"] + + @classmethod + def from_defaults(cls, param_ranges=None, **kwargs): + """Create a catalog factory using its default config. + + Parameters + ---------- + param_ranges : `dict` + Default parameter ranges of the catalog values and columns that will be + generated. See `gen_random_catalog`. + kwargs : `dict` + Any additional keyword arguments will be used to supplement or override + any matching default configuration parameters. + + Returns + ------- + factory : `SimpleCatalog` + Simple catalog factory. + """ + config = cls.default_config(**kwargs) + if param_ranges is not None: + config["param_ranges"].update(param_ranges) + table = gen_random_catalog(config["n"], config["param_ranges"], config["seed"]) + return cls(table=table, config=config) + + @classmethod + def from_table(cls, table, **kwargs): + """Create a factory from a table template, deriving parameters, their + value ranges and number of objects from the table. + + Optionally expands the given table columns assuming the intended source + model is a 2D Gaussian. + + Parameters + ---------- + table : `astropy.table.Table` + Catalog template. + kwargs : `dict` + Any additional keyword arguments will be used to supplement or override + any matching default configuration parameters. + + Returns + ------- + factory : `SimpleCatalog` + Simple catalog factory. + """ + table = expand_gaussian_cols(table) + + config = cls.default_config(**kwargs) + config["n"] = len(table) + params = {} + for col in table.keys(): + params[col] = (table[col].min(), table[col].max()) + config["param_ranges"] = params + return cls(table=table, config=config) + + @property + def mode(self): + """Catalog mode, ``static``, ``folding`` or ``progressive``.""" + return self._mode + + @mode.setter + def mode(self, val): + if val == "folding": + self._gen_realization = self.fold + self.config["return_copy"] = True + elif val == "progressive": + self._gen_realization = self.next + self.config["return_copy"] = True + elif val == "static": + self._gen_realization = self.static + else: + raise ValueError( + "Unrecognized object catalog mode. Expected 'static', " + f"'progressive', or 'folding', got {val} instead." + ) + self._mode = val + + @property + def kind(self): + """Catalog coordinate kind, ``pixel`` or ``world``""" + return self._kind + + @kind.setter + def kind(self, val): + if val == "pixel": + self._cat_keys = self.config["pix_pos_cols"] + self.config["pix_vel_cols"] + elif val == "world": + self._cat_keys = self.config["world_pos_cols"] + self.config["world_vel_cols"] + else: + raise ValueError( + "Unrecognized coordinate kind. Expected 'world' or 'pixel, got" f"{val} instead." + ) + self._kind = val + + def reset(self): + """Reset the iteration counter reset the realization to the initial one.""" + self.current = 0 + self._realization = self.table.copy() + + def static(self, **kwargs): + """Return the initial template as a catalog realization. + + Returns + ------- + catalog : `astropy.table.Table` + Catalog realization. + """ + self.current += 1 + if self.config["return_copy"]: + return self.table.copy() + return self.table + + def next(self, dt): + """Return the next catalog realization by incrementing the position + columns by the value of the velocity column and number of current `dt` steps. + + Parameters + ---------- + dt : `float` + Time increment of each step. + + Returns + ------- + catalog : `astropy.table.Table` + Catalog realization. + """ + a, b, va, vb = self._cat_keys + self._realization[a] = self.table[a] + self.current * self.table[va] * dt + self._realization[b] = self.table[b] + self.current * self.table[vb] * dt + self.current += 1 + return self._realization.copy() + + def fold(self, t, **kwargs): + """Return the next catalog realization by selecting those rows that + match the given parameter ``t``. By default the folding column is + ``obstime``. + + Parameters + ---------- + t : `float` + Value which to select from template catalog. + + Returns + ------- + catalog : `astropy.table.Table` + Catalog realization. + """ + self._realization = self.table[self.table[self.config["fold_col"]] == t] + self.current += 1 + return self._realization.copy() + + def mock(self, n=1, dt=None, t=None, wcs=None): + """Return the next realization(s) of the catalogs. + + Selects the appropriate mocking function. Ignores keywords not + appropriate for use given some catalog generation method and coordinate + kind. + + Parameters + ---------- + n : `int`, optional + Number of catalogs to mock. Default 1. + dt : `float`, optional. + Timestep between each step (arbitrary units) + t : `list[float]` or `list[astropy.time.Time]`, optional + Values on which to fold the template catalog. + wcs : `list[astropy.wcs.WCS]`, optional + WCS to use in conversion of on-sky coordinates to pixel coordinates, + for each realization. + """ + data = [] + + if self.mode == "folding": + for i, ts in enumerate(t): + data.append(self.fold(t=ts)) + else: + for i in range(n): + data.append(self._gen_realization(dt=dt)) + + if self.kind == "world": + racol, deccol = self.config["world_pos_cols"] + xpixcol, ypixcol = self.config["pix_pos_cols"] + for cat, w in zip(data, wcs): + x, y = w.world_to_pixel(SkyCoord(ra=cat[racol], dec=cat[deccol], unit="deg")) + cat[xpixcol] = x + cat[ypixcol] = y + + return data + + +class SourceCatalogConfig(SimpleCatalogConfig): + """Source catalog config. + + Assumes sources are static, asymmetric 2D Gaussians. + + Parameter ranges + ---------------- + amplitude : [1, 10] + Amplitude of the model. + x_mean : [0, 4096] + Real valued x coordinate of the object's centroid. + y_mean : [0, 2048] + Real valued y coordinate of the object's centroid. + x_stddev : [1, 3] + Real valued standard deviation of the model distribution, in x. + y_stddev : [1, 3] + Real valued standard deviation of the model distribution, in y. + theta : `[0, np.pi]` + Rotation of the model's covariance matrix, increases counterclockwise. + In radians. + """ + + param_ranges = { + "amplitude": [1.0, 10.0], + "x_mean": [0.0, 4096.0], + "y_mean": [0.0, 2048.0], + "x_stddev": [1.0, 3.0], + "y_stddev": [1.0, 3.0], + "theta": [0.0, np.pi], + } + + +class SourceCatalog(SimpleCatalog): + """A static catalog representing stars and galaxies. + + Coordinates defined in pixel space. + """ + + default_config = SourceCatalogConfig + + +class ObjectCatalogConfig(SimpleCatalogConfig): + """Object catalog config. + + Assumes objects are symmetric 2D Gaussians moving in a linear fashion. + + Parameter ranges + ---------------- + amplitude : [1, 10] + Amplitude of the model. + x_mean : [0, 4096] + Real valued x coordinate of the object's centroid. + y_mean : [0, 2048] + Real valued y coordinate of the object's centroid. + x_stddev : [1, 3] + Real valued standard deviation of the model distribution, in x. + y_stddev : [1, 3] + Real valued standard deviation of the model distribution, in y. + theta : `[0, np.pi]` + Rotation of the model's covariance matrix, increases counterclockwise. + In radians. + """ + + mode = "progressive" + param_ranges = { + "amplitude": [0.1, 3.0], + "x_mean": [0.0, 4096.0], + "y_mean": [0.0, 2048.0], + "vx": [500.0, 1000.0], + "vy": [500.0, 1000.0], + "stddev": [0.25, 1.5], + "theta": [0.0, np.pi], + } + + +class ObjectCatalog(SimpleCatalog): + """A catalog of moving objects. + + Assumed to be symmetric 2D Gaussians whose centroids are defined in pixel + space and moving in linear fashion with velocity also defined in pixel space. + The units are relative to the timestep. + """ + + default_config = ObjectCatalogConfig + + def __init__(self, table, **kwargs): + # Obj cat always has to return a copy + kwargs["return_copy"] = True + super().__init__(table=table, **kwargs) diff --git a/src/kbmod/mocking/config.py b/src/kbmod/mocking/config.py new file mode 100644 index 000000000..88090b272 --- /dev/null +++ b/src/kbmod/mocking/config.py @@ -0,0 +1,105 @@ +import copy + +__all__ = ["Config", "ConfigurationError"] + + +class ConfigurationError(Exception): + """Error that is raised when configuration parameters contain a logical error.""" + + +class Config: + """Base class for Standardizer configuration. + + Not all standardizers will (can) use the same parameters so refer to their + respective documentation for a more complete list. + + Parameters + ---------- + config : `dict`, `Config` or `None`, optional + Collection of configuration key-value pairs. + kwargs : optional + Keyword arguments, assigned as configuration key-values. + """ + + def __init__(self, config=None, **kwargs): + # This is a bit hacky, but it makes life a lot easier because it + # enables automatic loading of the default configuration and separation + # of default config from instance bound config + keys = list(set(dir(self.__class__)) - set(dir(Config))) + + # First fill out all the defaults by copying cls attrs + self._conf = {k: getattr(self, k) for k in keys} + + # Then override with any user-specified values + if config is not None: + self._conf.update(config) + self._conf.update(kwargs) + + # now just shortcut the most common dict operations + def __getitem__(self, key): + return self._conf[key] + + def __setitem__(self, key, value): + self._conf[key] = value + + def __str__(self): + res = f"{self.__class__.__name__}(" + for k, v in self.items(): + res += f"{k}: {v}, " + return res[:-2] + ")" + + def __len__(self): + return len(self._conf) + + def __contains__(self, key): + return key in self._conf + + def __eq__(self, other): + if isinstance(other, type(self)): + return self._conf == other._conf + elif isinstance(other, dict): + return self._conf == other + else: + return super().__eq__(other) + + def __iter__(self): + return iter(self._conf) + + def __or__(self, other): + if isinstance(other, type(self)): + return self.__class__(config=other._conf | self._conf) + elif isinstance(other, dict): + return self.__class__(config=self._conf | other) + else: + raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}") + + def keys(self): + """A set-like object providing a view on config's keys.""" + return self._conf.keys() + + def values(self): + """A set-like object providing a view on config's values.""" + return self._conf.values() + + def items(self): + """A set-like object providing a view on config's items.""" + return self._conf.items() + + def update(self, conf=None, **kwargs): + """Update this config from dict/other config/iterable. + + A dict-like update. If ``conf`` is present and has a ``.keys()`` + method, then does: ``for k in conf: this[k] = conf[k]``. If ``conf`` + is present but lacks a ``.keys()`` method, then does: + ``for k, v in conf: this[k] = v``. + + In either case, this is followed by: + ``for k in kwargs: this[k] = kwargs[k]`` + """ + if conf is not None: + self._conf.update(conf) + self._conf.update(kwargs) + + def toDict(self): + """Return this config as a dict.""" + return self._conf diff --git a/src/kbmod/mocking/data.py b/src/kbmod/mocking/data.py new file mode 100644 index 000000000..ecf5b3f82 --- /dev/null +++ b/src/kbmod/mocking/data.py @@ -0,0 +1,1034 @@ +import numpy as np +from astropy.io.fits import PrimaryHDU, CompImageHDU, ImageHDU, BinTableHDU, TableHDU +from astropy.modeling import models + +from .config import Config, ConfigurationError +from kbmod import Logging + + +__all__ = [ + "add_model_objects", + "DataFactory", + "SimpleImage", + "SimpleMask", + "SimpleVariance", + "SimulatedImage", +] + + +logger = Logging.getLogger(__name__) + + +def add_model_objects(img, catalog, model): + """Adds a catalog of model objects to the image. + + Parameters + ---------- + img : `np.array` + Image. + catalog : `astropy.table.QTable` + Table of objects, a catalog + model : `astropy.modelling.Model` + Astropy's model of the surface brightness of an source. + + Returns + ------- + img: `np.array` + Image including the rendenred models. + """ + shape = img.shape + yidx, xidx = np.indices(shape) + + # find catalog columns that exist for the model + params_to_set = [] + for param in catalog.colnames: + if param in model.param_names: + params_to_set.append(param) + + # Save the initial model parameters so we can set them back + init_params = {param: getattr(model, param) for param in params_to_set} + + # model could throw a value error if drawn amplitude was too large, we must + # restore the model back to its starting values to cover for a general + # use-case because Astropy modelling is a bit awkward. + try: + for i, source in enumerate(catalog): + for param in params_to_set: + setattr(model, param, source[param]) + model.render(img) + except ValueError as e: + # ignore rendering models larger than the image + message = "The `bounding_box` is larger than the input out in one or more dimensions." + if message in str(e): + pass + finally: + for param, value in init_params.items(): + setattr(model, param, value) + + return img + + +class DataFactoryConfig(Config): + """Data factory configuration primarily controls mutability of the given + and returned mocked datasets. + """ + + default_img_shape = (5, 5) + """Default image size, used if mocking ImageHDU or CompImageHDUs.""" + + default_img_bit_width = 32 + """Default image data type is float32; the value of BITPIX flag in headers. + See bitpix_type_map for other codes. + """ + + default_tbl_length = 5 + """Default table length, used if mocking BinTableHDU or TableHDU HDUs.""" + + default_tbl_dtype = np.dtype([("a", int), ("b", int)]) + """Default table dtype, used when mocking table-HDUs that do not contain + a description of table layout. + """ + + writeable = False + """Sets the base array ``writeable`` flag. Default `False`.""" + + return_copy = False + """When `True`, the `DataFactory.mock` returns a copy of the final object, + otherwise the original (possibly mutable!) object is returned. Default `False`. + """ + + # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man + bitpix_type_map = { + # or char + 8: int, + # actually no idea what dtype, or C type for that matter, + # are used to represent these values. But default Headers return them + 16: np.float16, + 32: np.float32, + 64: np.float64, + # classic IEEE float and double + -32: np.float32, + -64: np.float64, + } + """Map between FITS header BITPIX keyword value and NumPy return type.""" + + +class DataFactory: + """Generic data factory that can mock table and image HDUs from default + settings or given header definitions. + + Given a template, this factory repeats it for each mock. + A reference to the base template is returned whenever possible for + performance reasons. To prevent accidental mutation of the shared + array, the default behavior is that the returned data is not writable. + + A base template value of `None` is accepted as valid to satisfy FITS + factory use-case of generating HDUList stubs containing only headers. + + Primary purpose of this factory is to derive the template data given a + table, HDU or a Header object. When the base has no data, but just a + description of one, such as Headers, the default is to return "zeros" + for that datatype. This can be a zero length string, literal integer + zero, a float zero etc... + + Attributes + ---------- + base : `np.array`, `np.recarray` or `None` + Base data template. + shape : `tuple` + Shape of base array when it exists. + dtype : `type` + Numpy type of the base array, when it exists. + counter : `int` + Data factory tracks an internal counter of generated objects that can + be used as a ticker for generating new data. + + Parameters + ---------- + base : `np.array` + Static data shared by all mocked instances. + kwargs : + Additional keyword arguments are applied as configuration + overrides. + + Examples + -------- + >>> from astropy.io.fits import Header, CompImageHDU, BinTableHDU + >>> import kbmod.mocking as kbmock + >>> import numpy as np + >>> base = np.zeros((2, 2)) + >>> hdu = CompImageHDU(base) + >>> kbmock.DataFactory.from_hdu(hdu).mock() + array([[[0., 0.], + [0., 0.]]]) + >>> kbmock.DataFactory.from_header("image", hdu.header).mock() + array([[[0., 0.], + [0., 0.]]]) + >>> base = np.array([("test1", 10), ("test2", 11)], dtype=[("col1", "U5"), ("col2", int)]) + >>> hdu = BinTableHDU(base) + >>> kbmock.DataFactory.from_hdu(hdu).mock() + array([[(b'test1', 10), (b'test2', 11)]], + dtype=(numpy.record, [('col1', 'S5'), ('col2', '>> kbmock.DataFactory.from_header("table", hdu.header).mock() + array([[(b'', 0), (b'', 0)]], + dtype=(numpy.record, [('col1', 'S5'), ('col2', '>i8')])) + """ + + default_config = DataFactoryConfig + """Default configuration.""" + + def __init__(self, base, **kwargs): + self.config = self.default_config(**kwargs) + + self.base = base + if self.base is not None: + self.shape = base.shape + self.dtype = base.dtype + self.base.flags.writeable = self.config["writeable"] + self.counter = 0 + + @classmethod + def gen_image(cls, header=None, **kwargs): + """Generate an image from a complete or partial header and config. + + If a header is given, it trumps the default config values. When the + header is not complete, config values are used. Config overrides are + applied before the data description is evaluated. + + Parameters + ---------- + header : `None`, `Header` or dict-like, optional + Header, or dict-like object, containing the image-data descriptors. + kwargs : + Any additional keyword arguments are applied as config overrides. + + Returns + ------- + image : `np.array` + Image + """ + conf = cls.default_config(**kwargs) + metadata = {} if header is None else header + cols = metadata.get("NAXIS1", conf["default_img_shape"][0]) + rows = metadata.get("NAXIS2", conf["default_img_shape"][1]) + bitwidth = metadata.get("BITPIX", conf["default_img_bit_width"]) + dtype = conf.bitpix_type_map[bitwidth] + shape = (cols, rows) + return np.zeros(shape, dtype) + + @classmethod + def gen_table(cls, metadata=None, **kwargs): + """Generate an table from a complete or partial header and config. + + If a header is given, it trumps the default config values. When the + header is not complete, config values are used. Config overrides are + applied before the data description is evaluated. + + Parameters + ---------- + header : `None`, `Header` or dict-like, optional + Header, or dict-like object, containing the image-data descriptors. + kwargs : + Any additional keyword arguments are applied as config overrides. + + Returns + ------- + table : `np.array` + Table, a structured array. + + Notes + ----- + FITS format standards prescribe FORTRAN-77-like input format strings + for different data types, but the base set has been extended and/or + altered significantly by various pipelines to support their objects + internal to their pipelines. Constructing objects, or values, described + by non-standard strings will result in a failure. For a list of supported + column-types see: + https://docs.astropy.org/en/stable/io/fits/usage/table.html#column-creation + """ + conf = cls.default_config(**kwargs) + + # https://github.com/lsst/afw/blob/main/src/fits.cc#L207 + # So we really don't have much of a choice but to force a default + # AstroPy HDU and then call the update. This might not preserve the + # header or the data formats exactly and if metadata isn't given + # could even assume a wrong class all together. The TableHDU is + # almost never used by us however - so hopefully this keeps on working. + table_cls = BinTableHDU + data = None + if metadata is not None: + if metadata["XTENSION"] == "BINTABLE": + table_cls = BinTableHDU + elif metadata["XTENSION"] == "TABLE": + table_cls = TableHDU + + hdu = table_cls() + hdu.header.update(metadata) + + rows = metadata.get("NAXIS2", conf.default_tbl_length) + shape = (rows,) + data = np.zeros(shape, dtype=hdu.data.dtype) + else: + hdu = table_cls() + shape = (conf.default_tbl_length,) + data = np.zeros(shape, dtype=conf.default_tbl_dtype) + + return data + + @classmethod + def from_hdu(cls, hdu, **kwargs): + """Create the factory from an HDU with or without data and with or + without a complete Header. + + If the given HDU has data, it is preferred over creating a zero-array + based on the header. If the header is not complete, config defaults are + used. Config overrides are applied beforehand. + + Parameters + ---------- + hdu : `HDU` + One of AstroPy's Header Data Unit classes. + kwargs : + Config overrides. + + Returns + ------- + data : `np.array` + Data array, an ndarray or a recarray depending on the HDU. + """ + if isinstance(hdu, (PrimaryHDU, CompImageHDU, ImageHDU)): + base = hdu.data if hdu.data is not None else cls.gen_image(hdu.header) + return cls(base=base, **kwargs) + elif isinstance(hdu, (TableHDU, BinTableHDU)): + base = hdu.data if hdu.data is not None else cls.gen_table(hdu.header) + return cls(base=base, **kwargs) + else: + raise TypeError(f"Expected an HDU, got {type(hdu)} instead.") + + @classmethod + def from_header(cls, header, kind=None, **kwargs): + """Create the factory from an complete or partial Header. + + Provide the ``kind`` of data the header represents in situations where + the Header does not have an well defined ``XTENSION`` card. + + Parameters + ---------- + header : `astropy.io.fits.Header` + Header + kind : `str` or `None`, optional + Kind of data the header is representing. + kwargs : + Config overrides. + + Returns + ------- + data : `np.array` + Data array, an ndarray or a recarray depending on the Header and kind. + """ + hkind = header.get("XTENSION", False) + if hkind and "table" in hkind.lower(): + kind = "table" + elif hkind and "image" in hkind.lower(): + kind = "image" + elif kind is None: + raise ValueError("Must provide a header with XTENSION or ``kind``") + else: + # kind was defined as keyword arg, so all is right + pass + + if kind.lower() == "image": + return cls(base=cls.gen_image(header), **kwargs) + elif kind.lower() == "table": + return cls(base=cls.gen_table(header), **kwargs) + else: + raise TypeError(f"Expected an 'image' or 'table', got {kind} instead.") + + def mock(self, n=1): + """Mock one or multiple data arrays. + + Parameters + ---------- + n : `int` + Number of data to mock. + """ + if self.base is None: + raise ValueError( + "Expected a DataFactory that has a base, but none was set. " + "Use `zeros` or `from_hdu` to construct this object correctly." + ) + + if self.config["return_copy"]: + base = np.repeat(self.base[np.newaxis,], (n,), axis=0) + else: + base = np.broadcast_to(self.base, (n, *self.shape)) + base.flags.writeable = self.config["writeable"] + + return base + + +class SimpleVarianceConfig(DataFactoryConfig): + """Configure noise and gain of a simple variance factory.""" + + read_noise = 0.0 + """Read noise""" + + gain = 1.0 + "Gain." + + +class SimpleVariance(DataFactory): + """Simple variance factory. + + Variance is calculated as the + + variance = image/gain + read_noise^2 + + thus variance has to be calculated for each individual mocked image. + + Parameters + ---------- + image : `np.array` + Science image from which the variance will be derived from. + config : `DataFactoryConfig` + Configuration of the data factory. + **kwargs : + Additional keyword arguments are applied as config + overrides. + + Examples + -------- + >>> import kbmod.mocking as kbmock + >>> si = kbmock.SimpleImage(shape=(3, 3), add_noise=True, seed=100) + >>> sv = kbmock.SimpleVariance(gain=10) + >>> imgs = si.mock() + >>> imgs + array([[[ 8.694266, 9.225379, 10.046582], + [ 8.768851, 10.201585, 8.870326], + [10.702058, 9.910087, 9.283925]]], dtype=float32) + >>> sv.mock(imgs) + array([[[0.8694266 , 0.9225379 , 1.0046582 ], + [0.8768851 , 1.0201585 , 0.8870326 ], + [1.0702058 , 0.99100864, 0.9283925 ]]], dtype=float32) + """ + + default_config = SimpleVarianceConfig + + def __init__(self, image=None, **kwargs): + super().__init__(base=image, **kwargs) + if image is not None: + self.base = image / self.config["gain"] + self.config["read_noise"] ** 2 + + def mock(self, images=None): + """Mock one or multiple data arrays. + + Parameters + ---------- + images : `list[np.array]`, optional + List, or otherwise a collection, of images from which the variances + will be generated. When not provided, and base template was + defined, returns the base template. + """ + if images is None: + return self.base + return images / self.config["gain"] + self.config["read_noise"] ** 2 + + +class SimpleMaskConfig(DataFactoryConfig): + """Simple mask configuration.""" + + dtype = np.float32 + """Data type""" + + threshold = 1e-05 + """Default pixel value threshold above which every pixel in the template + will be masked. + """ + + shape = (5, 5) + """Default image shape.""" + + padding = 0 + """Number of pixels near the edge that are masked.""" + + bad_columns = [] + """List of columns marked as bad.""" + + patches = [] + """Default patches to mask. This is a list of tuples. Each tuple consists of + a patch and a value. The patch can be any combination of array coordinates + such as ``(int, int)`` for individual pixels, ``(slice, int)`` or + ``(int, slice)`` for columns and rows respectively or ``(slice, slice)`` + for areas. See `SimpleMask.from_params` for an example. + """ + + +class SimpleMask(DataFactory): + """Simple mask factory. + + Masks are assumed to be shared, static data. To create an instance of this + factory use one of the provided class methods. Created mask will correspond + to a bitmask already appropriate for use with KBMOD. + + Parameters + ---------- + mask : `np.array` + Bitmask array. + kwargs : + Config overrides. + + Examples + -------- + >>> import kbmod.mocking as kbmock + >>> si = kbmock.SimpleImage(shape=(3, 3), add_noise=True, seed=100) + >>> imgs = si.mock() + >>> imgs + array([[[ 8.694266, 9.225379, 10.046582], + [ 8.768851, 10.201585, 8.870326], + [10.702058, 9.910087, 9.283925]]], dtype=float32) + >>> sm = kbmock.SimpleMask.from_image(imgs, threshold=9) + >>> sm.base + array([[[0., 1., 1.], + [0., 1., 0.], + [1., 1., 1.]]], dtype=float32) + """ + + default_config = SimpleMaskConfig + """Default configuration.""" + + def __init__(self, mask, **kwargs): + super().__init__(base=mask, **kwargs) + + @classmethod + def from_image(cls, image, **kwargs): + """Create a factory instance out of an image, masking all pixels above + a threshold. + + Parameters + ---------- + image : `np.array` + Template image from which a mask is created. + kwargs : + Config overrides. + """ + config = cls.default_config(**kwargs) + mask = image.copy() + mask[image > config["threshold"]] = 1 + mask[image <= config["threshold"]] = 0 + return cls(mask) + + @classmethod + def from_params(cls, **kwargs): + """Create a factory instance from config parameters. + + Parameters + ---------- + kwargs : + Config overrides. + + Examples + -------- + >>> SimpleMask.from_params( + shape=(10, 10), + padding=1, + bad_columns=[2, 3], + patches=[ + ((5, 5), 2), + ((slice(6, 8), slice(6, 8)), 3) + ] + ) + array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], + [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], + [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], + [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], + [1., 0., 1., 1., 0., 1., 0., 0., 0., 1.], + [1., 0., 1., 1., 0., 0., 1., 1., 0., 1.], + [1., 0., 1., 1., 0., 0., 1., 1., 0., 1.], + [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) + """ + config = cls.default_config(**kwargs) + mask = np.zeros(config["shape"], dtype=config["dtype"]) + + shape, padding = config["shape"], config["padding"] + + # padding + mask[:padding] = 1 + mask[shape[0] - padding :] = 1 + mask[:, :padding] = 1 + mask[: shape[1] - padding :] = 1 + + # bad columns + for col in config["bad_columns"]: + mask[:, col] = 1 + + for patch, value in config["patches"]: + if isinstance(patch, tuple): + mask[patch] = 1 + elif isinstance(slice): + mask[slice] = 1 + else: + raise ValueError(f"Expected a tuple (x, y), (slice, slice) or slice, got {patch} instead.") + + return cls(mask, **config) + + +class SimpleImageConfig(DataFactoryConfig): + """Simple image configuration.""" + + return_copy = True + + shape = (100, 100) + """Dimensions of the generated images.""" + + add_noise = False + """Add noise to the base image.""" + + seed = None + """Seed of the random number generator used to generate noise.""" + + noise = 10 + """Mean of the standard Gaussian distribution of the noise.""" + + noise_std = 1.0 + """Standard deviation of the Gaussian distribution of the noise.""" + + model = models.Gaussian2D(x_stddev=1, y_stddev=1) + """Source and object model used to render them on the image.""" + + dtype = np.float32 + """Numpy data type.""" + + +class SimpleImage(DataFactory): + """Simple image data factory. + + Simple image consists of an blank empty base, onto which noise, sources + and objects can be added. All returned images act as a copy of the base + image. + + Noise realization is drawn from a Gaussian distribution with the given + standard deviation and mean. + + Parameters + ---------- + image : `np.array` + Science image that will be used as a base onto which to render details. + src_cat : `CatalogFactory` + Static source catalog. + obj_cat : `CatalogFactory` + Moving object catalog factory. + kwargs : + Additional keyword arguments are applied as config. + overrides. + + Examples + -------- + >>> import kbmod.mocking as kbmock + >>> si = kbmock.SimpleImage() + >>> si.mock() + array([[[0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + ..., + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32) + >>> si = kbmock.SimpleImage(shape=(3, 3), add_noise=True, seed=100) + >>> si.mock() + array([[[ 8.694266, 9.225379, 10.046582], + [ 8.768851, 10.201585, 8.870326], + [10.702058, 9.910087, 9.283925]]], dtype=float32) + """ + + default_config = SimpleImageConfig + """Default configuration.""" + + def __init__(self, image=None, src_cat=None, obj_cat=None, **kwargs): + super().__init__(image, **kwargs) + + if image is None: + image = np.zeros(self.config["shape"], dtype=self.config["dtype"]) + else: + image = image + self.config["shape"] = image.shape + + # Astropy throws a strange ValueError instead of reporting a non-writeable + # array, This must be a bug TODO: report. It's not safe to edit a + # non-writeable array and then revert writeability so make a copy. + self.src_cat = src_cat + if self.src_cat is not None: + image = image if image.flags.writeable else image.copy() + add_model_objects(image, src_cat.table, self.config["model"]) + image.flags.writeable = self.config["writeable"] + + self.base = image + self._base_contains_data = image.sum() != 0 + + @classmethod + def add_noise(cls, images, config): + """Adds gaussian noise to the images. + + Parameters + ---------- + images : `np.array` + A ``(n_images, image_width, image_height)`` shaped array of images. + config : `SimpleImageConfig` + Configuration. + + Returns + ------- + images : `np.array` + A ``(n_images, image_width, image_height)`` shaped array of images. + """ + rng = np.random.default_rng(seed=config["seed"]) + shape = images.shape + + # noise has to be resampled for every image + rng.standard_normal(size=shape, dtype=images.dtype, out=images) + + # There's a lot of multiplications that happen, skip if possible + if config["noise_std"] != 1.0: + images *= config["noise_std"] + images += config["noise"] + + return images + + def mock(self, n=1, obj_cats=None, **kwargs): + """Creates a set of images. + + Parameters + ---------- + n : `int`, optional + Number of images to create, default: 1. + obj_cats : `list[Catalog]` + A list of catalogs as long as the number of requested images of + moving objects that will be inserted into the image. + + Returns + ------- + images : `np.array` + A ``(n_images, image_width, image_height)`` shaped array of images. + """ + shape = (n, *self.config["shape"]) + images = np.zeros(shape, dtype=self.config["dtype"]) + + if self.config["add_noise"]: + images = self.add_noise(images=images, config=self.config) + + # if base has no data (no sources, bad cols etc) skip + if self._base_contains_data: + images += self.base + + # same with moving objects, each image has to have a new realization of + # a catalog in which moving objects have different coordinate. This way + # any trajectory can be mocked. When we have only 1 mocked image though + # zip will attempt to iterate over the next available dimension, and + # that's rows of the image and the table - we don't want that. + if obj_cats is not None: + pairs = [(images[0], obj_cats[0])] if n == 1 else zip(images, obj_cats) + for i, (img, cat) in enumerate(pairs): + add_model_objects(img, cat, self.config["model"]) + + return images + + +class SimulatedImageConfig(DataFactoryConfig): + """Simulated image configuration. + + Simulated image attempts to add noise to the image in a statistically + meaningful sense, but it does not reproduce the noise qualities in the same + way an optical simulation would. Noise sources added are: + - bad columns + - hot pixels + - read noise + - dark current + - sky level + + The quantities are expressed in physical units and the defaults were + selected to sort of make sense. + + Control over which source of noise are included in the image can be done by + setting the + - add_noise + - add_bad_cols + - add_hot_pix + flags to `False`. For a more fine-grained control set the distribution + parameters, f.e. mean and standard deviation, such that they do not produce + measurable values in the image. + + Expect the mean value of pixel counts to be: + + bias + mean(dark_current)*exposure + mean(sky_level) + + The deviation of the pixel counts should be expected to be: + + sqrt( std(read_noise)^2 + sqrt(sky_level)^2 ) + """ + + # not sure this is a smart idea to put here + rng = np.random.default_rng() + + seed = None + """Random number generator seed shared by all number generators.""" + + # image properties + shape = (100, 100) + """Dimensions of the created images.""" + + # detector properties + add_noise = True + """Add noise (includes read noise, dark current and sky) to the image.""" + + read_noise_gen = rng.normal + """Read noise follows a Gaussian distribution.""" + + read_noise = 5 + """Standard deviation of read noise distribution, in electrons.""" + + gain = 1.0 + """Gain in electrons/count.""" + + bias = 0.0 + """Bias in counts.""" + + add_bad_columns = False + """Add bad columns to the image.""" + + bad_cols_method = "random" + """Method by which bad columns are picked. If not 'random', 'bad col_locs' + must be provided.""" + + bad_col_locs = [] + """Location, column indices, of bad columns.""" + + n_bad_cols = 5 + """When bad columns method is random, sets the number of bad columns.""" + + bad_cols_seed = seed + """Seed for the bad columns random number generator.""" + + bad_col_offset = 100 + """Bad column signal offset (in counts) with regards to the baseline noise.""" + + bad_col_pattern_offset = 10 + """Random-looking noise variation along the length of the bad columns is + offset from the mean bad column counts by this amount.""" + + dark_current_gen = rng.poisson + """Dark current follows a Poisson distribution.""" + + dark_current = 0.1 + """Dark current mean in electrons/pixel/sec. Typically ~0.1-0.2.""" + + add_hot_pixels = False + """Simulate hot pixels.""" + + hot_pix_method = "random" + """When `random` the hop pixels are selected randomly, otherwise their + indices must be provided.""" + + hot_pix_locs = [] + """A `list[tuple]` of hot pixel indices.""" + + hot_pix_seed = seed + """Seed for hot pixel random number generator.""" + + n_hot_pix = 10 + """Number of hot pixels to insert into the image.""" + + hot_pix_offset = 10000 + """Offset of hot pixel counts from the baseline. Usally a very large number.""" + + # Observation properties + exposure_time = 120.0 + """Exposure time of the simulated image, affects noise properties.""" + + sky_count_gen = rng.poisson + """Sky background random number generator.""" + + sky_level = 20 + """Sky level, in counts.""" + + # Object and Source properties + model = models.Gaussian2D(x_stddev=1, y_stddev=1) + """Source and object model used to render them on the image.""" + + dtype = np.float32 + """Numpy data type.""" + + +class SimulatedImage(SimpleImage): + """Simulated image attempt to include a more realistic noise profile. + + Noise sources added are: + - bad columns + - hot pixels + - read noise + - dark current + - sky level + + Static or moving objects may be added to the simulated image. + + Parameters + ---------- + image : `np.array` + Base template image on which details will be rendered. + src_cat : `CatalogFactory` + Static source catalog. + obj_cat : `CatalogFactory` + Moving object catalog factory. + **kwargs : + Additional keyword arguments are applied as config + overrides. + """ + + default_config = SimulatedImageConfig + """Default config.""" + + @classmethod + def add_bad_cols(cls, image, config): + """Adds bad columns to the image based on the configuration. + + Columns can be sampled randomly, or a list of bad column indices can + be provided. + + Parameters + ---------- + image : `np.array` + Image. + config : `SimpleImageConfig` + Configuration. + + Returns + ------- + image : `np.array` + Image. + """ + if not config["add_bad_columns"]: + return image + + shape = image.shape + rng = np.random.RandomState(seed=config["bad_cols_seed"]) + if config["bad_cols_method"] == "random": + bad_cols = rng.randint(0, shape[1], size=config["n_bad_cols"]) + elif config["bad_col_locs"]: + bad_cols = config["bad_col_locs"] + else: + raise ConfigurationError( + "Bad columns method is not 'random', but `bad_col_locs` contains no column indices." + ) + + col_pattern = rng.randint(low=0, high=int(config["bad_col_pattern_offset"]), size=shape[0]) + + for col in bad_cols: + image[:, col] += col_pattern + config["bad_col_offset"] + + return image + + @classmethod + def add_hot_pixels(cls, image, config): + """Adds hot pixels to the image based on the configuration. + + Indices of hot pixels can be sampled randomly, or a list of hot pixel + indices can be provided. + + Parameters + ---------- + image : `np.array` + Image. + config : `SimpleImageConfig` + Configuration. + + Returns + ------- + image : `np.array` + Image. + """ + if not config["add_hot_pixels"]: + return image + + shape = image.shape + if config["hot_pix_method"] == "random": + rng = np.random.RandomState(seed=config["hot_pix_seed"]) + x = rng.randint(0, shape[1], size=config["n_hot_pix"]) + y = rng.randint(0, shape[0], size=config["n_hot_pix"]) + hot_pixels = np.column_stack([x, y]) + elif config["hot_pix_locs"]: + hot_pixels = config["hot_pix_locs"] + else: + raise ConfigurationError( + "Hot pixels method is not 'random', but `hot_pix_locs` contains " + "no (col, row) location indices of hot pixels." + ) + + for pix in hot_pixels: + image[pix] += config["hot_pix_offset"] + + return image + + @classmethod + def add_noise(cls, images, config): + """Adds read noise (gaussian), dark noise (poissonian) and sky + background (poissonian) noise to the image. + + Parameters + ---------- + image : `np.array` + Image. + config : `SimpleImageConfig` + Configuration. + + Returns + ------- + image : `np.array` + Image. + """ + shape = images.shape + + # add read noise + images += config["read_noise_gen"](scale=config["read_noise"] / config["gain"], size=shape) + + # add dark current + current = config["dark_current"] * config["exposure_time"] / config["gain"] + images += config["dark_current_gen"](current, size=shape) + + # add sky counts + images += ( + config["sky_count_gen"](lam=config["sky_level"] * config["gain"], size=shape) / config["gain"] + ) + + return images + + @classmethod + def gen_base_image(cls, config=None, src_cat=None): + """Generate base image from configuration. + + Parameters + ---------- + config : `SimpleImageConfig` + Configuration. + src_cat : `CatalogFactory` + Static source catalog. + + Returns + ------- + image : `np.array` + Image. + """ + config = cls.default_config(config) + + # empty image + base = np.zeros(config["shape"], dtype=config["dtype"]) + base += config["bias"] + base = cls.add_hot_pixels(base, config) + base = cls.add_bad_cols(base, config) + if src_cat is not None: + add_model_objects(base, src_cat.table, config["model"]) + + return base + + def __init__(self, image=None, src_cat=None, obj_cat=None, **kwargs): + conf = self.default_config(**kwargs) + super().__init__(image=self.gen_base_image(conf), src_cat=src_cat, obj_cat=obj_cat, **conf) diff --git a/src/kbmod/mocking/dump_headers.py b/src/kbmod/mocking/dump_headers.py new file mode 100644 index 000000000..f5a26ec07 --- /dev/null +++ b/src/kbmod/mocking/dump_headers.py @@ -0,0 +1,309 @@ +# Modified from the original Astropy code to add a card format to the +# tabular output. All rights belong to the original authors. + +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +Modified Astropy ``archiveheaders`` utility that adds a datatype column for each +header card to the output. All rights belong to the original authors. + +``archiveheaders`` is a command line script based on astropy.io.fits for printing +the header(s) of one or more FITS file(s) to the standard output in a human- +readable format. + +The modifications to the script include: +- supporting only tabular output +- making "ascii.ecsv" the default output format +- appending a datatype to each serialized header card, describing the type of + the cards value. + +Example uses of ``archiveheaders``: + +1. Print the header of all the HDUs of a .fits file:: + + $ archiveheaders filename.fits + +2. Dump the header keywords of all the files in the current directory into a + machine-readable ecsv file:: + + $ archiveheaders *.fits > keywords.csv + +3. Sorting the output along a specified keyword:: + + $ archiveheaders -f -s DATE-OBS *.fits + +4. Sort first by OBJECT, then DATE-OBS:: + + $ archiveheaders -f -s OBJECT -s DATE-OBS *.fits + +Note that compressed images (HDUs of type +:class:`~astropy.io.fits.CompImageHDU`) really have two headers: a real +BINTABLE header to describe the compressed data, and a fake IMAGE header +representing the image that was compressed. Astropy returns the latter by +default. You must supply the ``--compressed`` option if you require the real +header that describes the compression. + +With Astropy installed, please run ``archiveheaders --help`` to see the full usage +documentation. +""" + +import argparse +import sys + +import numpy as np + +from astropy import __version__, log +from astropy.io import fits + +DESCRIPTION = """ +Print the header(s) of a FITS file. Optional arguments allow the desired +extension(s), keyword(s), and output format to be specified. +Note that in the case of a compressed image, the decompressed header is +shown by default. + +This script is part of the Astropy package. See +https://docs.astropy.org/en/latest/io/fits/usage/scripts.html#module-astropy.io.fits.scripts.fitsheader +for further documentation. +""".strip() + + +class ExtensionNotFoundException(Exception): + """Raised if an HDU extension requested by the user does not exist.""" + + +class HeaderFormatter: + """Class to format the header(s) of a FITS file for display by the + `fitsheader` tool; essentially a wrapper around a `HDUList` object. + + Example usage: + fmt = HeaderFormatter('/path/to/file.fits') + print(fmt.parse()) + + Parameters + ---------- + filename : str + Path to a single FITS file. + verbose : bool + Verbose flag, to show more information about missing extensions, + keywords, etc. + + Raises + ------ + OSError + If `filename` does not exist or cannot be read. + """ + + def __init__(self, filename, verbose=True): + self.filename = filename + self.verbose = verbose + self._hdulist = fits.open(filename) + + def parse(self, compressed=False): + """Returns the FITS file header(s) in a readable format. + + Parameters + ---------- + compressed : bool, optional + If True, shows the header describing the compression, rather than + the header obtained after decompression. (Affects FITS files + containing `CompImageHDU` extensions only.) + + Returns + ------- + formatted_header : str or astropy.table.Table + Traditional 80-char wide format in the case of `HeaderFormatter`; + an Astropy Table object in the case of `TableHeaderFormatter`. + """ + hdukeys = range(len(self._hdulist)) # Display all by default + return self._parse_internal(hdukeys, compressed) + + def _parse_internal(self, hdukeys, compressed): + """The meat of the formatting; in a separate method to allow overriding.""" + result = [] + for idx, hdu in enumerate(hdukeys): + try: + cards = self._get_cards(hdu, compressed) + except ExtensionNotFoundException: + continue + + if idx > 0: # Separate HDUs by a blank line + result.append("\n") + result.append(f"# HDU {hdu} in {self.filename}:\n") + for c in cards: + result.append(f"{c}\n") + return "".join(result) + + def _get_cards(self, hdukey, compressed): + """Returns a list of `astropy.io.fits.card.Card` objects. + + This function will return the desired header cards, taking into + account the user's preference to see the compressed or uncompressed + version. + + Parameters + ---------- + hdukey : int or str + Key of a single HDU in the HDUList. + + compressed : bool, optional + If True, shows the header describing the compression. + + Raises + ------ + ExtensionNotFoundException + If the hdukey does not correspond to an extension. + """ + # First we obtain the desired header + try: + if compressed: + # In the case of a compressed image, return the header before + # decompression (not the default behavior) + header = self._hdulist[hdukey]._header + else: + header = self._hdulist[hdukey].header + except (IndexError, KeyError): + message = f"{self.filename}: Extension {hdukey} not found." + if self.verbose: + log.warning(message) + raise ExtensionNotFoundException(message) + + # return all cards + cards = header.cards + return cards + + def close(self): + self._hdulist.close() + + +class TableHeaderFormatter(HeaderFormatter): + """Class to convert the header(s) of a FITS file into a Table object. + The table returned by the `parse` method will contain four columns: + filename, hdu, keyword, and value. + + Subclassed from HeaderFormatter, which contains the meat of the formatting. + """ + + def _parse_internal(self, hdukeys, compressed): + """Method called by the parse method in the parent class.""" + tablerows = [] + for hdu in hdukeys: + try: + for card in self._get_cards(hdu, compressed): + tablerows.append( + { + "filename": self.filename, + "hdu": hdu, + "keyword": card.keyword, + "value": str(card.value), + "format": type(card.value).__name__, + } + ) + except ExtensionNotFoundException: + pass + + if tablerows: + from astropy import table + + return table.Table(tablerows) + return None + + +def print_headers_as_table(args): + """Prints FITS header(s) in a machine-readable table format. + + Parameters + ---------- + args : argparse.Namespace + Arguments passed from the command-line as defined below. + """ + tables = [] + # Create a Table object for each file + for filename in args.filename: # Support wildcards + formatter = None + try: + formatter = TableHeaderFormatter(filename) + tbl = formatter.parse(args.compressed) + if tbl: + tables.append(tbl) + except OSError as e: + log.error(str(e)) # file not found or unreadable + finally: + if formatter: + formatter.close() + + # Concatenate the tables + if len(tables) == 0: + return False + elif len(tables) == 1: + resulting_table = tables[0] + else: + from astropy import table + + resulting_table = table.vstack(tables) + # Print the string representation of the concatenated table + resulting_table.write(sys.stdout, format=args.table) + + +def main(args=None): + """This is the main function called by the `fitsheader` script.""" + parser = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter + ) + mode_group = parser.add_mutually_exclusive_group() + mode_group.add_argument( + "-t", + "--table", + nargs="?", + default="ascii.ecsv", + metavar="FORMAT", + help=( + '"The default format is "ascii.ecsv" (can be "ascii.csv", "ascii.html", ' + '"ascii.latex", "fits", etc)' + ), + ) + mode_group.add_argument( + "-f", + "--fitsort", + action="store_true", + help=("print the headers as a table with each unique keyword in a given column (fitsort format) "), + ) + parser.add_argument( + "-s", + "--sort", + metavar="SORT_KEYWORD", + action="append", + type=str, + help=( + "sort output by the specified header keywords, can be repeated to " + "sort by multiple keywords; Only supported with -f/--fitsort" + ), + ) + parser.add_argument( + "-c", + "--compressed", + action="store_true", + help=( + "for compressed image data, show the true header which describes " + "the compression rather than the data" + ), + ) + parser.add_argument( + "filename", + nargs="+", + help="path to one or more files; wildcards are supported", + ) + args = parser.parse_args() + + if args.sort: + args.sort = [key.replace(".", " ") for key in args.sort] + if not args.fitsort: + log.error("Sorting with -s/--sort is only supported in conjunction with" " -f/--fitsort") + # 2: Unix error convention for command line syntax + sys.exit(2) + + # Now print the desired headers + try: + print_headers_as_table(args) + except OSError: + # A 'Broken pipe' OSError may occur when stdout is closed prematurely, + # eg. when calling `fitsheader file.fits | head`. We let this pass. + pass diff --git a/src/kbmod/mocking/fits.py b/src/kbmod/mocking/fits.py new file mode 100644 index 000000000..bc492335f --- /dev/null +++ b/src/kbmod/mocking/fits.py @@ -0,0 +1,410 @@ +from astropy.io.fits import HDUList, PrimaryHDU, CompImageHDU, BinTableHDU +from astropy.wcs import WCS + +from .callbacks import IncrementObstime +from .headers import HeaderFactory, ArchivedHeader +from .data import ( + DataFactoryConfig, + DataFactory, + SimpleImage, + SimulatedImage, + SimpleVariance, + SimpleMask, +) + + +__all__ = [ + "EmptyFits", + "SimpleFits", + "DECamImdiff", +] + + +class NoneFactory: + """Factory that returns `None`.""" + + def mock(self, n): + return [ + None, + ] * n + + +class EmptyFits: + """Mock FITS files containing zeroed arrays. + + Mocks a FITS file containing 4 extensions: + - primary + - image + - variance + - mask + that contain no data. By default the created data is immutable. Useful when + data is added after mocking a collection of HDULists. + + The primary header contains an incrementing timestamp. + + Attributes + ---------- + prim_hdr : `HeaderFactory` + Primary header factory + img_hdr : `HeaderFactory` + Image header factory, contains WCS. + var_hdr : `HeaderFactory` + Variance header factory, contains WCS. + mask_hdr: `HeaderFactory` + Mask header factory, contains WCS. + img_data : `DataFactory` + Image data factory, used to also create variances. + mask_data : `DataFactory` + Mask data factory. + current : `int` + Counter of current mocking iteration. + + Parameters + ---------- + header : `dict-like` or `None`, optional + Keyword-value pairs that will use to create and update the primary header + metadata. + shape : `tuple`, optional + Size of the image arrays, 100x100 pixels by default. + start_t : `str` or `astropy.time.Time`, optional + A timestamp string interpretable by Astropy, or a time object. + step_t : `float` or `astropy.time.TimeDelta`, optional + Timestep between each mocked instance. In units of days by default. + editable_images : `bool`, optional + Make mocked images writable and independent. `False` by default. + Makes the variance images editable too. + editable_masks : `bool`, optional + Make masks writable and independent. `False` by default. Masks can + usually be shared, so leaving them immutable reduces the memory footprint + and time it takes to mock large images. + """ + + def __init__( + self, + header=None, + shape=(100, 100), + start_t="2024-01-01T00:00:00.00", + step_t=0.001, + editable_images=False, + editable_masks=False, + ): + self.prim_hdr = HeaderFactory.from_primary_template( + overrides=header, mutables=["DATE-OBS"], callbacks=[IncrementObstime(start=start_t, dt=step_t)] + ) + + self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, shape=shape) + self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, shape=shape) + self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, shape=shape) + + self.img_data = DataFactory.from_header( + header=self.img_hdr.header, kind="image", writeable=editable_images, return_copy=editable_images + ) + self.mask_data = DataFactory.from_header( + header=self.mask_hdr.header, kind="image", return_copy=editable_masks, writeable=editable_masks + ) + + self.current = 0 + + def mock(self, n=1): + """Mock n empty fits files.""" + prim_hdrs = self.prim_hdr.mock(n=n) + img_hdrs = self.img_hdr.mock(n=n) + var_hdrs = self.var_hdr.mock(n=n) + mask_hdrs = self.mask_hdr.mock(n=n) + + images = self.img_data.mock(n=n) + variances = self.img_data.mock(n=n) + masks = self.mask_data.mock(n=n) + + hduls = [] + for ph, ih, vh, mh, imd, vd, md in zip( + prim_hdrs, img_hdrs, var_hdrs, mask_hdrs, images, variances, masks + ): + hduls.append( + HDUList( + hdus=[ + PrimaryHDU(header=ph), + CompImageHDU(header=ih, data=imd), + CompImageHDU(header=vh, data=vd), + CompImageHDU(header=mh, data=md), + ] + ) + ) + + self.current += n + return hduls + + +class SimpleFits: + """Mock FITS files containing data. + + Mocks a FITS file containing 4 extensions: + - primary + - image + - variance + - mask + that contain no data. By default the created data is mutable. + The primary header contains an incrementing timestamp. + + Attributes + ---------- + prim_hdr : `HeaderFactory` + Primary header factory + img_hdr : `HeaderFactory` + Image header factory, contains WCS. + var_hdr : `HeaderFactory` + Variance header factory, contains WCS. + mask_hdr: `HeaderFactory` + Mask header factory, contains WCS. + img_data : `SimpleImage` or `SimulatedImage` + Image data factory. + var_data : `SimpleVariance`` + Variance data factory. + mask_data : `SimpleMask` + Mask data factory. + current : `int` + Counter of current mocking iteration. + + Parameters + ---------- + shared_header_metadata : `dict-like` or `None`, optional + Keyword-value pairs that will use to create and update all of the headers. + shape : `tuple`, optional + Size of the image arrays, 100x100 pixels by default. + start_t : `str` or `astropy.time.Time`, optional + A timestamp string interpretable by Astropy, or a time object. + step_t : `float` or `astropy.time.TimeDelta`, optional + Timestep between each mocked instance. In units of days by default. + with_noise : `bool` + Add noise into the images. + noise : `str` + Noise profile to use, "simplistic" is simple Gaussian noise and + "realistic" simulates several noise sources and adds them together. + src_cat : `SourceCatalog` + Source catalog of static objects to add into the images. + obj_cat : `ObjectCatalog` + Object catalog of moving objects to add into the images. + wcs_factory : `WCSFactory` + Factory used to create and update WCS data in headers of mocked FITS + files. + """ + + def __init__( + self, + shared_header_metadata=None, + shape=(100, 100), + start_t="2024-01-01T00:00:00.00", + step_t=0.001, + with_noise=False, + noise="simplistic", + src_cat=None, + obj_cat=None, + wcs_factory=None, + ): + # 2. Set up Header and Data factories that go into creating HDUs + # 2.1) First headers, since that metadata specified data formats + self.prim_hdr = HeaderFactory.from_primary_template( + overrides=shared_header_metadata, + mutables=["DATE-OBS"], + callbacks=[IncrementObstime(start=start_t, dt=step_t)], + ) + + wcs = None + if wcs_factory is not None: + wcs = wcs_factory + + if shared_header_metadata is None: + shared_header_metadata = {"EXTNAME": "IMAGE"} + + self.img_hdr = HeaderFactory.from_ext_template( + overrides=shared_header_metadata.copy(), shape=shape, wcs=wcs + ) + shared_header_metadata["EXTNAME"] = "VARIANCE" + self.var_hdr = HeaderFactory.from_ext_template( + overrides=shared_header_metadata.copy(), shape=shape, wcs=wcs + ) + shared_header_metadata["EXTNAME"] = "MASK" + self.mask_hdr = HeaderFactory.from_ext_template( + overrides=shared_header_metadata.copy(), shape=shape, wcs=wcs + ) + + # 2.2) Then data factories + if noise == "realistic": + self.img_data = SimulatedImage(shape=shape, src_cat=src_cat, add_noise=with_noise) + else: + self.img_data = SimpleImage(shape=shape, src_cat=src_cat, add_noise=with_noise) + self.var_data = SimpleVariance(self.img_data.base) + self.mask_data = SimpleMask.from_image(self.img_data.base) + + self.start_t = start_t + self.step_t = step_t + self.obj_cat = obj_cat + self.current = 0 + + def mock(self, n=1): + """Mock n simple FITS files.""" + prim_hdrs = self.prim_hdr.mock(n=n) + img_hdrs = self.img_hdr.mock(n=n) + var_hdrs = self.var_hdr.mock(n=n) + mask_hdrs = self.mask_hdr.mock(n=n) + + obj_cats = None + if self.obj_cat is not None: + obj_cats = self.obj_cat.mock( + n=n, + dt=self.step_t, + t=[hdr["DATE-OBS"] for hdr in prim_hdrs], + wcs=[WCS(hdr) for hdr in img_hdrs], + ) + + images = self.img_data.mock(n, obj_cats=obj_cats) + variances = self.var_data.mock(images=images) + masks = self.mask_data.mock(n) + + hduls = [] + for ph, ih, vh, mh, imd, vd, md in zip( + prim_hdrs, img_hdrs, var_hdrs, mask_hdrs, images, variances, masks + ): + hduls.append( + HDUList( + hdus=[ + PrimaryHDU(header=ph), + CompImageHDU(header=ih, data=imd), + CompImageHDU(header=vh, data=vd), + CompImageHDU(header=mh, data=md), + ] + ) + ) + + self.current += n + return hduls + + +class DECamImdiff: + """Mock FITS files from archived headers of Rubin Science Pipelines + "differenceExp" dataset type headers (arXiv:2109.03296). + + Each FITS file contains 16 HDUs, one PRIMARY, 3 images (image, mask and + variance) and supporting data such as PSF, ArchiveId etc. stored in + `BinTableHDU`s. + + The exported data contains approximately 60 real header data of a Rubin + Science Pipelines ``differenceExp`` dataset type. The data was created from + DEEP B1a field, as described in arXiv:2310.03678. + + By default only headers are mocked. + + Attributes + ---------- + hdr_factory : `ArchivedHeader` + Header factory for all 16 HDUs. + data_factories : `list[DataFactory]` + Data factories, one for each HDU being mocked. + hdu_layout : `list[HDU]` + List of HDU types (PrimaryHDU, CompImageHDU...) used to create an HDU + from a header and a data factory. + img_data : `SimpleImage` or `SimulatedImage` + Reference to second element in `data_factories`, an image data factory. + var_data : `SimpleVariance`` + Reference to third element in `data_factories`, an variance data factory. + Variance uses read_noise of 7.0 e- and gain of 5.0 e-/count as described + by the Table 2.2 in DECam Data Handbook Version 2.05 March 2014. + mask_data : `SimpleMask` + Reference to fourth element in `data_factories`, an mask data factory. + current : `int` + Counter of current mocking iteration. + + Parameters + ---------- + shared_header_metadata : `dict-like` or `None`, optional + Keyword-value pairs that will use to create and update all of the headers. + shape : `tuple`, optional + Size of the image arrays, 100x100 pixels by default. + start_t : `str` or `astropy.time.Time`, optional + A timestamp string interpretable by Astropy, or a time object. + step_t : `float` or `astropy.time.TimeDelta`, optional + Timestep between each mocked instance. In units of days by default. + with_noise : `bool` + Add noise into the images. + noise : `str` + Noise profile to use, "simplistic" is simple Gaussian noise and + "realistic" simulates several noise sources and adds them together. + src_cat : `SourceCatalog` + Source catalog of static objects to add into the images. + obj_cat : `ObjectCatalog` + Object catalog of moving objects to add into the images. + wcs_factory : `WCSFactory` + Factory used to create and update WCS data in headers of mocked FITS + files. + """ + + def __init__(self, with_data=False, with_noise=False, noise="simplistic", src_cat=None, obj_cat=None): + if obj_cat is not None and obj_cat.mode == "progressive": + raise ValueError( + "Only folding or static object catalogs can be used with" + "default archived DECam headers since header timestamps are not " + "required to be equally spaced." + ) + + self.hdr_factory = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv") + + self.data_factories = [NoneFactory()] * 16 + if with_data: + headers = self.hdr_factory.get(0) + + shape = (headers[1]["NAXIS1"], headers[1]["NAXIS2"]) + dtype = DataFactoryConfig.bitpix_type_map[headers[1]["BITPIX"]] + + # Read noise and gain are typical values. DECam has 2 amps per CCD, + # each powering ~half of the plane. Their values and areas are + # recorded in the header, but that would mean I would have to + # produce an image which has different zero-offsets for the two + # halves which is too much detail for this use-case. Typical values + # are taken from the DECam Data Handbook Version 2.05 March 2014 + # Table 2.2 + if noise == "realistic": + self.img_data = SimulatedImage(src_cat=src_cat, shape=shape, dtype=dtype) + else: + self.img_data = SimpleImage(src_cat=src_cat, shape=shape, dtype=dtype) + self.var_data = SimpleVariance(self.img_data.base, read_noise=7.0, gain=4.0) + self.mask_data = SimpleMask.from_image(self.img_data.base) + + self.data_factories[1] = self.img_data + self.data_factories[2] = self.mask_data + self.data_factories[3] = self.mask_data + self.data_factories[4:] = [DataFactory.from_header(h, kind="table") for h in headers[4:]] + + self.with_data = with_data + self.src_cat = src_cat + self.obj_cat = obj_cat + self.hdu_layout = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU] + self.hdu_layout.extend([BinTableHDU] * 12) + self.current = 0 + + def mock(self, n=1): + """Mock n differenceExp dataset type-like FITS files.""" + headers = self.hdr_factory.mock(n=n) + + obj_cats = None + if self.obj_cat is not None: + kwargs = {"t": [hdrs[0][0]["DATE-AVG"] for hdr in hdrs]} + obj_cats = self.obj_cat.mock(n=n, **kwargs) + + if self.with_data: + images = self.img_data.mock(n=n, obj_cats=obj_cats) + masks = self.mask_data.mock(n=n) + variances = self.var_data.mock(images=images) + data = [NoneFactory().mock(n=n), images, masks, variances] + data.extend([factory.mock(n=n) for factory in self.data_factories[4:]]) + else: + data = [factory.mock(n=n) for factory in self.data_factories] + + hduls = [] + for i, hdrs in enumerate(headers): + hdus = [] + for j, (layer, hdr) in enumerate(zip(self.hdu_layout, hdrs)): + hdus.append(layer(header=hdr, data=data[j][i])) + hduls.append(HDUList(hdus=hdus)) + + self.current += n + return hduls diff --git a/src/kbmod/mocking/headers.py b/src/kbmod/mocking/headers.py new file mode 100644 index 000000000..2b265e9e8 --- /dev/null +++ b/src/kbmod/mocking/headers.py @@ -0,0 +1,637 @@ +import random +import warnings +import itertools + +import numpy as np + +from astropy.wcs import WCS +from astropy.io.fits import Header +from astropy.io.fits.verify import VerifyWarning + +from .utils import header_archive_to_table +from .config import Config + + +__all__ = [ + "WCSFactory", + "HeaderFactory", + "ArchivedHeader", +] + + +class WCSFactory: + """WCS Factory. + + Used to generate collections of ICRS TAN WCS from parameters, or as a way to + update the given header with a new WCS. + + The new WCS is generated by inheriting the header and updating its properties, + or by completely overwriting the Header WCS and replacing it with a new WCS. + + Attribute + --------- + current : `int` + Counter of the number of mocked WCSs. + template : `WCS` + Template WCS. + + Parameters + ---------- + pointing : `tuple`, optional + Ra and Dec w.r.t ICRS coordinate system, in decimal degrees. + rotation : `float`, optional + Rotation, in degrees, from ICRS equator (ecliptic). + pixscale : `float`, optional + Pixel scale, in arcseconds per pixel. + dither_pos : `bool`, optional + Dither positions of mocked WCSs. + dither_rot : `bool`, optional + Dither rotations of mocked WCSs. + dither_amplitudes : `tuple`, optional + A set of 3 values, the amplitude of dither in ra direction, the + amplitude of dither in dec direction and the amplitude of dither in + rotations. In decimal degrees. + cycle : `list[WCS]`, optional + A list of pre-created WCS objects through which to iterate. + + Examples + -------- + >>> from astropy.io.fits import Header + >>> import kbmod.mocking as kbmock + >>> wcsf = kbmock.WCSFactory(pointing=(10, 10), rotation=45) + >>> wcsf.mock(Header()) + WCSAXES = 2 / Number of coordinate axes + CRPIX1 = 0.0 / Pixel coordinate of reference point + CRPIX2 = 0.0 / Pixel coordinate of reference point + PC1_1 = -3.928369684292E-05 / Coordinate transformation matrix element + PC1_2 = 3.9283723288914E-05 / Coordinate transformation matrix element + PC2_1 = 3.9283723288914E-05 / Coordinate transformation matrix element + PC2_2 = 3.928369684292E-05 / Coordinate transformation matrix element + CDELT1 = 1.0 / [deg] Coordinate increment at reference point + CDELT2 = 1.0 / [deg] Coordinate increment at reference point + CUNIT1 = 'deg' / Units of coordinate increment and value + CUNIT2 = 'deg' / Units of coordinate increment and value + CTYPE1 = 'RA---TAN' / Right ascension, gnomonic projection + CTYPE2 = 'DEC--TAN' / Declination, gnomonic projection + CRVAL1 = 10.0 / [deg] Coordinate value at reference point + CRVAL2 = 10.0 / [deg] Coordinate value at reference point + LONPOLE = 180.0 / [deg] Native longitude of celestial pole + LATPOLE = 10.0 / [deg] Native latitude of celestial pole + MJDREF = 0.0 / [d] MJD of fiducial time + RADESYS = 'ICRS' / Equatorial coordinate system, + """ + + def __init__( + self, + mode="static", + pointing=(351.0, -5), + rotation=0, + pixscale=0.2, + dither_pos=False, + dither_rot=False, + dither_amplitudes=(0.01, 0.01, 0.0), + cycle=None, + ): + self.pointing = pointing + self.rotation = rotation + self.pixscale = pixscale + + self.dither_pos = dither_pos + self.dither_rot = dither_rot + self.dither_amplitudes = dither_amplitudes + self.cycle = cycle + + self.template = self.gen_wcs(self.pointing, self.rotation, self.pixscale) + + self.mode = mode + if dither_pos or dither_rot or cycle is not None: + self.mode = "dynamic" + self.current = 0 + + @classmethod + def gen_wcs(cls, pointing, rotation=0, pixscale=1, shape=None): + """ + Create a simple celestial `~astropy.wcs.WCS` object in ICRS + coordinate system. + + Parameters + ---------- + shape : tuple[int] + Two-tuple, dimensions of the WCS footprint + center_coords : tuple[int] + Two-tuple of on-sky coordinates of the center of the WCS in + decimal degrees, in ICRS. + rotation : float, optional + Rotation in degrees, from ICRS equator. In decimal degrees. + pixscale : float + Pixel scale in arcsec/pixel. + + Returns + ------- + wcs : `astropy.wcs.WCS` + The world coordinate system. + + Examples + -------- + >>> import kbmod.mocking as kbmock + >>> [kbmock.WCSFactory.gen_wcs((10, 10), rot, 0.2) for rot in (0, 90)] + [WCS Keywords + + Number of WCS axes: 2 + CTYPE : 'RA---TAN' 'DEC--TAN' + CRVAL : 10.0 10.0 + CRPIX : 0.0 0.0 + PC1_1 PC1_2 : -5.555555555555556e-05 0.0 + PC2_1 PC2_2 : 0.0 5.555555555555556e-05 + CDELT : 1.0 1.0 + NAXIS : 0 0, WCS Keywords + + Number of WCS axes: 2 + CTYPE : 'RA---TAN' 'DEC--TAN' + CRVAL : 10.0 10.0 + CRPIX : 0.0 0.0 + PC1_1 PC1_2 : 3.7400283533421276e-11 5.555555555554297e-05 + PC2_1 PC2_2 : 5.555555555554297e-05 -3.7400283533421276e-11 + CDELT : 1.0 1.0 + NAXIS : 0 0] + """ + wcs = WCS(naxis=2) + rho = rotation * 0.0174533 # deg to rad + scale = pixscale / 3600.0 # arcsec/pixel to deg/pix + + if shape is not None: + wcs.pixel_shape = shape + wcs.wcs.crpix = [shape[1] // 2, shape[0] // 2] + else: + wcs.wcs.crpix = [0, 0] + wcs.wcs.crval = pointing + wcs.wcs.cunit = ["deg", "deg"] + wcs.wcs.pc = np.array( + [[-scale * np.cos(rho), scale * np.sin(rho)], [scale * np.sin(rho), scale * np.cos(rho)]] + ) + wcs.wcs.radesys = "ICRS" + wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + + return wcs + + def next(self): + """Iteratively return WCS from the cycle.""" + for wcs in self.cycle: + yield wcs + + def update_from_header(self, header): + """Update WCS template using a header, only updates the cards template + shares with the given header. + + Updates the template WCS in-place. + """ + t = self.template.to_header() + t.update(header) + self.template = WCS(t) + + def update_headers(self, headers): + """Update the headers, in-place, with a new mocked WCS. + + If the header contains a WCS, it is updated to match the template. + + Parameters + ---------- + header : `astropy.io.fits.Header` + Header + + Returns + ------- + header : `astropy.io.fits.Header` + Update header. + """ + wcs = self.template + + for header in headers: + if self.cycle is not None: + wcs = self.next() + + if self.dither_pos: + dra = random.uniform(-self.dither_amplitudes[0], self.dither_amplitudes[0]) + ddec = random.uniform(-self.dither_amplitudes[1], self.dither_amplitudes[1]) + wcs.wcs.crval += [dra, ddec] + + if self.dither_rot: + ddec = random.uniform(-self.dither_amplitudes[2], self.dither_amplitudes[2]) + rho = self.dither_amplitudes[2] * 0.0174533 # deg to rad + rot_matrix = np.array([[-np.cos(rho), np.sin(rho)], [np.sin(rho), np.cos(rho)]]) + new_pc = wcs.wcs.pc @ rot_matrix + wcs.wcs.pc = new_pc + + self.current += 1 + header.update(wcs.to_header()) + + +class HeaderFactory: + """Mocks a header from a given template. + + Callback functions can be defined for individual header cards which are + executed and their respective values are updated for each new mocked header. + + A WCS factory can be attached to this factory to update the related WCS + header keywords. + + Provides two base templates from which to create Headers from, based on + CTIO observatory location and DECam instrument. + + Can generate headers from dict-like template and card overrides. + + Attributes + ---------- + is_dynamic : `bool` + True when factory has a mutable Header cards or a WCS factory. + + Parameters + --------- + metadata : `dict-like` + Header, dict, or a list of cards from which a header is created. + mutables : `list[str]` or `None`, optional + A list of strings, matching header card keys, designating them as a + card that has an associated callback. + callbacks : `list[func]`, `list[class]` or `None` + List of callbacks, functions or classes with a ``__call__`` method, that + will be executed in order to update the respective key in ``mutables`` + See already provided callbacks in `kbmod.mocking.callbacks` module. + has_wcs : `bool` + Attach a WCS to each produced header. + wcs_factory : `WCSFactory` or `None` + A WCS factory to use, if `None` and `has_wcs` is `True`, uses the default + WCS Factory. See `WCSFactory` for details. + + Examples + -------- + >>> import kbmod.mocking as kbmock + >>> hf = kbmock.HeaderFactory.from_primary_template() + >>> hf.mock() + [EXTNAME = 'PRIMARY ' + NAXIS = 0 + BITPIX = 8 + DATE-OBS= '2021-03-19T00:27:21.140552' + NEXTEND = 3 + OBS-LAT = -30.166 + OBS-LONG= -70.814 + OBS-ELEV= 2200 + OBSERVAT= 'CTIO ' ] + >>> hf = kbmock.HeaderFactory.from_ext_template() + >>> hf.mock() + [NAXIS = 2 + NAXIS1 = 2048 + NAXIS2 = 4096 + CRPIX1 = 1024.0 / Pixel coordinate of reference point + CRPIX2 = 2048.0 / Pixel coordinate of reference point + BITPIX = 32 + WCSAXES = 2 / Number of coordinate axes + PC1_1 = -5.5555555555556E-05 / Coordinate transformation matrix element + PC2_2 = 5.5555555555556E-05 / Coordinate transformation matrix element + CDELT1 = 1.0 / [deg] Coordinate increment at reference point + CDELT2 = 1.0 / [deg] Coordinate increment at reference point + CUNIT1 = 'deg' / Units of coordinate increment and value + CUNIT2 = 'deg' / Units of coordinate increment and value + CTYPE1 = 'RA---TAN' / Right ascension, gnomonic projection + CTYPE2 = 'DEC--TAN' / Declination, gnomonic projection + CRVAL1 = 351.0 / [deg] Coordinate value at reference point + CRVAL2 = -5.0 / [deg] Coordinate value at reference point + LONPOLE = 180.0 / [deg] Native longitude of celestial pole + LATPOLE = -5.0 / [deg] Native latitude of celestial pole + MJDREF = 0.0 / [d] MJD of fiducial time + RADESYS = 'ICRS' / Equatorial coordinate system ] + """ + + primary_template = { + "EXTNAME": "PRIMARY", + "NAXIS": 0, + "BITPIX": 8, + "DATE-OBS": "2021-03-19T00:27:21.140552", + "NEXTEND": 3, + "OBS-LAT": -30.166, + "OBS-LONG": -70.814, + "OBS-ELEV": 2200, + "OBSERVAT": "CTIO", + } + """Template for the Primary header content.""" + + ext_template = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "CRPIX1": 1024, "CRPIX2": 2048, "BITPIX": 32} + """Template of an image-like extension header.""" + + def __validate_mutables(self): + """Validate number of mutables is number of callbacks, and that designated + mutable cards exist in the given header template. + """ + # !xor + if bool(self.mutables) != bool(self.callbacks): + raise ValueError( + "When providing a list of mutable cards, you must provide associated callback methods." + ) + + if self.mutables is None: + return + + if len(self.mutables) != len(self.callbacks): + raise ValueError( + "The number of mutable cards does not correspond to the number of given callbacks." + ) + + for k in self.mutables: + if k not in self.header: + raise ValueError( + f"Mutable key {k} does not exists " + "in the header. Please " + "provide the required metadata keys." + ) + + def __init__(self, metadata, mutables=None, callbacks=None, has_wcs=False, wcs_factory=None): + cards = [] if metadata is None else metadata + self.header = Header(cards=cards) + self.mutables = mutables + self.callbacks = callbacks + self.__validate_mutables() + + self.is_dynamic = mutables is not None + + self.has_wcs = has_wcs + if has_wcs: + self.wcs_factory = WCSFactory() if wcs_factory is None else wcs_factory + self.wcs_factory.update_from_header(self.header) + self.is_dynamic = self.is_dynamic or self.wcs_factory.mode != "static" + + self.counter = 0 + + def mock(self, n=1): + """Mocks headers, executing callbacks and creating WCS as necessary. + + Parameters + ---------- + n : `int` + Number of headers to mock. + + Returns + ------- + headers : `list[Header]` + Mocked headers. + """ + headers = [] + # This can't be vectorized because callbacks may share global state + for i in range(n): + if not self.is_dynamic: + header = self.header + else: + header = self.header.copy() + if self.mutables is not None: + for i, mutable in enumerate(self.mutables): + header[mutable] = self.callbacks[i](header[mutable]) + if self.has_wcs: + self.wcs_factory.update_headers([header]) + headers.append(header) + self.counter += 1 + return headers + + @classmethod + def gen_header(cls, base, overrides=None, wcs=None): + """Generate a header from a base template and overrides. + + If a WCS is given, and the header contains only a partially defined + WCS, updates only the missing WCS cards and values. + + Parameters + ---------- + base : `Header` or `dict-like` + Header or a dict-like base template for the header. + overrides : `Header`, `dict-like` or `None`, optional + Keys and values that will either be updated or extended to the base + template. + wcs : `astropy.wcs.WCS` or `None`, optional + WCS template to use to update the header values. + + Returns + ------- + header : `astropy.io.fits.Header` + A header. + """ + header = Header(base) + if overrides is not None: + header.update(overrides) + + if wcs is not None: + # Sync WCS with header + overwrites + wcs_header = wcs.to_header() + wcs_header.update(header) + # then merge back to mocked header template + header.update(wcs_header) + + return header + + @classmethod + def from_primary_template(cls, overrides=None, mutables=None, callbacks=None): + """Create a header assuming the default template of a PRIMARY header. + + Override, or extend the default template with keys and values in override. + Attach callbacks to mutable cards. + + Parameters + ---------- + overrides : `dict-like` or `None`, optional + A header, or different dict-like, object used to override the base + template keys and values. + mutables : `list[str]` or `None`, optional + List of card keys designated as mutable. + callbacks : `list[callable]` or `None`, optional + List of callable functions or classes that match the mutables. + + Returns + ------- + factory : `HeaderFactory` + Header factory. + """ + hdr = cls.gen_header(base=cls.primary_template, overrides=overrides) + return cls(hdr, mutables, callbacks) + + @classmethod + def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, shape=None, wcs=None): + """Create an extension header assuming the default template of an image + like header. + + Override, or extend the default template with keys and values in override. + Attach callbacks to mutable cards. + + Parameters + ---------- + overrides : `dict-like` or `None`, optional + A header, or different dict-like, object used to override the base + template keys and values. + mutables : `list[str]` or `None`, optional + List of card keys designated as mutable. + callbacks : `list[callable]` or `None`, optional + List of callable functions or classes that match the mutables. + shape : `tuple` or `None`, optional + Update the template description of data dimensions and the reference + pixel. + wcs : `astropy.wcs.WCS` or `None`, optional + WCS Factory to use. + + Returns + ------- + factory : `HeaderFactory` + Header factory. + """ + ext_template = cls.ext_template.copy() + + if shape is not None: + ext_template["NAXIS1"] = shape[0] + ext_template["NAXIS2"] = shape[1] + ext_template["CRPIX1"] = shape[0] // 2 + ext_template["CRPIX2"] = shape[1] // 2 + + hdr = cls.gen_header(base=ext_template, overrides=overrides) + return cls(hdr, mutables, callbacks, has_wcs=True, wcs_factory=wcs) + + +class ArchivedHeader(HeaderFactory): + """Archived header factory. + + Archived headers are those that were produced with the modified version of + AstroPy's ``fitsheader`` utility available in this module. See + + archiveheaders -h + + for more details. To produce an archive, with KBMOD installed, execute the + following: + + archiveheaders *fits > archive.ecsv | tar -cf headers.tar.bz2 archive.ecsv + + Attributes + ---------- + lexical_type_map : `dict` + A map between the serialized names of built-in types and the built-in + types. Used to cast the serialized card values before creating a Header. + compression : `str` + By default it's assumed the TAR archive was compressed with the ``bz2`` + compression algorithm. + format : `str` + The format in which the file of serialized header cards was written in. + An AstroPy By default, this is assumed to be ``ascii.ecsv`` + + Parameters + ---------- + archive_name : `str` + Name of the TAR archive containing serialized headers. + fname : `str` + Name of the file, within the archive, containing the headers. + external : `bool` + When `True`, file will be searched for relative to the current working + directory. Otherwise, the file is searched for within the header archive + provided with this module. + """ + + # will almost never be anything else. Rather, it would be a miracle if it + # were something else, since FITS standard shouldn't allow it. Further + # casting by some packages will always be casting implemented in terms of + # parsing these types. + lexical_type_map = { + "int": int, + "str": str, + "float": float, + "bool": bool, + } + """Map between type names and types themselves.""" + + compression = "bz2" + """Compression used to compress the archived headers.""" + + format = "ascii.ecsv" + """Format of the archive, and AstroPy'S ASCII module valid identifier of a format.""" + + def __init__(self, archive_name, fname, external=False): + super().__init__({}) + self.table = header_archive_to_table( + archive_name, fname, self.compression, self.format, external=external + ) + + # Create HDU groups for easier iteration + self.table = self.table.group_by("filename") + self.n_hdus = len(self.table) + + def lexical_cast(self, value, vtype): + """Cast str literal of a type to the type itself. Supports just + the builtin Python types. + """ + if vtype in self.lexical_type_map: + return self.lexical_type_map[vtype](value) + return value + + def get_item(self, group_idx, hdr_idx): + """Get an extension of an HDUList within the archive without + incrementing the mocking counter. + + Parameters + ---------- + group_idx : `int` + Index of the HDUList to fetch from the archive. + hdr_idx : `int` + Index of the extension within the HDUList. + + Returns + ------- + header : `Header` + Header of the given extension of the targeted HDUList. + """ + header = Header() + # this is equivalent to one hdulist worth of headers + group = self.table.groups[group_idx] + # this is equivalent to one HDU's header + subgroup = group.group_by("hdu") + for k, v, f in subgroup.groups[hdr_idx]["keyword", "value", "format"]: + header[k] = self.lexical_cast(v, f) + warnings.resetwarnings() + return header + + def get(self, group_idx): + """Get an HDUList within the archive without incrementing the mocking counter. + + Parameters + ---------- + group_idx : `int` + Index of the HDUList to fetch from the archive. + + Returns + ------- + headers : `list[Header]` + All headers of the targeted HDUList. + """ + headers = [] + # this is a bit repetitive but it saves recreating + # groups for one HDUL-equivalent many times + group = self.table.groups[group_idx] + subgroup = group.group_by("hdu") + headers = [] + for subgroup in subgroup.groups: + header = Header() + for k, v, f in subgroup["keyword", "value", "format"]: + # ignore warnings about non-standard keywords + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=VerifyWarning) + header[k] = self.lexical_cast(v, f) + headers.append(header) + return headers + + def mock(self, n=1): + """Mock all headers within n HDULists. + + Parameters + ---------- + n : `int` + Number of HDUList units which headers we want mocked. + + Returns + ------- + headers : `list[list[Header]]` + A list of containing headers of each extension of an HDUList. + """ + res = [] + for _ in range(n): + res.append(self.get(self.counter)) + self.counter += 1 + return res diff --git a/src/kbmod/mocking/headers_archive.tar.bz2 b/src/kbmod/mocking/headers_archive.tar.bz2 new file mode 100644 index 000000000..cdaa2fe9b Binary files /dev/null and b/src/kbmod/mocking/headers_archive.tar.bz2 differ diff --git a/src/kbmod/mocking/utils.py b/src/kbmod/mocking/utils.py new file mode 100644 index 000000000..008a5eeef --- /dev/null +++ b/src/kbmod/mocking/utils.py @@ -0,0 +1,32 @@ +import tarfile +from os import path + +from astropy.table import Table, MaskedColumn + + +# def get_absolute_data_path(file_or_directory): +# test_dir = path.abspath(path.dirname(__file__)) +# data_dir = path.join(test_dir, "archived_data") +# return path.join(data_dir, file_or_directory) +# +# +# def get_absolute_demo_data_path(file_or_directory): +# project_root_dir = path.abspath(path.dirname(path.dirname(__file__))) +# data_dir = path.join(project_root_dir, "archived_data") +# return path.join(data_dir, file_or_directory) + + +def header_archive_to_table(archive_path, fname, compression, format, external=True): + if not external: + mocking_dir = path.abspath(path.dirname(__file__)) + archive_path = path.join(mocking_dir, archive_path) + with tarfile.open(archive_path, f"r:{compression}") as archive: + tblfile = archive.extractfile(fname) + table = Table.read(tblfile.read().decode(), format=format) + # sometimes empty strings get serialized as masked, to cover that + # eventuality we'll just substitute an empty string + if isinstance(table["value"], MaskedColumn): + table["value"].fill_value = "" + table["value"] = table["value"].filled() + + return table diff --git a/src/kbmod/standardizers/butler_standardizer.py b/src/kbmod/standardizers/butler_standardizer.py index ada95074e..f5c48dbc7 100644 --- a/src/kbmod/standardizers/butler_standardizer.py +++ b/src/kbmod/standardizers/butler_standardizer.py @@ -294,7 +294,7 @@ def _fetch_meta(self): # photometric analysis of the results, while the effective # values are too often NaN. The URI location itself is # ultimately not very useful, but helpful for data inspection. - if self.config.standardize_metadata: + if self.config["standardize_metadata"]: meta_ref = self.ref.makeComponentRef("metadata") meta = self.butler.get(meta_ref) @@ -311,13 +311,13 @@ def _fetch_meta(self): self._metadata["GAINB"] = meta["GAINB"] # Will be nan for VR filter so it's optional - if self.config.standardize_effective_summary_stats: + if self.config["standardize_effective_summary_stats"]: self._metadata["effTime"] = summary.effTime self._metadata["effTimePsfSigmaScale"] = summary.effTimePsfSigmaScale self._metadata["effTimeSkyBgScale"] = summary.effTimeSkyBgScale self._metadata["effTimeZeroPointScale"] = summary.effTimeZeroPointScale - if self.config.standardize_uri: + if self.config["standardize_uri"]: self._metadata["location"] = self.butler.getURI( self.ref, collections=[ @@ -348,14 +348,14 @@ def standardizeMetadata(self): def standardizeScienceImage(self): self.exp = self.butler.get(self.ref) if self.exp is None else self.exp - zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config.zero_point) / 2.5) + zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config["zero_point"]) / 2.5) return [ self.exp.image.array / zp_correct, ] def standardizeVarianceImage(self): self.exp = self.butler.get(self.ref) if self.exp is None else self.exp - zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config.zero_point) / 2.5) + zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config["zero_point"]) / 2.5) return [ self.exp.variance.array / zp_correct**2, ] diff --git a/src/kbmod/standardizers/fits_standardizers/__init__.py b/src/kbmod/standardizers/fits_standardizers/__init__.py index 8cbbbae12..1195f2a82 100644 --- a/src/kbmod/standardizers/fits_standardizers/__init__.py +++ b/src/kbmod/standardizers/fits_standardizers/__init__.py @@ -1,5 +1,6 @@ from .fits_standardizer import * +from .single_extension_fits import * +from .multi_extension_fits import * from .kbmodv05 import * from .kbmodv1 import * -from .multi_extension_fits import * -from .single_extension_fits import * +from .test_data_std import * diff --git a/src/kbmod/standardizers/fits_standardizers/test_data_std.py b/src/kbmod/standardizers/fits_standardizers/test_data_std.py new file mode 100644 index 000000000..d5d9d5e04 --- /dev/null +++ b/src/kbmod/standardizers/fits_standardizers/test_data_std.py @@ -0,0 +1,96 @@ +"""Class for standardizing FITS files produced by the mocking module.""" + +from astropy.time import Time +from .multi_extension_fits import MultiExtensionFits, FitsStandardizerConfig + + +__all__ = [ + "TestDataStdConfig", + "TestDataStd", +] + + +class TestDataStdConfig(FitsStandardizerConfig): + pass + + +class TestDataStd(MultiExtensionFits): + """Standardizer for test-data produced by the mocking module + + The standardizer will never volunteer to process any data, it must be + explicitly forced. + + Parameters + ---------- + location : `str` or `None`, optional + Location of the file (if any) that is standardized. Required if + ``hdulist`` is not provided. + hdulist : `astropy.io.fits.HDUList` or `None`, optional + HDUList to standardize. Required if ``location`` is not provided. + If provided, and ``location`` is not given, defaults to ``:memory:``. + config : `StandardizerConfig`, `dict` or `None`, optional + Configuration key-values used when standardizing the file. + + Attributes + ---------- + hdulist : `~astropy.io.fits.HDUList` + All HDUs found in the FITS file + primary : `~astropy.io.fits.PrimaryHDU` + The primary HDU. + processable : `list` + Any additional extensions marked by the standardizer for further + processing. Does not include the primary HDU if it doesn't contain any + image data. Contains at least 1 entry. + wcs : `list` + WCSs associated with the processable image data. Will contain + at least 1 WCS. + bbox : `list` + Bounding boxes associated with each WCS. + """ + + name = "TestDataStd" + priority = 0 + configClass = TestDataStdConfig + + @classmethod + def resolveTarget(cls, tgt): + return False + + def __init__(self, location=None, hdulist=None, config=None, **kwargs): + super().__init__(location=location, hdulist=hdulist, config=config, **kwargs) + + # this is the only science-image header for Rubin + self.processable = [ + self.hdulist["IMAGE"], + ] + + def translateHeader(self): + """Returns the following metadata, read from the primary header, as a + dictionary: + + ======== ========== =================================================== + Key Header Key Description + ======== ========== =================================================== + mjd DATE-AVG Decimal MJD timestamp of the middle of the exposure + observat OBSERVAT Observatory name + obs_lat OBS-LAT Observatory Latitude + obs_lon OBS-LONG Observatory Longitude + obs_elev OBS-ELEV Observatory elevation. + ======== ========== =================================================== + """ + # required + standardizedHeader = {} + obs_datetime = Time(self.primary["DATE-OBS"]) + standardizedHeader["mjd_mid"] = obs_datetime.mjd + # optional + standardizedHeader["observat"] = self.primary["OBSERVAT"] + standardizedHeader["obs_lat"] = self.primary["OBS-LAT"] + standardizedHeader["obs_lon"] = self.primary["OBS-LONG"] + standardizedHeader["obs_elev"] = self.primary["OBS-ELEV"] + return standardizedHeader + + def standardizeMaskImage(self): + return (self.hdulist["MASK"].data for i in self.processable) + + def standardizeVarianceImage(self): + return (self.hdulist["VARIANCE"].data for i in self.processable) diff --git a/tests/dump_headers.py b/tests/dump_headers.py deleted file mode 100644 index 74fbc60f4..000000000 --- a/tests/dump_headers.py +++ /dev/null @@ -1,500 +0,0 @@ -# Modified from the original Astropy code to add a card format to the -# tabular output. All rights belong to the original authors. - -# Licensed under a 3-clause BSD style license - see LICENSE.rst -""" -``fitsheader`` is a command line script based on astropy.io.fits for printing -the header(s) of one or more FITS file(s) to the standard output in a human- -readable format. - -Example uses of fitsheader: - -1. Print the header of all the HDUs of a .fits file:: - - $ fitsheader filename.fits - -2. Print the header of the third and fifth HDU extension:: - - $ fitsheader --extension 3 --extension 5 filename.fits - -3. Print the header of a named extension, e.g. select the HDU containing - keywords EXTNAME='SCI' and EXTVER='2':: - - $ fitsheader --extension "SCI,2" filename.fits - -4. Print only specific keywords:: - - $ fitsheader --keyword BITPIX --keyword NAXIS filename.fits - -5. Print keywords NAXIS, NAXIS1, NAXIS2, etc using a wildcard:: - - $ fitsheader --keyword NAXIS* filename.fits - -6. Dump the header keywords of all the files in the current directory into a - machine-readable csv file:: - - $ fitsheader --table ascii.csv *.fits > keywords.csv - -7. Specify hierarchical keywords with the dotted or spaced notation:: - - $ fitsheader --keyword ESO.INS.ID filename.fits - $ fitsheader --keyword "ESO INS ID" filename.fits - -8. Compare the headers of different fits files, following ESO's ``fitsort`` - format:: - - $ fitsheader --fitsort --extension 0 --keyword ESO.INS.ID *.fits - -9. Same as above, sorting the output along a specified keyword:: - - $ fitsheader -f -s DATE-OBS -e 0 -k DATE-OBS -k ESO.INS.ID *.fits - -10. Sort first by OBJECT, then DATE-OBS:: - - $ fitsheader -f -s OBJECT -s DATE-OBS *.fits - -Note that compressed images (HDUs of type -:class:`~astropy.io.fits.CompImageHDU`) really have two headers: a real -BINTABLE header to describe the compressed data, and a fake IMAGE header -representing the image that was compressed. Astropy returns the latter by -default. You must supply the ``--compressed`` option if you require the real -header that describes the compression. - -With Astropy installed, please run ``fitsheader --help`` to see the full usage -documentation. -""" - -import argparse -import sys - -import numpy as np - -from astropy import __version__, log -from astropy.io import fits - -DESCRIPTION = """ -Print the header(s) of a FITS file. Optional arguments allow the desired -extension(s), keyword(s), and output format to be specified. -Note that in the case of a compressed image, the decompressed header is -shown by default. - -This script is part of the Astropy package. See -https://docs.astropy.org/en/latest/io/fits/usage/scripts.html#module-astropy.io.fits.scripts.fitsheader -for further documentation. -""".strip() - - -class ExtensionNotFoundException(Exception): - """Raised if an HDU extension requested by the user does not exist.""" - - -class HeaderFormatter: - """Class to format the header(s) of a FITS file for display by the - `fitsheader` tool; essentially a wrapper around a `HDUList` object. - - Example usage: - fmt = HeaderFormatter('/path/to/file.fits') - print(fmt.parse(extensions=[0, 3], keywords=['NAXIS', 'BITPIX'])) - - Parameters - ---------- - filename : str - Path to a single FITS file. - verbose : bool - Verbose flag, to show more information about missing extensions, - keywords, etc. - - Raises - ------ - OSError - If `filename` does not exist or cannot be read. - """ - - def __init__(self, filename, verbose=True): - self.filename = filename - self.verbose = verbose - self._hdulist = fits.open(filename) - - def parse(self, extensions=None, keywords=None, compressed=False): - """Returns the FITS file header(s) in a readable format. - - Parameters - ---------- - extensions : list of int or str, optional - Format only specific HDU(s), identified by number or name. - The name can be composed of the "EXTNAME" or "EXTNAME,EXTVER" - keywords. - - keywords : list of str, optional - Keywords for which the value(s) should be returned. - If not specified, then the entire header is returned. - - compressed : bool, optional - If True, shows the header describing the compression, rather than - the header obtained after decompression. (Affects FITS files - containing `CompImageHDU` extensions only.) - - Returns - ------- - formatted_header : str or astropy.table.Table - Traditional 80-char wide format in the case of `HeaderFormatter`; - an Astropy Table object in the case of `TableHeaderFormatter`. - """ - # `hdukeys` will hold the keys of the HDUList items to display - if extensions is None: - hdukeys = range(len(self._hdulist)) # Display all by default - else: - hdukeys = [] - for ext in extensions: - try: - # HDU may be specified by number - hdukeys.append(int(ext)) - except ValueError: - # The user can specify "EXTNAME" or "EXTNAME,EXTVER" - parts = ext.split(",") - if len(parts) > 1: - extname = ",".join(parts[0:-1]) - extver = int(parts[-1]) - hdukeys.append((extname, extver)) - else: - hdukeys.append(ext) - - # Having established which HDUs the user wants, we now format these: - return self._parse_internal(hdukeys, keywords, compressed) - - def _parse_internal(self, hdukeys, keywords, compressed): - """The meat of the formatting; in a separate method to allow overriding.""" - result = [] - for idx, hdu in enumerate(hdukeys): - try: - cards = self._get_cards(hdu, keywords, compressed) - except ExtensionNotFoundException: - continue - - if idx > 0: # Separate HDUs by a blank line - result.append("\n") - result.append(f"# HDU {hdu} in {self.filename}:\n") - for c in cards: - result.append(f"{c}\n") - return "".join(result) - - def _get_cards(self, hdukey, keywords, compressed): - """Returns a list of `astropy.io.fits.card.Card` objects. - - This function will return the desired header cards, taking into - account the user's preference to see the compressed or uncompressed - version. - - Parameters - ---------- - hdukey : int or str - Key of a single HDU in the HDUList. - - keywords : list of str, optional - Keywords for which the cards should be returned. - - compressed : bool, optional - If True, shows the header describing the compression. - - Raises - ------ - ExtensionNotFoundException - If the hdukey does not correspond to an extension. - """ - # First we obtain the desired header - try: - if compressed: - # In the case of a compressed image, return the header before - # decompression (not the default behavior) - header = self._hdulist[hdukey]._header - else: - header = self._hdulist[hdukey].header - except (IndexError, KeyError): - message = f"{self.filename}: Extension {hdukey} not found." - if self.verbose: - log.warning(message) - raise ExtensionNotFoundException(message) - - if not keywords: # return all cards - cards = header.cards - else: # specific keywords are requested - cards = [] - for kw in keywords: - try: - crd = header.cards[kw] - if isinstance(crd, fits.card.Card): # Single card - cards.append(crd) - else: # Allow for wildcard access - cards.extend(crd) - except KeyError: # Keyword does not exist - if self.verbose: - log.warning(f"{self.filename} (HDU {hdukey}): Keyword {kw} not found.") - return cards - - def close(self): - self._hdulist.close() - - -class TableHeaderFormatter(HeaderFormatter): - """Class to convert the header(s) of a FITS file into a Table object. - The table returned by the `parse` method will contain four columns: - filename, hdu, keyword, and value. - - Subclassed from HeaderFormatter, which contains the meat of the formatting. - """ - - def _parse_internal(self, hdukeys, keywords, compressed): - """Method called by the parse method in the parent class.""" - tablerows = [] - for hdu in hdukeys: - try: - for card in self._get_cards(hdu, keywords, compressed): - tablerows.append( - { - "filename": self.filename, - "hdu": hdu, - "keyword": card.keyword, - "value": str(card.value), - "format": type(card.value).__name__, - } - ) - except ExtensionNotFoundException: - pass - - if tablerows: - from astropy import table - - return table.Table(tablerows) - return None - - -def print_headers_traditional(args): - """Prints FITS header(s) using the traditional 80-char format. - - Parameters - ---------- - args : argparse.Namespace - Arguments passed from the command-line as defined below. - """ - for idx, filename in enumerate(args.filename): # support wildcards - if idx > 0 and not args.keyword: - print() # print a newline between different files - - formatter = None - try: - formatter = HeaderFormatter(filename) - print(formatter.parse(args.extensions, args.keyword, args.compressed), end="") - except OSError as e: - log.error(str(e)) - finally: - if formatter: - formatter.close() - - -def print_headers_as_table(args): - """Prints FITS header(s) in a machine-readable table format. - - Parameters - ---------- - args : argparse.Namespace - Arguments passed from the command-line as defined below. - """ - tables = [] - # Create a Table object for each file - for filename in args.filename: # Support wildcards - formatter = None - try: - formatter = TableHeaderFormatter(filename) - tbl = formatter.parse(args.extensions, args.keyword, args.compressed) - if tbl: - tables.append(tbl) - except OSError as e: - log.error(str(e)) # file not found or unreadable - finally: - if formatter: - formatter.close() - - # Concatenate the tables - if len(tables) == 0: - return False - elif len(tables) == 1: - resulting_table = tables[0] - else: - from astropy import table - - resulting_table = table.vstack(tables) - # Print the string representation of the concatenated table - resulting_table.write(sys.stdout, format=args.table) - - -def print_headers_as_comparison(args): - """Prints FITS header(s) with keywords as columns. - - This follows the dfits+fitsort format. - - Parameters - ---------- - args : argparse.Namespace - Arguments passed from the command-line as defined below. - """ - from astropy import table - - tables = [] - # Create a Table object for each file - for filename in args.filename: # Support wildcards - formatter = None - try: - formatter = TableHeaderFormatter(filename, verbose=False) - tbl = formatter.parse(args.extensions, args.keyword, args.compressed) - if tbl: - # Remove empty keywords - tbl = tbl[np.where(tbl["keyword"] != "")] - else: - tbl = table.Table([[filename]], names=("filename",)) - tables.append(tbl) - except OSError as e: - log.error(str(e)) # file not found or unreadable - finally: - if formatter: - formatter.close() - - # Concatenate the tables - if len(tables) == 0: - return False - elif len(tables) == 1: - resulting_table = tables[0] - else: - resulting_table = table.vstack(tables) - - # If we obtained more than one hdu, merge hdu and keywords columns - hdus = resulting_table["hdu"] - if np.ma.isMaskedArray(hdus): - hdus = hdus.compressed() - if len(np.unique(hdus)) > 1: - for tab in tables: - new_column = table.Column([f"{row['hdu']}:{row['keyword']}" for row in tab]) - tab.add_column(new_column, name="hdu+keyword") - keyword_column_name = "hdu+keyword" - else: - keyword_column_name = "keyword" - - # Check how many hdus we are processing - final_tables = [] - for tab in tables: - final_table = [table.Column([tab["filename"][0]], name="filename")] - if "value" in tab.colnames: - for row in tab: - if row["keyword"] in ("COMMENT", "HISTORY"): - continue - final_table.append(table.Column([row["value"]], name=row[keyword_column_name])) - final_tables.append(table.Table(final_table)) - final_table = table.vstack(final_tables) - # Sort if requested - if args.sort: - final_table.sort(args.sort) - # Reorganise to keyword by columns - final_table.pprint(max_lines=-1, max_width=-1) - - -if __name__ == "__main__": - """This is the main function called by the `fitsheader` script.""" - parser = argparse.ArgumentParser( - description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter - ) - - parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") - - parser.add_argument( - "-e", - "--extension", - metavar="HDU", - action="append", - dest="extensions", - help=( - "specify the extension by name or number; this argument can " - "be repeated to select multiple extensions" - ), - ) - parser.add_argument( - "-k", - "--keyword", - metavar="KEYWORD", - action="append", - type=str, - help=( - "specify a keyword; this argument can be repeated to select " - "multiple keywords; also supports wildcards" - ), - ) - mode_group = parser.add_mutually_exclusive_group() - mode_group.add_argument( - "-t", - "--table", - nargs="?", - default=False, - metavar="FORMAT", - help=( - "print the header(s) in machine-readable table format; the " - 'default format is "ascii.fixed_width" (can be "ascii.csv", ' - '"ascii.html", "ascii.latex", "fits", etc)' - ), - ) - mode_group.add_argument( - "-f", - "--fitsort", - action="store_true", - help=("print the headers as a table with each unique " "keyword in a given column (fitsort format) "), - ) - parser.add_argument( - "-s", - "--sort", - metavar="SORT_KEYWORD", - action="append", - type=str, - help=( - "sort output by the specified header keywords, can be repeated to " - "sort by multiple keywords; Only supported with -f/--fitsort" - ), - ) - parser.add_argument( - "-c", - "--compressed", - action="store_true", - help=( - "for compressed image data, show the true header which describes " - "the compression rather than the data" - ), - ) - parser.add_argument( - "filename", - nargs="+", - help="path to one or more files; wildcards are supported", - ) - args = parser.parse_args() - # If `--table` was used but no format specified, - # then use ascii.fixed_width by default - if args.table is None: - args.table = "ascii.fixed_width" - - if args.sort: - args.sort = [key.replace(".", " ") for key in args.sort] - if not args.fitsort: - log.error("Sorting with -s/--sort is only supported in conjunction with" " -f/--fitsort") - # 2: Unix error convention for command line syntax - sys.exit(2) - - if args.keyword: - args.keyword = [key.replace(".", " ") for key in args.keyword] - - # Now print the desired headers - try: - if args.table: - print_headers_as_table(args) - elif args.fitsort: - print_headers_as_comparison(args) - else: - print_headers_traditional(args) - except OSError: - # A 'Broken pipe' OSError may occur when stdout is closed prematurely, - # eg. when calling `fitsheader file.fits | head`. We let this pass. - pass diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 799d46158..ff4e82b64 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,18 +1,350 @@ -import math -import numpy as np -import os -import tempfile +# import math +# import numpy as np +# import os +# import tempfile +# import pytest +# +# from kbmod.fake_data.fake_data_creator import * +# from kbmod.run_search import * +# from kbmod.search import * +# from kbmod.wcs_utils import make_fake_wcs +# from kbmod.work_unit import WorkUnit + +# from utils.utils_for_tests import get_absolute_demo_data_path + + +#### import unittest +import itertools +import random + +import numpy as np +from numpy.lib.recfunctions import structured_to_unstructured + +from astropy.time import Time +from astropy.table import Table, vstack +from astropy.wcs import WCS +from astropy.coordinates import SkyCoord + +from kbmod import ImageCollection +from kbmod.run_search import SearchRunner +from kbmod.configuration import SearchConfiguration +from kbmod.reprojection import reproject_work_unit +import kbmod.mocking as kbmock + + +class TestEmptySearch(unittest.TestCase): + def setUp(self): + self.factory = kbmock.EmptyFits() + + def test_empty(self): + """Test no detections are found on empty images.""" + hduls = self.factory.mock(n=10) + + # create the most permissive search configs you can come up with + # all values in these images are zeros, we should not be returning + # anything + config = SearchConfiguration.from_dict( + { + "average_angle": 0.0, + "v_arr": [10, 20, 10], + "lh_level": 0.1, + "num_obs": 1, + "do_mask": False, + "do_clustering": True, + "do_stamp_filter": False, + } + ) + + ic = ImageCollection.fromTargets(hduls, force="TestDataStd") + wu = ic.toWorkUnit(search_config=config) + results = SearchRunner().run_search_from_work_unit(wu) + self.assertTrue(len(results) == 0) + + def test_static_objects(self): + """Test no detections are found on images containing static objects.""" + src_cat = kbmock.SourceCatalog.from_defaults(seed=100) + factory = kbmock.SimpleFits(src_cat=src_cat) + hduls = factory.mock(10) + + ic = ImageCollection.fromTargets(hduls, force="TestDataStd") + wu = ic.toWorkUnit(search_config=SearchConfiguration()) + results = SearchRunner().run_search_from_work_unit(wu) + self.assertTrue(len(results) == 0) + + +class TestRandomLinearSearch(unittest.TestCase): + def setUp(self): + # Set up shared search values + self.n_imgs = 5 + self.repeat_n_times = 10 + self.shape = (200, 200) + self.start_pos = (85, 115) + self.vxs = [-20, 20] + self.vys = [-20, 20] + + # Set up configs for mocking and search + # These don't change from test to test + self.param_ranges = { + "amplitude": [100, 100], + "x_mean": self.start_pos, + "y_mean": self.start_pos, + "x_stddev": [2.0, 2.0], + "y_stddev": [2.0, 2.0], + "vx": self.vxs, + "vy": self.vys, + } + + self.config = SearchConfiguration.from_dict( + { + "generator_config": { + "name": "VelocityGridSearch", + "min_vx": self.vxs[0], + "max_vx": self.vxs[1], + "min_vy": self.vys[0], + "max_vy": self.vys[1], + "vx_steps": 40, + "vy_steps": 40, + }, + "num_obs": self.n_imgs, + "do_mask": False, + "do_clustering": True, + "do_stamp_filter": False, + } + ) + + def xmatch_best(self, obj, results, match_cols={"x_mean": "x", "y_mean": "y", "vx": "vx", "vy": "vy"}): + """Finds the result that minimizes the L2 distance to the target object. + + Parameters + ---------- + obj : `astropy.table.Row` + Row, or a table with single entry, containing the target object. + results : `astropy.table.Table` + Table of objects from which the closest matching one will be returned. + match_cols : `dict`, optional + Dictionary of column names on which to perform the matching. Keys + of the dictionary are columns from ``obj`` and values of the dict + are columns from ``results``. + + Returns + ------- + result : `astropy.table.Row` + Best matching result + distances: `np.array` + Array of calculated L2 distances of ``obj`` to all given results. + """ + objk, resk = [], [] + for k, v in match_cols.items(): + if k in obj.columns and v in results.table.columns: + objk.append(k) + resk.append(v) + tgt = np.fromiter(obj[tuple(objk)].values(), dtype=float, count=len(objk)) + res = structured_to_unstructured(results[tuple(resk)].as_array(), dtype=float) + diff = np.linalg.norm(tgt - res, axis=1) + if len(results) == 1: + return results[0], diff + return results[diff == diff.min()][0], diff + + def assertResultValuesWithinSpec( + self, expected, result, spec, match_cols={"x_mean": "x", "y_mean": "y", "vx": "vx", "vy": "vy"} + ): + """Asserts expected object matches the given result object within + specification. + + Parameters + ---------- + expected : `astropy.table.Row` + Row, or table with single entry, containing the target object. + result : `astropy.table.Row` + Row, or table with single entry, containing the found object. + spec : `float` + Specification of maximum deviation of the expected values from the + found resulting values. For example, a spec of 3 means results can + be 3 or less pixels away from the expected position. + match_cols : `dict`, optional + Dictionary of column names on which to perform the matching. Keys + of the dictionary are columns from ``obj`` and values of the dict + are columns from ``results``. -from kbmod.fake_data.fake_data_creator import * -from kbmod.run_search import * -from kbmod.search import * -from kbmod.wcs_utils import make_fake_wcs -from kbmod.work_unit import WorkUnit + Raises + ------- + AssertionError - if comparison fails. + """ + for ekey, rkey in match_cols.items(): + info = ( + f"\n Expected: \n {expected[tuple(match_cols.keys())]} \n" + f"Retrieved : \n {result[tuple(match_cols.values())]}" + ) + self.assertLessEqual(abs(expected[ekey] - result[rkey]), spec, info) -# from .utils_for_tests import get_absolute_demo_data_path -# import utils_for_tests -from utils.utils_for_tests import get_absolute_demo_data_path + def run_single_search(self, data, expected, spec=5): + """Runs a KBMOD search on given data and tests the results lie within + specification from the expected. + + Parameters + ---------- + data : `list[str]` or `list[astropy.io.fits.HDUList]` + List of targets processable by the TestDataStandardizer. + expected : `kbmod.mocking.ObjectCatalog` + Object catalog expected to be retrieved from the run. + spec : `float` + Specification of maximum deviation of the expected values from the + found resulting values. For example, a spec of 3 means results can + be 3 or less pixels away from the expected position. + """ + ic = ImageCollection.fromTargets(data, force="TestDataStd") + wu = ic.toWorkUnit(search_config=self.config) + results = SearchRunner().run_search_from_work_unit(wu) + + # Run tests + self.assertGreaterEqual(len(results), 1) + for obj in expected.table: + res, dist = (results[0], None) if len(results) == 1 else self.xmatch_best(obj, results) + self.assertResultValuesWithinSpec(obj, res, spec) + + def test_exact_motion(self): + """Test exact searches are recovered in all 8 cardinal directions.""" + search_vs = list(itertools.product([-20, 0, 20], repeat=2)) + search_vs.remove((0, 0)) + for vx, vy in search_vs: + with self.subTest(f"Cardinal direction: {(vx, vy)}"): + self.config._params["generator_config"] = {"name": "SingleVelocitySearch", "vx": vx, "vy": vy} + obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1) + obj_cat.table["vx"] = vx + obj_cat.table["vy"] = vy + factory = kbmock.SimpleFits(shape=self.shape, step_t=1, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) + self.run_single_search(hduls, obj_cat, 1) + + def test_random_motion(self): + """Repeat searches for randomly inserted objects.""" + # Mock the data and repeat tests. The random catalog + # creation guarantees a diverse set of changing test values + for i in range(self.repeat_n_times): + with self.subTest(f"Iteration {i}"): + obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1) + factory = kbmock.SimpleFits(shape=self.shape, step_t=1, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) + self.run_single_search(hduls, obj_cat) + + def test_resampled_search(self): + """Search for objects in a set of resampled images; randomly dithered pointings and orientations.""" + # 0. Setup + self.shape = (500, 500) + self.start_pos = (10, 10) # (ra, dec) in deg + n_obj = 1 + pixscale = 0.2 + timestamps = Time(np.arange(58915, 58915 + self.n_imgs, 1), format="mjd") + vx = 0.001 # degrees / day (given the timestamps) + vy = 0.001 + + # 1. Mock data + # - mock catalogs, set expected positions by hand + # - mock WCSs so that they dither around (10, 10) + # - instantiate the required mockers and mock + cats = [] + for i, t in enumerate(timestamps): + cats.append( + Table( + { + "amplitude": [100], + "obstime": [t], + "ra_mean": [self.start_pos[0] + vx * i], + "dec_mean": [self.start_pos[1] + vy * i], + "stddev": [2.0], + } + ) + ) + catalog = vstack(cats) + obj_cat = kbmock.ObjectCatalog.from_table(catalog, kind="world", mode="folding") + + wcs_factory = kbmock.WCSFactory( + pointing=self.start_pos, + rotation=0, + pixscale=pixscale, + dither_pos=True, + dither_rot=True, + dither_amplitudes=(0.001, 0.001, 10), + ) + + prim_hdr_factory = kbmock.HeaderFactory.from_primary_template( + mutables=["DATE-OBS"], + callbacks=[kbmock.ObstimeIterator(timestamps)], + ) + + factory = kbmock.SimpleFits(shape=self.shape, obj_cat=obj_cat, wcs_factory=wcs_factory) + factory.prim_hdr = prim_hdr_factory + hduls = factory.mock(n=self.n_imgs) + + # 2. Run search + # - make an IC + # - determine WCS footprint to reproject to + # - determine the pixel-based velocity to search for + # - reproject + # - run search + ic = ImageCollection.fromTargets(hduls, force="TestDataStd") + + from reproject.mosaicking import find_optimal_celestial_wcs + + opt_wcs, self.shape = find_optimal_celestial_wcs(list(ic.wcs)) + opt_wcs.array_shape = self.shape + + meanvx = -vx * 3600 / pixscale + meanvy = vy * 3600 / pixscale + + # The velocity grid needs to be searched very densely for the realistic + # case (compared to the fact the velocity spread is not that large), and + # we'll still end up ~10 pixels away from the truth. + search_config = SearchConfiguration.from_dict( + { + "generator_config": { + "name": "VelocityGridSearch", + "min_vx": meanvx - 5, + "max_vx": meanvx + 5, + "min_vy": meanvy - 5, + "max_vy": meanvy + 5, + "vx_steps": 40, + "vy_steps": 40, + }, + "num_obs": 1, + "do_mask": False, + "do_clustering": True, + "do_stamp_filter": False, + } + ) + wu = ic.toWorkUnit(search_config) + repr_wu = reproject_work_unit(wu, opt_wcs, parallelize=False) + results = SearchRunner().run_search_from_work_unit(repr_wu) + + # Compare results and validate + # - add in pixel velocities because realistic searches rarely + # find good pixel location match + # - due to that, we also can't rely that we'll get a good match on + # any particular catalog realization. We iterate over all of them + # and find the best matching results in each realization. + # From all realizations find the one that matches the best. + # Select that realization and that best match for comparison. + cats = obj_cat.mock(t=timestamps, wcs=[opt_wcs] * self.n_imgs) + for cat in cats: + cat["vx"] = meanvx + cat["vy"] = meanvy + + dists = np.array([self.xmatch_best(cat, results)[1] for cat in cats]) + min_dist_within_realization = dists.min(axis=0) + min_dist_across_realizations = dists.min() + + best_realization = dists.min(axis=1) == min_dist_across_realizations + best_realization_idx = np.where(best_realization == True)[0][0] + + best_cat = cats[best_realization_idx] + best_res = results[dists[best_realization_idx] == min_dist_across_realizations] + + self.assertGreaterEqual(len(results), 1) + self.assertResultValuesWithinSpec(best_cat, best_res, 10) + + +#### # this is the first test to actually test things like get_all_stamps from @@ -22,107 +354,107 @@ # (instead of RawImages), but hopefully we can deduplicate all this by making # these operations into functions and calling on the .image attribute # apply_stamp_filter for example is literal copy of the C++ code in RawImage? -class test_end_to_end(unittest.TestCase): - def setUp(self): - # Define the path for the data. - im_filepath = get_absolute_demo_data_path("demo") - - # The demo data has an object moving at x_v=10 px/day - # and y_v = 0 px/day. So we search velocities [0, 20] - # and angles [-0.5, 0.5]. - v_arr = [0, 20, 21] - ang_arr = [0.5, 0.5, 11] - - self.input_parameters = { - # Required - "im_filepath": im_filepath, - "res_filepath": None, - "time_file": None, - "output_suffix": "DEMO", - "v_arr": v_arr, - "ang_arr": ang_arr, - # Important - "num_obs": 7, - "do_mask": True, - "lh_level": 10.0, - "gpu_filter": True, - # Fine tuning - "sigmaG_lims": [15, 60], - "mom_lims": [37.5, 37.5, 1.5, 1.0, 1.0], - "peak_offset": [3.0, 3.0], - "chunk_size": 1000000, - "stamp_type": "cpp_median", - "eps": 0.03, - "clip_negative": True, - "mask_num_images": 10, - "cluster_type": "position", - # Override the ecliptic angle for the demo data since we - # know the true angle in pixel space. - "average_angle": 0.0, - "save_all_stamps": True, - } - - @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") - def test_demo_defaults(self): - rs = SearchRunner() - keep = rs.run_search_from_config(self.input_parameters) - self.assertGreaterEqual(len(keep), 1) - self.assertEqual(keep["stamp"][0].shape, (21, 21)) - - @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") - def test_demo_config_file(self): - im_filepath = get_absolute_demo_data_path("demo") - config_file = get_absolute_demo_data_path("demo_config.yml") - rs = SearchRunner() - keep = rs.run_search_from_file( - config_file, - overrides={"im_filepath": im_filepath}, - ) - self.assertGreaterEqual(len(keep), 1) - self.assertEqual(keep["stamp"][0].shape, (21, 21)) - - @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") - def test_demo_stamp_size(self): - self.input_parameters["stamp_radius"] = 15 - self.input_parameters["mom_lims"] = [80.0, 80.0, 50.0, 20.0, 20.0] - - rs = SearchRunner() - keep = rs.run_search_from_config(self.input_parameters) - self.assertGreaterEqual(len(keep), 1) - - self.assertIsNotNone(keep["stamp"][0]) - self.assertEqual(keep["stamp"][0].shape, (31, 31)) - - self.assertIsNotNone(keep["all_stamps"][0]) - for s in keep["all_stamps"][0]: - self.assertEqual(s.shape, (31, 31)) - - @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") - def test_e2e_work_unit(self): - num_images = 10 - - # Create a fake data set with a single bright fake object and all - # the observations on a single day. - fake_times = create_fake_times(num_images, 57130.2, 10, 0.01, 1) - ds = FakeDataSet(128, 128, fake_times, use_seed=True) - trj = Trajectory(x=50, y=60, vx=5.0, vy=0.0, flux=500.0) - ds.insert_object(trj) - - # Set the configuration to pick up the fake object. - config = SearchConfiguration() - config.set("ang_arr", [math.pi, math.pi, 16]) - config.set("v_arr", [0, 10.0, 20]) - - fake_wcs = make_fake_wcs(10.0, 10.0, 128, 128) - work = WorkUnit(im_stack=ds.stack, config=config, wcs=fake_wcs) - - with tempfile.TemporaryDirectory() as dir_name: - file_path = os.path.join(dir_name, "test_workunit.fits") - work.to_fits(file_path) - - rs = SearchRunner() - keep = rs.run_search_from_file(file_path) - self.assertGreaterEqual(len(keep), 1) +# class test_end_to_end(pytest.TestCase): +# def setUp(self): +# # Define the path for the data. +# im_filepath = get_absolute_demo_data_path("demo") +# +# # The demo data has an object moving at x_v=10 px/day +# # and y_v = 0 px/day. So we search velocities [0, 20] +# # and angles [-0.5, 0.5]. +# v_arr = [0, 20, 21] +# ang_arr = [0.5, 0.5, 11] +# +# self.input_parameters = { +# # Required +# "im_filepath": im_filepath, +# "res_filepath": None, +# "time_file": None, +# "output_suffix": "DEMO", +# "v_arr": v_arr, +# "ang_arr": ang_arr, +# # Important +# "num_obs": 7, +# "do_mask": True, +# "lh_level": 10.0, +# "gpu_filter": True, +# # Fine tuning +# "sigmaG_lims": [15, 60], +# "mom_lims": [37.5, 37.5, 1.5, 1.0, 1.0], +# "peak_offset": [3.0, 3.0], +# "chunk_size": 1000000, +# "stamp_type": "cpp_median", +# "eps": 0.03, +# "clip_negative": True, +# "mask_num_images": 10, +# "cluster_type": "position", +# # Override the ecliptic angle for the demo data since we +# # know the true angle in pixel space. +# "average_angle": 0.0, +# "save_all_stamps": True, +# } +# +# @pytest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") +# def test_demo_defaults(self): +# rs = SearchRunner() +# keep = rs.run_search_from_config(self.input_parameters) +# self.assertGreaterEqual(len(keep), 1) +# self.assertEqual(keep["stamp"][0].shape, (21, 21)) +# +# @pytest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") +# def test_demo_config_file(self): +# im_filepath = get_absolute_demo_data_path("demo") +# config_file = get_absolute_demo_data_path("demo_config.yml") +# rs = SearchRunner() +# keep = rs.run_search_from_file( +# config_file, +# overrides={"im_filepath": im_filepath}, +# ) +# self.assertGreaterEqual(len(keep), 1) +# self.assertEqual(keep["stamp"][0].shape, (21, 21)) +# +# @pytest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") +# def test_demo_stamp_size(self): +# self.input_parameters["stamp_radius"] = 15 +# self.input_parameters["mom_lims"] = [80.0, 80.0, 50.0, 20.0, 20.0] +# +# rs = SearchRunner() +# keep = rs.run_search_from_config(self.input_parameters) +# self.assertGreaterEqual(len(keep), 1) +# +# self.assertIsNotNone(keep["stamp"][0]) +# self.assertEqual(keep["stamp"][0].shape, (31, 31)) +# +# self.assertIsNotNone(keep["all_stamps"][0]) +# for s in keep["all_stamps"][0]: +# self.assertEqual(s.shape, (31, 31)) +# +# @pytest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") +# def test_e2e_work_unit(self): +# num_images = 10 +# +# # Create a fake data set with a single bright fake object and all +# # the observations on a single day. +# fake_times = create_fake_times(num_images, 57130.2, 10, 0.01, 1) +# ds = FakeDataSet(128, 128, fake_times, use_seed=True) +# trj = Trajectory(x=50, y=60, vx=5.0, vy=0.0, flux=500.0) +# ds.insert_object(trj) +# +# # Set the configuration to pick up the fake object. +# config = SearchConfiguration() +# config.set("ang_arr", [math.pi, math.pi, 16]) +# config.set("v_arr", [0, 10.0, 20]) +# +# fake_wcs = make_fake_wcs(10.0, 10.0, 128, 128) +# work = WorkUnit(im_stack=ds.stack, config=config, wcs=fake_wcs) +# +# with tempfile.TemporaryDirectory() as dir_name: +# file_path = os.path.join(dir_name, "test_workunit.fits") +# work.to_fits(file_path) +# +# rs = SearchRunner() +# keep = rs.run_search_from_file(file_path) +# self.assertGreaterEqual(len(keep), 1) if __name__ == "__main__": diff --git a/tests/test_mocking.py b/tests/test_mocking.py new file mode 100644 index 000000000..588c020b4 --- /dev/null +++ b/tests/test_mocking.py @@ -0,0 +1,338 @@ +import unittest + +import numpy as np + +from astropy.wcs import WCS +from astropy.time import Time +from astropy.table import Table, vstack + +import kbmod.mocking as kbmock + + +class TestEmptyFits(unittest.TestCase): + def test(self): + """Test basics of EmptyFits factory.""" + factory = kbmock.EmptyFits() + hduls = factory.mock(2) + + hdul = hduls[0] + zeros = np.zeros((100, 100)) + self.assertEqual(len(hduls), 2) + self.assertEqual(len(hduls[0]), 4) + for name, hdu in zip(("PRIMARY", "IMAGE", "VARIANCE", "MASK"), hduls[0]): + self.assertEqual(name, hdu.name) + self.assertEqual(hdul["PRIMARY"].data, None) + self.assertTrue((hdul["IMAGE"].data == zeros).all()) + self.assertTrue((hdul["VARIANCE"].data == zeros).all()) + self.assertTrue((hdul["MASK"].data == zeros).all()) + + factory = kbmock.EmptyFits(shape=(10, 100), step_t=1) + hduls = factory.mock(2) + hdul = hduls[0] + self.assertEqual(hdul["IMAGE"].data.shape, (10, 100)) + self.assertEqual(hdul["VARIANCE"].data.shape, (10, 100)) + self.assertEqual(hdul["MASK"].data.shape, (10, 100)) + dt = Time(hduls[1]["PRIMARY"].header["DATE-OBS"]) - Time(hduls[0]["PRIMARY"].header["DATE-OBS"]) + self.assertEqual(dt.to("day").value, 1) + + with self.assertRaisesRegex(ValueError, "destination is read-only"): + hdul["IMAGE"].data[0, 0] = 0 + + factory = kbmock.EmptyFits(editable_images=True) + hduls = factory.mock(2) + hdul = hduls[0] + hdul["IMAGE"].data[0, 0] = 1 + hdul["VARIANCE"].data[0, 0] = 2 + self.assertEqual(hdul["IMAGE"].data[0, 0], 1) + self.assertEqual(hduls[1]["IMAGE"].data[0, 0], 0) + with self.assertRaisesRegex(ValueError, "destination is read-only"): + hdul["MASK"].data[0, 0] = 0 + + factory = kbmock.EmptyFits(editable_images=True, editable_masks=True) + hduls = factory.mock(2) + hdul = hduls[0] + hdul["MASK"].data[0, 0] = 1 + self.assertEqual(hduls[0]["MASK"].data[0, 0], 1) + self.assertEqual(hduls[1]["MASK"].data[0, 0], 0) + + +class TestSimpleFits(unittest.TestCase): + def setUp(self): + self.n_obj = 5 + self.n_imgs = 3 + self.shape = (100, 300) + self.padded = ((10, 90), (10, 290)) + self.timestamps = Time(np.arange(58915, 58915 + self.n_imgs, 1), format="mjd") + self.step_t = 1 + + def test(self): + """Test basic functionality of SimpleFits factory.""" + factory = kbmock.SimpleFits() + hduls = factory.mock(2) + + hdul = hduls[0] + zeros = np.zeros((100, 100)) + self.assertEqual(len(hduls), 2) + self.assertEqual(len(hduls[0]), 4) + for name, hdu in zip(("PRIMARY", "IMAGE", "VARIANCE", "MASK"), hduls[0]): + self.assertEqual(name, hdu.name) + self.assertEqual(hdul["PRIMARY"].data, None) + self.assertTrue((hdul["IMAGE"].data == zeros).all()) + self.assertTrue((hdul["VARIANCE"].data == zeros).all()) + self.assertTrue((hdul["MASK"].data == zeros).all()) + + factory = kbmock.SimpleFits(shape=(10, 100), step_t=1) + hduls = factory.mock(2) + hdul = hduls[0] + self.assertEqual(hdul["IMAGE"].data.shape, (10, 100)) + self.assertEqual(hdul["VARIANCE"].data.shape, (10, 100)) + self.assertEqual(hdul["MASK"].data.shape, (10, 100)) + step_t = Time(hduls[1]["PRIMARY"].header["DATE-OBS"]) - Time(hduls[0]["PRIMARY"].header["DATE-OBS"]) + self.assertEqual(step_t.to("day").value, 1.0) + + def test_static_src_cat(self): + """Test that static source catalog works and is correctly rendered.""" + src_cat = kbmock.SourceCatalog.from_defaults() + src_cat2 = kbmock.SourceCatalog.from_defaults() + self.assertEqual(src_cat.config["mode"], "static") + self.assertFalse((src_cat.table == src_cat2.table).all()) + + src_cat = kbmock.SourceCatalog.from_defaults(n=self.n_obj) + self.assertEqual(len(src_cat.table), self.n_obj) + + param_ranges = { + "amplitude": [100, 100], + "x_mean": self.padded[1], + "y_mean": self.padded[0], + "x_stddev": [2.0, 2.0], + "y_stddev": [2.0, 2.0], + } + src_cat = kbmock.SourceCatalog.from_defaults(param_ranges, seed=100) + src_cat2 = kbmock.SourceCatalog.from_defaults(param_ranges, seed=100) + self.assertTrue((src_cat.table == src_cat2.table).all()) + self.assertLessEqual(src_cat.table["x_mean"].max(), self.shape[1]) + self.assertLessEqual(src_cat.table["y_mean"].max(), self.shape[0]) + + factory = kbmock.SimpleFits(shape=self.shape, src_cat=src_cat) + hdul = factory.mock()[0] + + x = np.round(src_cat.table["x_mean"].data).astype(int) + y = np.round(src_cat.table["y_mean"].data).astype(int) + self.assertGreaterEqual(hdul["IMAGE"].data[y, x].min(), 90) + + def validate_cat_render(self, hduls, cats, expected_gte=90): + """Validate that catalog objects appear in the given images. + + Parameters + ---------- + hduls : `list[astropy.io.fits.HDUList]` + List of FITS files to check. + cats : `list[astropy.table.Table]` + List of catalog realizations containing the coordinate of objects + to check for. + expected_gte : `float` + Expected minimal value of the pixel at the object's location. + """ + for hdul, cat in zip(hduls, cats): + x = np.round(cat["x_mean"].data).astype(int) + y = np.round(cat["y_mean"].data).astype(int) + self.assertGreaterEqual(hdul["IMAGE"].data[y, x].min(), expected_gte) + self.assertGreaterEqual(hdul["VARIANCE"].data[y, x].min(), expected_gte) + + def test_progressive_obj_cat(self): + """Test progressive catalog renders properly.""" + obj_cat = kbmock.ObjectCatalog.from_defaults() + obj_cat2 = kbmock.ObjectCatalog.from_defaults() + self.assertEqual(obj_cat.config["mode"], "progressive") + self.assertFalse((obj_cat.table == obj_cat2.table).all()) + + obj_cat = kbmock.ObjectCatalog.from_defaults(n=self.n_obj) + self.assertEqual(len(obj_cat.table), self.n_obj) + + param_ranges = { + "amplitude": [100, 100], + "x_mean": (0, 90), + "y_mean": self.padded[0], + "x_stddev": [2.0, 2.0], + "y_stddev": [2.0, 2.0], + "vx": [10, 20], + "vy": [0, 0], + } + seed = 200 + obj_cat = kbmock.ObjectCatalog.from_defaults(param_ranges, seed=seed) + obj_cat2 = kbmock.ObjectCatalog.from_defaults(param_ranges, seed=seed) + self.assertTrue((obj_cat.table == obj_cat2.table).all()) + + obj_cat = kbmock.ObjectCatalog.from_defaults(param_ranges, n=self.n_obj) + factory = kbmock.SimpleFits(shape=self.shape, step_t=self.step_t, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) + + obj_cat.reset() + cats = obj_cat.mock(n=self.n_imgs, dt=self.step_t) + self.validate_cat_render(hduls, cats) + + def test_folding_obj_cat(self): + """Test folding catalog renders properly.""" + # Set up shared values for the whole setup + # like starting positions of object and timestamps + start_x = np.ones((self.n_obj,)) * 10 + start_y = np.linspace(10, self.shape[0] - 10, self.n_obj) + + # Set up non-linear catalog (objects will move as counter^2*v) + cats = [] + for i, t in enumerate(self.timestamps): + cats.append( + Table( + { + "amplitude": [100] * self.n_obj, + "obstime": [t] * self.n_obj, + "x_mean": start_x + 15 * i * i, + "y_mean": start_y, + "stddev": [2.0] * self.n_obj, + } + ) + ) + catalog = vstack(cats) + + # Mock data based on that catalog + obj_cat = kbmock.ObjectCatalog.from_table(catalog, mode="folding") + + prim_hdr_factory = kbmock.HeaderFactory.from_primary_template( + mutables=["DATE-OBS"], + callbacks=[kbmock.ObstimeIterator(self.timestamps)], + ) + + factory = kbmock.SimpleFits(shape=self.shape, obj_cat=obj_cat) + factory.prim_hdr = prim_hdr_factory + hduls = factory.mock(n=self.n_imgs) + + obj_cat.reset() + cats = obj_cat.mock(n=self.n_imgs, t=self.timestamps) + self.validate_cat_render(hduls, cats) + + def test_progressive_sky_cat(self): + """Test progressive catalog based on on-sky coordinates.""" + # a 10-50 in x by a 10-90 in y box using default WCS + # self.shape = (500, 500) + param_ranges = { + "ra_mean": (350.998, 351.002), + "dec_mean": (-5.0077, -5.0039), + "v_ra": [-0.001, 0.0001], + "v_dec": [0, 0], + "amplitude": [100, 100], + "x_stddev": [2.0, 2.0], + "y_stddev": [2.0, 2.0], + } + catalog = kbmock.gen_random_catalog(self.n_obj, param_ranges) + obj_cat = kbmock.ObjectCatalog.from_table(catalog, kind="world") + + factory = kbmock.SimpleFits(shape=self.shape, step_t=self.step_t, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) + + # Run tests and ensure we have rendered the object in correct + # positions + obj_cat.reset() + wcs = [WCS(h["IMAGE"].header) for h in hduls] + cats = obj_cat.mock(n=self.n_imgs, dt=self.step_t, wcs=wcs) + self.validate_cat_render(hduls, cats) + + def test_folding_sky_cat(self): + """Test folding catalog based on on-sky coordinates.""" + # a 20x20 box in pixels using a default WCS + start_ra = np.linspace(350.998, 351.002, self.n_obj) + start_dec = np.linspace(-5.0077, -5.0039, self.n_obj) + + cats = [] + for i, t in enumerate(self.timestamps): + cats.append( + Table( + { + "amplitude": [100] * self.n_obj, + "obstime": [t] * self.n_obj, + "ra_mean": start_ra - 0.001 * i, + "dec_mean": start_dec, # + 0.00011 * i, + "stddev": [2.0] * self.n_obj, + } + ) + ) + catalog = vstack(cats) + obj_cat = kbmock.ObjectCatalog.from_table(catalog, kind="world", mode="folding") + + prim_hdr_factory = kbmock.HeaderFactory.from_primary_template( + mutables=["DATE-OBS"], + callbacks=[kbmock.ObstimeIterator(self.timestamps)], + ) + + factory = kbmock.SimpleFits(shape=self.shape, obj_cat=obj_cat) + factory.prim_hdr = prim_hdr_factory + hduls = factory.mock(n=self.n_imgs) + + obj_cat.reset() + wcs = [WCS(h[1].header) for h in hduls] + cats = obj_cat.mock(n=self.n_imgs, t=self.timestamps, wcs=wcs) + self.validate_cat_render(hduls, cats) + + # TODO: move to pytest and mark as xfail + def test_noise_gen(self): + """Test noise renders with expected statistical properties.""" + factory = kbmock.SimpleFits(shape=(1000, 1000), with_noise=True) + hdul = factory.mock()[0] + self.assertAlmostEqual(hdul["IMAGE"].data.mean(), 10, 1) + self.assertAlmostEqual(hdul["IMAGE"].data.std(), 1, 1) + + factory = kbmock.SimpleFits(shape=(1000, 1000), with_noise=True, noise="realistic") + hdul = factory.mock()[0] + self.assertAlmostEqual(hdul["IMAGE"].data.mean(), 32, 1) + self.assertAlmostEqual(hdul["IMAGE"].data.std(), 7.5, 0) + + img_factory = kbmock.SimpleImage(shape=(1000, 1000), add_noise=True, noise=5, noise_std=2.0) + factory = kbmock.SimpleFits() + factory.img_data = img_factory + hduls = factory.mock(n=3) + + for hdul in hduls[1:]: + self.assertFalse((hduls[0]["IMAGE"].data == hdul["IMAGE"].data).all()) + self.assertAlmostEqual(hdul["IMAGE"].data.mean(), 5, 1) + self.assertAlmostEqual(hdul["IMAGE"].data.std(), 2, 1) + + +class TestDiffIm(unittest.TestCase): + def test(self): + """Test basic functionality of SimpleFits factory.""" + factory = kbmock.DECamImdiff() + hduls = factory.mock(2) + + names = [ + "IMAGE", + "MASK", + "VARIANCE", + "ARCHIVE_INDEX", + "FilterLabel", + "Detector", + "TransformMap", + "ExposureSummaryStats", + "Detector", + "KernelPsf", + "FixedKernel", + "SkyWcs", + "ApCorrMap", + "ChebyshevBoundedField", + "ChebyshevBoundedField", + ] + hdul = hduls[0] + self.assertEqual(len(hduls), 2) + self.assertEqual(len(hduls[0]), 16) + for name, hdu in zip(names, hdul[1:]): + self.assertEqual(name, hdu.name) + self.assertEqual(hdul["PRIMARY"].data, None) + + factory = kbmock.DECamImdiff(with_data=True) + hduls = factory.mock(2) + hdul = hduls[0] + self.assertEqual(hdul["IMAGE"].data.shape, (2048, 4096)) + self.assertEqual(hdul["VARIANCE"].data.shape, (2048, 4096)) + self.assertEqual(hdul["MASK"].data.shape, (2048, 4096)) + + +if __name__ == "__main__": + unittest.main()