import abc
import enum
import os
from typing import List, Optional, Tuple, Union, Dict
from collections import deque
from PySide2.QtCore import *
import numpy as np
from astropy import units as u
from wiser.gui.util import get_random_matplotlib_color, get_color_icon
from wiser.raster.dataset import RasterDataSet, SpectralMetadata
from wiser.raster.roi import RegionOfInterest
from wiser.raster.selection import SelectionType
from wiser.raster.serializable import Serializable, SerializedForm
# ============================================================================
# SPECTRAL CALCULATIONS
class SpectrumAverageMode(enum.Enum):
"""
This enumeration specifies the calculation mode when a spectrum is computed
over multiple pixels of a raster data set.
"""
# Compute the mean (average) spectrum over multiple spatial pixels
MEAN = 1
# Compute the median spectrum over multiple spatial pixels
MEDIAN = 2
AVG_MODE_NAMES = {
SpectrumAverageMode.MEAN: "Mean",
SpectrumAverageMode.MEDIAN: "Median",
}
def find_rectangles_in_row(row: np.ndarray, y: int) -> List[np.ndarray]:
rectangles = []
start = None
for x in range(len(row)):
if row[x] == 1 and start is None:
start = x # Start of a new rectangle
elif row[x] == 0 and start is not None:
rectangles.append(np.array([start, x - 1, y, y])) # End of rectangle
start = None
# If the row ends and a rectangle was still open
if start is not None:
rectangles.append(np.array([start, len(row) - 1, y, y]))
return rectangles
def raster_to_combined_rectangles_x_axis(raster):
rectangles = []
previous_row_rectangles = deque()
# prev_row_rect_to_keep = []
for y in range(raster.shape[0]): # For each row (y-coordinate)
row = raster[y]
current_row_rectangles = find_rectangles_in_row(row, y)
# Get all of the previous rectangles (from previous row or continued on from farther back rows)
# Since we haven't yet added the previous row's rectangles to our final set of rectangles, if there
# are no matches with a current rectangle then it is safe to add it in with the final set of rects
for i in range(len(previous_row_rectangles)):
prev_rect = previous_row_rectangles.pop()
prev_x_start, prev_x_end, prev_y_start, prev_y_end = prev_rect
merged = False
# Get all of the rectangles in the current row, we will compare each previous rectangle
# with all the rectangles in the current row. If there is a match in x and y values
# between the previous rectangle and a current row rectangle, then we update the current
# row rectangle's size to expand. Then when we add this rectangle to previous row rectangles
# it will have carried over.
# If there isn't a merge, we add the previous rect to rectangles.
for curr_rect in current_row_rectangles:
x_start, x_end, y_start, y_end = curr_rect
# If the current rectangle does match with a previous rectangle
if prev_x_start == x_start and prev_x_end == x_end and prev_y_end == y - 1:
# Merge the current rectangle with the previous one
curr_rect[-2] = prev_y_start
merged = True
break
# If the previous rectangle here does not continue from a current rows, we
# immediately add it to rectangles which we can do because we know it
# won't show up again
if not merged:
rectangles.append(np.array(prev_rect))
# We make the previous row rectangles the be the current row to "move on" from this row
# The current rectangles are updated to get merged into the previous ones and nothing is doubly added
previous_row_rectangles = current_row_rectangles
# For the last row it is never treated as a previous row (which would let it be added to
# rectangles), so we just add it to rectangles
rectangles += list(previous_row_rectangles) # Accumulate merged rectangles
return np.array(rectangles)
def raster_to_combined_rectangles_y_axis(raster_y: np.ndarray):
raster_x = raster_y.T
rectangles_x = raster_to_combined_rectangles_x_axis(raster_x)
rectangles_y = rectangles_x[:, [0, 1, 2, 3]] = rectangles_x[:, [2, 3, 0, 1]]
return rectangles_y
def array_to_qrects(array):
qrects = []
for row in array:
x1, x2, y1, y2 = row
# QRect takes (x, y, width, height), so calculate width and height
width = x2 - x1 + 1
height = y2 - y1 + 1
qrects.append(QRect(x1, y1, width, height))
return qrects
def create_raster_from_roi(roi: RegionOfInterest) -> np.ndarray:
bbox = roi.get_bounding_box()
pixels = roi.get_all_pixels()
xmin = bbox.topLeft().x()
ymin = bbox.topLeft().y()
raster = np.zeros((bbox.height(), bbox.width()), dtype=np.uint8)
for pixel in pixels:
pixel_x, pixel_y = pixel[0], pixel[1]
pixel_index_x, pixel_index_y = pixel_x - xmin, pixel_y - ymin
raster[pixel_index_y][pixel_index_x] = 1
return raster
def calc_spectrum_fast(dataset: RasterDataSet, roi: RegionOfInterest, mode=SpectrumAverageMode.MEAN):
"""
Calculate a spectrum over a region of interest from the specified dataset.
The calculation mode can be specified with the mode argument.
"""
spectra = []
# We make a raster out of all of the pixels in the ROI
raster = create_raster_from_roi(roi)
# We do a variant of the Run Line Encoding (RLE) algorithm in the x direction
# and the y direction
rect_x_axis = raster_to_combined_rectangles_x_axis(raster)
rect_y_axis = raster_to_combined_rectangles_y_axis(raster)
bbox = roi.get_bounding_box()
rects = None
if len(rect_x_axis) < len(rect_y_axis):
rects = rect_x_axis
else:
rects = rect_y_axis
# We need to make the rectangles we got from the 'RLE' algorithm
# be in the image coordinate system
rects[:, :2] += bbox.left()
rects[:, 2:] += bbox.top()
# Accessing by rectangular blocks is faster than accessing point by point
qrects = array_to_qrects(rects)
for qrect in qrects:
try:
s = dataset.get_all_bands_at_rect(qrect.left(), qrect.top(), qrect.width(), qrect.height())
except BaseException:
# TODO (Joshua G-K): Make this cleaner. Either check in impl or don't let user create
# ROIs that go out of bounds.
arr = np.full((dataset.num_bands(),), np.nan)
return arr
ndim = s.ndim
if ndim == 2:
for i in range(s.shape[1]):
spectra.append(s[:, i])
elif ndim == 3:
for i in range(s.shape[1]):
for j in range(s.shape[2]):
spectra.append(s[:, i, j])
else:
raise TypeError(f"Expected 2 or 3 dimensions in rectangular aray, but got {s.ndim}")
if len(spectra) > 1:
spectra = np.asarray(spectra)
# Need to compute mean/median/... of the collection of spectra
if mode == SpectrumAverageMode.MEAN:
spectrum = np.nanmean(spectra, axis=0)
elif mode == SpectrumAverageMode.MEDIAN:
spectrum = np.nanmedian(spectra, axis=0)
else:
raise ValueError(f"Unrecognized average type {mode}")
else:
# Only one spectrum, don't need to compute mean/median
spectrum = spectra[0]
return spectrum
def calc_rect_spectrum(dataset: RasterDataSet, rect: QRect, mode=SpectrumAverageMode.MEAN):
"""
Calculate a spectrum over a rectangular area of the specified dataset.
The calculation mode can be specified with the mode argument.
The rect argument is expected to be a QRect object.
"""
points = [(rect.left() + dx, rect.top() + dy) for dx, dy in np.ndindex(rect.width(), rect.height())]
return calc_spectrum(dataset, points, mode)
def calc_spectrum(dataset: RasterDataSet, points: List[QPoint], mode=SpectrumAverageMode.MEAN):
"""
Calculate a spectrum over a collection of points from the specified dataset.
The calculation mode can be specified with the mode argument.
The points argument can be any iterable that produces coordinates for this
function to use.
"""
n = 0
spectra = []
# Collect the spectra that we need for the calculation
for p in points:
n += 1
s = dataset.get_all_bands_at(p[0], p[1])
spectra.append(s)
if len(spectra) > 1:
# Need to compute mean/median/... of the collection of spectra
if mode == SpectrumAverageMode.MEAN:
spectrum = np.mean(spectra, axis=0)
elif mode == SpectrumAverageMode.MEDIAN:
spectrum = np.median(spectra, axis=0)
else:
raise ValueError(f"Unrecognized average type {mode}")
else:
# Only one spectrum, don't need to compute mean/median
spectrum = spectra[0]
return spectrum
def get_all_spectra_in_roi(
dataset: RasterDataSet, roi: RegionOfInterest
) -> List[Tuple[Tuple[int, int], np.ndarray]]:
"""
Given a raster data set and a region of interest, this function returns an
array of 2-tuples, where each pair is comprised of:
* The pixel's (x, y) integer coordinates as a 2-tuple
* A NumPy ndarray object containing the spectrum at that coordinate.
Note that the spectral data will include NaNs for any value from a bad band,
or that was set to the "data ignore value".
"""
# Generate the set of all pixels in the ROI. Turn it into a list so we can
# sort it.
all_pixels = list(roi.get_all_pixels())
all_pixels.sort()
# Generate the collection of spectra at all of those pixels. Each element
# in the list is the pixel, plus its NumPy
all_spectra = [(p, dataset.get_all_bands_at(x=p[0], y=p[1])) for p in all_pixels]
return all_spectra
def calc_roi_spectrum(dataset: RasterDataSet, roi: RegionOfInterest, mode=SpectrumAverageMode.MEAN):
"""
Calculate a spectrum over a Region of Interest from the specified dataset.
The calculation mode can be specified with the mode argument.
"""
return calc_spectrum_fast(dataset, roi, mode)
# ============================================================================
# CLASSES TO REPRESENT SPECTRA
[docs]
class Spectrum(abc.ABC, Serializable):
"""
The base class for representing spectra of interest to the user of the
application.
"""
def __init__(self):
self._id: Optional[int] = None
self._name: Optional[str] = None
self._color = None
# TODO(donnie): Should these be in this class, or in display code?
self._icon = None
self._visible = True
def get_id(self) -> Optional[int]:
return self._id
def set_id(self, id: int) -> None:
self._id = id
def get_name(self) -> str:
raise NotImplementedError("Must be implemented in subclass")
def set_name(self, name: str):
raise NotImplementedError("Must be implemented in subclass")
[docs]
def get_source_name(self) -> str:
"""
Returns the name of the spectrum's source.
"""
raise NotImplementedError("Must be implemented in subclass")
[docs]
def num_bands(self) -> int:
"""Returns the number of spectral bands in the spectrum."""
pass
[docs]
def get_bad_bands(self) -> np.ndarray:
"""
Returns a boolean numpy array indicating which bands are bad (1 for bands
to keep, 0 for bands to removed). If no bad bands are defined, returns None.
"""
return np.array([1] * self.num_bands(), dtype=np.bool_)
[docs]
def get_shape(self) -> Tuple[int]:
"""
Returns the shape of the spectrum. This is always simply
``(num_bands)``.
"""
return (self.num_bands(),)
[docs]
def get_elem_type(self) -> np.dtype:
"""
Returns the element-type of the spectrum.
"""
pass
[docs]
def has_wavelengths(self) -> bool:
"""
Returns True if this spectrum has wavelength units for all bands, False
otherwise.
"""
raise NotImplementedError("Must be implemented in subclass")
[docs]
def get_wavelengths(self) -> List[u.Quantity]:
"""
Returns a list of wavelength values corresponding to each band. The
individual values are astropy values-with-units.
"""
raise NotImplementedError("Must be implemented in subclass")
[docs]
def get_wavelength_units(self) -> Optional[u.Unit]:
"""
Returns the astropy unit corresponding to the wavelength.
"""
raise NotImplementedError("Must be implemented in subclass")
[docs]
def get_spectrum(self) -> np.ndarray:
"""
Return the spectrum data as a 1D NumPy array.
"""
raise NotImplementedError("Must be implemented in subclass")
def get_color(self) -> Optional[str]:
return self._color
def set_color(self, color: str) -> None:
self._color = color
self._icon = None
def is_editable(self) -> bool:
# By default, spectra are editable.
return True
def is_discardable(self) -> bool:
# By default, spectra are discardable.
return True
def get_spectral_metadata(self) -> SpectralMetadata:
spectral_metadata = SpectralMetadata(
band_info=None,
bad_bands=None,
default_display_bands=None,
num_bands=self.num_bands(),
data_ignore_value=None,
has_wavelengths=self.has_wavelengths(),
wavelengths=self.get_wavelengths(),
wavelength_units=self.get_wavelength_units(),
)
return spectral_metadata
[docs]
@staticmethod
def deserialize_into_class(spectrum_arr: Union[str, np.ndarray], metadata: Dict) -> "NumPyArraySpectrum":
"""
This should recreate the object from the serialized form that is obtained from the
get_serialized_form method.
"""
name = metadata["name"]
source_name = metadata["source_name"]
id = metadata["id"]
wavelengths = metadata["wavelengths"]
editable = metadata["editable"]
discardable = metadata["discardable"]
spectrum = NumPyArraySpectrum(spectrum_arr, name, source_name, wavelengths, editable, discardable)
if spectrum.get_id() is None:
spectrum.set_id(id)
return spectrum
# ===============================================================================
# NUMPY ARRAY SPECTRA
# ===============================================================================
[docs]
class NumPyArraySpectrum(Spectrum):
"""
This class represents a spectrum that wraps a simple 1D NumPy array. This
is generally used for computed spectra.
"""
def __init__(
self,
arr: np.ndarray,
name: Optional[str] = None,
source_name: Optional[str] = None,
wavelengths: Optional[List[u.Quantity]] = None,
editable=True,
discardable=True,
):
super().__init__()
self._arr = arr
self._name = name
self._source_name = "unknown"
if source_name:
self._source_name = source_name
self._wavelengths = wavelengths
self._editable = editable
self._discardable = discardable
self._bad_bands: np.ndarray = np.array([1] * arr.shape[0], dtype=np.bool_)
[docs]
def get_name(self) -> Optional[str]:
"""
Returns the current name of the spectrum, or ``None`` if no name has
been assigned.
"""
return self._name
[docs]
def set_name(self, name: Optional[str]):
"""
Sets the name of the spectrum. ``None`` may be specified if the
spectrum is to be unnamed.
"""
self._name = name
[docs]
def get_source_name(self) -> Optional[str]:
"""
Returns the name of the spectrum's source, or ``None`` if no source
name has been specified.
"""
return self._source_name
def set_source_name(self, name: str):
self._source_name = name
[docs]
def get_elem_type(self) -> np.dtype:
"""
Returns the element-type of the spectrum.
"""
return self._arr.dtype
[docs]
def num_bands(self) -> int:
"""Returns the number of spectral bands in the spectrum."""
return self._arr.shape[0]
[docs]
def set_bad_bands(self, bad_bands: np.ndarray):
"""Sets the bad bands array for this spectrum. 1 is keep, 0 is ignore."""
assert bad_bands.ndim == 1 and bad_bands.shape[0] == self.num_bands(), (
"Passed in bad bands either doesn't have 1 dimension or doesn't have"
"same number of bands as this spectrum!"
)
self._bad_bands = bad_bands
[docs]
def get_bad_bands(self):
return self._bad_bands
[docs]
def has_wavelengths(self) -> bool:
"""
Returns True if this spectrum has wavelength units for all bands, False
otherwise.
"""
if self._wavelengths is None:
return False
return isinstance(self._wavelengths[0], u.Quantity)
[docs]
def get_wavelengths(self) -> List[u.Quantity]:
"""
Returns a list of wavelength values corresponding to each band. The
individual values are astropy values-with-units.
"""
if self._wavelengths is None:
raise KeyError("Spectrum doesn't have wavelengths")
return self._wavelengths
[docs]
def set_wavelengths(self, wavelengths: Optional[List[u.Quantity]]):
"""
Sets the wavelength values that correspond to each band. The argument
is a list of astropy values-with-units. Alternately, this method may
be used to clear the wavelength information, by passing in ``None`` as
the argument.
"""
if wavelengths is not None:
if len(wavelengths) != self.num_bands():
raise ValueError(
f"Spectrum has {self.num_bands()} bands, but "
+ f"{len(wavelengths)} wavelengths were specified"
)
# Make a copy of the incoming list
wavelengths = list(wavelengths)
self._wavelengths = wavelengths
[docs]
def get_wavelength_units(self) -> Optional[u.Unit]:
if self.has_wavelengths():
if isinstance(self._wavelengths[0], u.Quantity):
return self._wavelengths[0].unit
return None
def copy_spectral_metadata(self, source: SpectralMetadata):
assert source.get_wavelengths(), "SpectralMetadata has no wavelengths"
self.set_wavelengths(source.get_wavelengths())
[docs]
def get_spectrum(self) -> np.ndarray:
"""
Return the spectrum data as a 1D NumPy array.
"""
return self._arr
def is_editable(self):
return self._editable
def is_discardable(self):
return self._discardable
def __eq__(self, other: "NumPyArraySpectrum") -> bool:
return (
self.get_spectrum() == other.get_spectrum()
and self.get_elem_type() == other.get_elem_type()
and self.get_wavelengths() == other.get_wavelengths()
and self.get_wavelength_units() == other.get_wavelength_units()
)
# ===============================================================================
# RASTER DATA-SET SPECTRA
# ===============================================================================
class RasterDataSetSpectrum(Spectrum):
def __init__(self, dataset):
super().__init__()
self._dataset: RasterDataSet = dataset
self._avg_mode = SpectrumAverageMode.MEAN
self._name = None
self._use_generated_name = True
# This field holds the spectrum data. It is generated lazily, so it
# won't be set until it is requested. Additionally, it may be set back
# to None if the details of how the spectrum is generated are changed.
self._spectrum = None
def __str__(self) -> str:
return f"RasterDataSetSpectrum[{self.get_source_name()}, " + f"name={self.get_name()}]"
def _reset_internal_state(self):
"""
This internal helper function should be called when important details
of this object change, possibly necessitating the recalculation of the
spectrum data and/or a generated name for the spectrum.
"""
self._spectrum = None
if self._use_generated_name:
self._name = None
def get_name(self) -> Optional[str]:
if self._name is None and self._use_generated_name:
self._name = self._generate_name()
return self._name
def set_name(self, name):
self._name = name
def set_use_generated_name(self, use_generated: bool) -> None:
self._use_generated_name = use_generated
if use_generated:
self._name = None
def use_generated_name(self) -> bool:
return self._use_generated_name
def _generate_name(self) -> str:
"""
This helper function generates a name for this spectrum from its
configuration details.
"""
raise NotImplementedError("Must be implemented in subclass")
def get_source_name(self):
filenames = self._dataset.get_filepaths()
if filenames is not None and len(filenames) > 0:
ds_name = os.path.basename(filenames[0])
else:
ds_name = "unknown"
return ds_name
def get_dataset(self):
return self._dataset
def get_avg_mode(self):
return self._avg_mode
def set_avg_mode(self, avg_mode):
if avg_mode not in SpectrumAverageMode:
raise ValueError("avg_mode must be a value from SpectrumAverageMode")
self._avg_mode = avg_mode
self._reset_internal_state()
def num_bands(self) -> int:
"""Returns the number of spectral bands in the raster data."""
return self._dataset.num_bands()
def get_bad_bands(self):
if self._dataset.get_bad_bands() is None:
return np.array([1] * self.num_bands(), dtype=np.bool_)
return np.array(self._dataset.get_bad_bands(), dtype=np.bool_)
def get_elem_type(self) -> np.dtype:
"""
Returns the element-type of the spectrum.
"""
return self._dataset.get_elem_type()
def has_wavelengths(self) -> bool:
"""
Returns True if this spectrum has wavelength units for all bands, False
otherwise.
"""
return self._dataset.has_wavelengths()
def get_wavelengths(self, filter_bad_bands=False) -> List[u.Quantity]:
"""
Returns a list of wavelength values corresponding to each band. The
individual values are astropy values-with-units.
"""
b0: Dict = self._dataset.band_list()[0]
if "wavelength" in b0:
key = "wavelength"
else:
key = "index"
bands = [b[key] for b in self._dataset.band_list()]
if filter_bad_bands:
bad_bands = self._dataset.get_bad_bands()
bands = [bands[i] for i in range(len(bands)) if bad_bands[i]]
return bands
def get_wavelength_units(self) -> Optional[u.Unit]:
"""
Gets the wavelength units.
"""
return self.get_wavelengths()[0].unit
def _calculate_spectrum(self):
"""
This internal helper method computes and stores the spectrum data for
this object, based on its current configuration.
"""
raise NotImplementedError("Must be implemented in subclass")
def get_spectrum(self) -> np.ndarray:
"""
Return the spectrum data as a 1D NumPy array.
"""
if self._spectrum is None:
self._calculate_spectrum()
return self._spectrum
def __eq__(self, other: "Spectrum") -> bool:
return (
self.get_spectrum() == other.get_spectrum()
and self.get_elem_type() == other.get_elem_type()
and self.get_wavelengths() == other.get_wavelengths()
and self.get_wavelength_units() == other.get_wavelength_units()
)
class SpectrumAtPoint(RasterDataSetSpectrum):
"""
This class represents the spectrum at or around a point in a raster data
set. A rectangular area may be specified, and an average spectrum will be
computed over that area.
"""
def __init__(
self,
dataset: RasterDataSet,
point: Tuple[int, int],
area: Tuple[int, int] = (1, 1),
avg_mode=SpectrumAverageMode.MEAN,
):
super().__init__(dataset)
self._point: Tuple[int, int] = point
self._area: Tuple[int, int] = None
self.set_area(area)
self.set_avg_mode(avg_mode)
def _generate_name(self) -> str:
"""
This helper function generates a name for this spectrum from its
configuration details.
"""
if self._area == (1, 1):
name = f"Spectrum at ({self._point[0]}, {self._point[1]})"
else:
name = (
f"{AVG_MODE_NAMES[self._avg_mode]} of "
+ f"{self._area[0]}x{self._area[1]} "
+ f"area around ({self._point[0]}, {self._point[1]})"
)
return name
def _calculate_spectrum(self):
"""
This internal helper method computes and stores the spectrum data for
this object, based on its current configuration.
"""
(x, y) = self._point
if self._area == (1, 1):
try:
self._spectrum = self._dataset.get_all_bands_at(x, y)
except:
self._spectrum = np.full((self._dataset.num_bands(),), np.nan)
else:
(width, height) = self._area
left = x - (width - 1) / 2
top = y - (height - 1) / 2
rect = QRect(left, top, width, height)
try:
self._spectrum = calc_rect_spectrum(self._dataset, rect, mode=self._avg_mode)
except:
self._spectrum = np.full((width, height), np.nan)
def get_point(self):
return self._point
def get_area(self):
return self._area
def set_area(self, area) -> None:
if area[0] % 2 != 1 or area[1] % 2 != 1:
raise ValueError(f"area values must be odd; got {area}")
if self._area != area:
self._area = area
self._reset_internal_state()
class ROIAverageSpectrum(RasterDataSetSpectrum):
"""
This class represents the average spectrum of a Region of Interest in a
raster data set.
"""
def __init__(
self,
dataset: RasterDataSet,
roi: RegionOfInterest,
avg_mode=SpectrumAverageMode.MEAN,
):
super().__init__(dataset)
self._roi: RegionOfInterest = roi
self.set_avg_mode(avg_mode)
def get_roi(self) -> RegionOfInterest:
return self._roi
def _generate_name(self) -> str:
"""
This helper function generates a name for this spectrum from its
configuration details.
"""
return f"{AVG_MODE_NAMES[self._avg_mode]} of " + f'"{self._roi.get_name()}" Region of Interest'
def _calculate_spectrum(self):
"""
This internal helper method computes and stores the spectrum data for
this object, based on its current configuration.
"""
self._spectrum = calc_roi_spectrum(self._dataset, self._roi, self._avg_mode)