diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py index 24def21f3..a0ccf2c0b 100644 --- a/specutils/spectra/spectrum1d.py +++ b/specutils/spectra/spectrum1d.py @@ -768,3 +768,76 @@ def __repr__(self): result = "".format(inner_str) return result + + def plot(self, ax=None, x_name='spectral axis', y_name='flux', + set_quantity_support=True, **kwargs): + """ + Visualize this spectrum using matplotlib in "histogram style". + + Parameters + ---------- + ax : `matplotlib.axes.Axes` or None + The axis to plot this figure into. If None, use the current + ``pyplot`` axes (which will create a new figure if none exists). + x_name : str or None + The name to use for the x axis (units will be automatically added) + or None to not set the x axis label. + y_name : str or None + The name to use for the y axis (units will be automatically added) + or None to not set the y axis label. + set_quantity_support : bool + If True, call `astropy.visualization.quantity_support` to ensure + that the quantities in the plot are properly settable. + + kwargs are passed into `~matplotlib.axes.Axes.plot`, except for + ``drawstyle`` or ``ds``. + + Returns + ------- + ax : `matplotlib.axes.Axes` + Either ``ax``, or the newly created axes object (if the ``ax`` + parameter is None). + """ + # import is intentionally inside the method to make matplotlib an + # "optional" dependency + from matplotlib import pyplot as plt + from astropy.visualization import quantity_support + + if set_quantity_support: + quantity_support() + + if 'drawstyle' in kwargs or 'ds' in kwargs: + raise TypeError("cannot set draw style in a spectrum's plot_quick") + + kwargs['drawstyle'] = 'steps-post' + + if len(self.shape) != 1: + nspecdim = len(self.shape) - 1 + indexing_hint = 'spec[' + ', '.join(['0']*nspecdim) + ']' + raise ValueError(f'plot_quick can only be used on 1d spectra. To ' + 'get the first spectrum, try {indexing_hint}') + + if ax is None: + ax = plt.gca() + + # TODO: replace below with self.bin_edges once it is correct + mid_bin_edges = (self.spectral_axis[1:] + self.spectral_axis[:-1])/2 + bin_edges = np.concatenate([(self.spectral_axis[0]*2-mid_bin_edges[0]).ravel(), + mid_bin_edges, + (self.spectral_axis[-1]*2-mid_bin_edges[-1]).ravel()]) + + # for a plot with steps-post, the last horizontal line requires a repeat + # of the last flux value + extended_flux = np.concatenate([self.flux, [self.flux[-1]]]) + + ax.plot(bin_edges, extended_flux, **kwargs) + + if x_name is not None: + sa_unit = self.spectral_axis.unit.to_string(format='latex_inline') + ax.set_xlabel(x_name + f' [{sa_unit}]') + + if y_name is not None: + flux_unit = self.flux.unit.to_string(format='latex_inline') + ax.set_ylabel(y_name + f' [{flux_unit}]') + + return ax diff --git a/specutils/tests/test_spectrum1d.py b/specutils/tests/test_spectrum1d.py index 07efa7fbb..b4c0db978 100644 --- a/specutils/tests/test_spectrum1d.py +++ b/specutils/tests/test_spectrum1d.py @@ -11,6 +11,12 @@ from .conftest import remote_access from ..spectra import Spectrum1D +try: + import matplotlib + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + def test_empty_spectrum(): spec = Spectrum1D(spectral_axis=[]*u.um, @@ -538,3 +544,10 @@ def test_spectral_axis_direction(): wave = [3, 2, 1] * u.nm spec1d = Spectrum1D(spectral_axis=wave, flux=flux) assert spec1d.spectral_axis_direction == 'decreasing' + + +@pytest.mark.skipif('not HAS_MATPLOTLIB') +def test_plot(): + spec_single_flux = Spectrum1D([1, 2] * u.Jy, [3, 4] * u.nm) + ax = spec_single_flux.plot() + assert isinstance(ax, matplotlib.axes.Axes)