Skip to content

Commit dca5e7c

Browse files
authored
Merge pull request #757 from aragilar/add_new_astropy_priority_support
Add support for astropy registry priorities
2 parents 8211036 + 2ff82ff commit dca5e7c

File tree

2 files changed

+76
-16
lines changed

2 files changed

+76
-16
lines changed

specutils/io/registers.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
A module containing the mechanics of the specutils io registry.
33
"""
4+
import inspect
45
import os
56
import pathlib
67
import sys
@@ -16,6 +17,16 @@
1617
log = logging.getLogger(__name__)
1718

1819

20+
def _astropy_has_priorities():
21+
"""
22+
Check if astropy has support for loader priorities
23+
"""
24+
sig = inspect.signature(io_registry.register_reader)
25+
if sig.parameters.get("priority") is not None:
26+
return True
27+
return False
28+
29+
1930
def data_loader(label, identifier=None, dtype=Spectrum1D, extensions=None,
2031
priority=0):
2132
"""
@@ -52,7 +63,10 @@ def wrapper(*args, **kwargs):
5263
return wrapper
5364

5465
def decorator(func):
55-
io_registry.register_reader(label, dtype, func)
66+
if _astropy_has_priorities():
67+
io_registry.register_reader(label, dtype, func, priority=priority)
68+
else:
69+
io_registry.register_reader(label, dtype, func)
5670

5771
if identifier is None:
5872
# If the identifier is not defined, but the extensions are, create
@@ -78,17 +92,6 @@ def decorator(func):
7892
# Include the file extensions as attributes on the function object
7993
func.extensions = extensions
8094

81-
# Include priority on the loader function attribute
82-
func.priority = priority
83-
84-
# Sort the io_registry based on priority
85-
sorted_loaders = sorted(io_registry._readers.items(),
86-
key=lambda item: getattr(item[1], 'priority', 0))
87-
88-
# Update the registry with the sorted dictionary
89-
io_registry._readers.clear()
90-
io_registry._readers.update(sorted_loaders)
91-
9295
log.debug("Successfully loaded reader \"{}\".".format(label))
9396

9497
# Automatically register a SpectrumList reader for any data_loader that
@@ -102,7 +105,14 @@ def load_spectrum_list(*args, **kwargs):
102105
load_spectrum_list.extensions = extensions
103106
load_spectrum_list.priority = priority
104107

105-
io_registry.register_reader(label, SpectrumList, load_spectrum_list)
108+
if _astropy_has_priorities():
109+
io_registry.register_reader(
110+
label, SpectrumList, load_spectrum_list, priority=priority,
111+
)
112+
else:
113+
io_registry.register_reader(
114+
label, SpectrumList, load_spectrum_list,
115+
)
106116
io_registry.register_identifier(label, SpectrumList, id_func)
107117
log.debug("Created SpectrumList reader for \"{}\".".format(label))
108118

@@ -113,9 +123,12 @@ def wrapper(*args, **kwargs):
113123
return decorator
114124

115125

116-
def custom_writer(label, dtype=Spectrum1D):
126+
def custom_writer(label, dtype=Spectrum1D, priority=0):
117127
def decorator(func):
118-
io_registry.register_writer(label, Spectrum1D, func)
128+
if _astropy_has_priorities():
129+
io_registry.register_writer(label, dtype, func, priority=priority)
130+
else:
131+
io_registry.register_writer(label, dtype, func)
119132

120133
@wraps(func)
121134
def wrapper(*args, **kwargs):

specutils/tests/test_io.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
This module tests SpecUtils io routines
55
"""
66

7+
from collections import Counter
78
from specutils.io.parsing_utils import generic_spectrum_from_table # or something like that
89
from astropy.io import registry
910
from astropy.table import Table
@@ -15,7 +16,8 @@
1516
import warnings
1617

1718
from specutils import Spectrum1D, SpectrumList
18-
from specutils.io import data_loader
19+
from specutils.io import data_loader, custom_writer
20+
from specutils.io.registers import _astropy_has_priorities
1921

2022

2123
def test_generic_spectrum_from_table(recwarn):
@@ -156,3 +158,48 @@ def reader(*args, **kwargs):
156158
# Clean up after ourselves
157159
registry.unregister_reader(format_name, datatype)
158160
registry.unregister_identifier(format_name, datatype)
161+
162+
163+
@pytest.mark.xfail(
164+
not _astropy_has_priorities(),
165+
reason="Test requires priorities to be implemented in astropy",
166+
raises=registry.IORegistryError,
167+
)
168+
def test_loader_uses_priority(tmpdir):
169+
counter = Counter()
170+
fname = str(tmpdir.join('good.txt'))
171+
172+
with open(fname, 'w') as ff:
173+
ff.write('\n')
174+
175+
def identifier(origin, *args, **kwargs):
176+
fname = args[0]
177+
return 'good' in fname
178+
179+
@data_loader("test_counting_loader1", identifier=identifier, priority=1)
180+
def counting_loader1(*args, **kwargs):
181+
counter["test1"] += 1
182+
wave = np.arange(1,1.1,0.01)*u.AA
183+
return Spectrum1D(
184+
spectral_axis=wave,
185+
flux=np.ones(len(wave))*1.e-14*u.Jy,
186+
)
187+
188+
@data_loader("test_counting_loader2", identifier=identifier, priority=2)
189+
def counting_loader2(*args, **kwargs):
190+
counter["test2"] += 1
191+
wave = np.arange(1,1.1,0.01)*u.AA
192+
return Spectrum1D(
193+
spectral_axis=wave,
194+
flux=np.ones(len(wave))*1.e-14*u.Jy,
195+
)
196+
197+
Spectrum1D.read(fname)
198+
assert counter["test2"] == 1
199+
assert counter["test1"] == 0
200+
201+
for datatype in [Spectrum1D, SpectrumList]:
202+
registry.unregister_reader("test_counting_loader1", datatype)
203+
registry.unregister_identifier("test_counting_loader1", datatype)
204+
registry.unregister_reader("test_counting_loader2", datatype)
205+
registry.unregister_identifier("test_counting_loader2", datatype)

0 commit comments

Comments
 (0)