Skip to content

API Reference

Core Modules

IO Module

Input/Output module for ShallowLearn. Handles loading and writing satellite data with VRT generation and metadata preservation.

GeoTIFFCollection

Collection manager for multiple GeoTIFF files.

Useful for handling datasets like Planetscope with multiple single-band files or collections of classification/analysis results.

Source code in ShallowLearn/io/satellite_data.py
class GeoTIFFCollection:
    """
    Collection manager for multiple GeoTIFF files.

    Useful for handling datasets like Planetscope with multiple single-band files
    or collections of classification/analysis results.
    """

    def __init__(self, directory: str, pattern: str = "*.tif"):
        """
        Initialize collection from directory.

        Parameters:
        -----------
        directory : str
            Directory containing GeoTIFF files
        pattern : str, default "*.tif"
            Glob pattern to match files
        """
        self.directory = Path(directory)
        self.pattern = pattern
        self.files = []
        self.images = []

        if not self.directory.exists():
            raise FileNotFoundError(f"Directory not found: {directory}")

        self._discover_files()

    def _discover_files(self):
        """Discover and sort GeoTIFF files in directory."""
        self.files = sorted(list(self.directory.glob(self.pattern)))

        if not self.files:
            raise ValueError(f"No files matching pattern '{self.pattern}' found in {self.directory}")

    def load_all(self) -> List[GeoTIFFImage]:
        """
        Load all GeoTIFF files in the collection.

        Returns:
        --------
        List[GeoTIFFImage]
            List of loaded GeoTIFF images
        """
        self.images = []
        for file_path in self.files:
            try:
                img = GeoTIFFImage(str(file_path))
                img.load()
                self.images.append(img)
            except Exception as e:
                print(f"Warning: Failed to load {file_path}: {e}")
                continue

        return self.images

    def get_file_list(self) -> List[Path]:
        """Get list of discovered files."""
        return self.files.copy()

    def stack_images(self) -> np.ndarray:
        """
        Stack all images into a single array.

        Returns:
        --------
        np.ndarray
            Stacked images with shape (n_images, bands, height, width)
        """
        if not self.images:
            self.load_all()

        if not self.images:
            raise ValueError("No images successfully loaded")

        # Assume all images have compatible shapes
        stacked = np.stack([img.image for img in self.images], axis=0)
        return stacked

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, index: int) -> GeoTIFFImage:
        """Get image by index, loading if necessary."""
        if index >= len(self.files):
            raise IndexError(f"Index {index} out of range for {len(self.files)} files")

        if index >= len(self.images):
            # Load missing images up to requested index
            for i in range(len(self.images), index + 1):
                img = GeoTIFFImage(str(self.files[i]))
                img.load()
                self.images.append(img)

        return self.images[index]

    def __repr__(self) -> str:
        return f"<GeoTIFFCollection: {len(self.files)} files in {self.directory.name}>"

__getitem__(index)

Get image by index, loading if necessary.

Source code in ShallowLearn/io/satellite_data.py
def __getitem__(self, index: int) -> GeoTIFFImage:
    """Get image by index, loading if necessary."""
    if index >= len(self.files):
        raise IndexError(f"Index {index} out of range for {len(self.files)} files")

    if index >= len(self.images):
        # Load missing images up to requested index
        for i in range(len(self.images), index + 1):
            img = GeoTIFFImage(str(self.files[i]))
            img.load()
            self.images.append(img)

    return self.images[index]

__init__(directory, pattern='*.tif')

Initialize collection from directory.

Parameters:

directory : str Directory containing GeoTIFF files pattern : str, default "*.tif" Glob pattern to match files

Source code in ShallowLearn/io/satellite_data.py
def __init__(self, directory: str, pattern: str = "*.tif"):
    """
    Initialize collection from directory.

    Parameters:
    -----------
    directory : str
        Directory containing GeoTIFF files
    pattern : str, default "*.tif"
        Glob pattern to match files
    """
    self.directory = Path(directory)
    self.pattern = pattern
    self.files = []
    self.images = []

    if not self.directory.exists():
        raise FileNotFoundError(f"Directory not found: {directory}")

    self._discover_files()

get_file_list()

Get list of discovered files.

Source code in ShallowLearn/io/satellite_data.py
def get_file_list(self) -> List[Path]:
    """Get list of discovered files."""
    return self.files.copy()

load_all()

Load all GeoTIFF files in the collection.

Returns:

List[GeoTIFFImage] List of loaded GeoTIFF images

Source code in ShallowLearn/io/satellite_data.py
def load_all(self) -> List[GeoTIFFImage]:
    """
    Load all GeoTIFF files in the collection.

    Returns:
    --------
    List[GeoTIFFImage]
        List of loaded GeoTIFF images
    """
    self.images = []
    for file_path in self.files:
        try:
            img = GeoTIFFImage(str(file_path))
            img.load()
            self.images.append(img)
        except Exception as e:
            print(f"Warning: Failed to load {file_path}: {e}")
            continue

    return self.images

stack_images()

Stack all images into a single array.

Returns:

np.ndarray Stacked images with shape (n_images, bands, height, width)

Source code in ShallowLearn/io/satellite_data.py
def stack_images(self) -> np.ndarray:
    """
    Stack all images into a single array.

    Returns:
    --------
    np.ndarray
        Stacked images with shape (n_images, bands, height, width)
    """
    if not self.images:
        self.load_all()

    if not self.images:
        raise ValueError("No images successfully loaded")

    # Assume all images have compatible shapes
    stacked = np.stack([img.image for img in self.images], axis=0)
    return stacked

GeoTIFFImage

Generic GeoTIFF loader with backwards compatibility to LoadGeoTIFF.

Supports various GeoTIFF types including: - Planetscope individual band files - GBR benthic classification data - Generic single/multi-band GeoTIFF files

Source code in ShallowLearn/io/satellite_data.py
class GeoTIFFImage:
    """
    Generic GeoTIFF loader with backwards compatibility to LoadGeoTIFF.

    Supports various GeoTIFF types including:
    - Planetscope individual band files
    - GBR benthic classification data
    - Generic single/multi-band GeoTIFF files
    """

    def __init__(self, file_path: str):
        """
        Initialize GeoTIFF loader.

        Parameters:
        -----------
        file_path : str
            Path to the GeoTIFF file
        """
        self.data_source = file_path  # Maintain compatibility with LoadGeoTIFF
        self.path = Path(file_path)
        self.metadata = None
        self.bounds = None
        self.image = None

        if not self.path.exists():
            raise FileNotFoundError(f"GeoTIFF file not found: {file_path}")

    def load(self) -> np.ndarray:
        """
        Load GeoTIFF data with backwards compatibility.

        Returns:
        --------
        np.ndarray
            Image data with shape (bands, height, width)
        """
        try:
            with rio.open(self.data_source) as src:
                # Read all bands
                data = src.read()

                # Handle nodata values similar to original LoadGeoTIFF
                no_data = src.nodatavals
                if no_data and any(nd is not None for nd in no_data):
                    # Create mask for nodata values but don't apply it yet
                    # (maintaining compatibility with original behavior)
                    pass

                self.image = data
                return data

        except Exception as e:
            raise RuntimeError(f"Failed to load GeoTIFF {self.data_source}: {e}")

    def get_metadata(self) -> dict:
        """
        Get rasterio metadata for the GeoTIFF file.

        Returns:
        --------
        dict
            Rasterio metadata dictionary
        """
        try:
            with rio.open(self.data_source) as src:
                self.metadata = src.meta.copy()
            return self.metadata
        except Exception as e:
            raise RuntimeError(f"Failed to get metadata for {self.data_source}: {e}")

    def get_bounds(self) -> rio.coords.BoundingBox:
        """
        Get spatial bounds of the GeoTIFF file.

        Returns:
        --------
        rasterio.coords.BoundingBox
            Bounding box (left, bottom, right, top)
        """
        try:
            with rio.open(self.data_source) as src:
                self.bounds = src.bounds
            return self.bounds
        except Exception as e:
            raise RuntimeError(f"Failed to get bounds for {self.data_source}: {e}")

    def get_crs(self):
        """
        Get coordinate reference system.

        Returns:
        --------
        rasterio.crs.CRS
            Coordinate reference system
        """
        try:
            with rio.open(self.data_source) as src:
                return src.crs
        except Exception as e:
            raise RuntimeError(f"Failed to get CRS for {self.data_source}: {e}")

    def get_transform(self):
        """
        Get affine transform.

        Returns:
        --------
        rasterio.Affine
            Affine transformation
        """
        try:
            with rio.open(self.data_source) as src:
                return src.transform
        except Exception as e:
            raise RuntimeError(f"Failed to get transform for {self.data_source}: {e}")

    @property
    def shape(self) -> Tuple[int, ...]:
        """Get image shape without loading full data."""
        if self.image is not None:
            return self.image.shape
        try:
            with rio.open(self.data_source) as src:
                return (src.count, src.height, src.width)
        except Exception as e:
            raise RuntimeError(f"Failed to get shape for {self.data_source}: {e}")

    @property
    def dtype(self):
        """Get image data type without loading full data."""
        try:
            with rio.open(self.data_source) as src:
                return src.dtypes[0]  # Assume all bands have same dtype
        except Exception as e:
            raise RuntimeError(f"Failed to get dtype for {self.data_source}: {e}")

    def __repr__(self) -> str:
        shape_str = f"{self.shape}" if hasattr(self, 'shape') else "Unknown"
        return f"<GeoTIFFImage: {self.path.name}, Shape: {shape_str}>"

dtype property

Get image data type without loading full data.

shape property

Get image shape without loading full data.

__init__(file_path)

Initialize GeoTIFF loader.

Parameters:

file_path : str Path to the GeoTIFF file

Source code in ShallowLearn/io/satellite_data.py
def __init__(self, file_path: str):
    """
    Initialize GeoTIFF loader.

    Parameters:
    -----------
    file_path : str
        Path to the GeoTIFF file
    """
    self.data_source = file_path  # Maintain compatibility with LoadGeoTIFF
    self.path = Path(file_path)
    self.metadata = None
    self.bounds = None
    self.image = None

    if not self.path.exists():
        raise FileNotFoundError(f"GeoTIFF file not found: {file_path}")

get_bounds()

Get spatial bounds of the GeoTIFF file.

Returns:

rasterio.coords.BoundingBox Bounding box (left, bottom, right, top)

Source code in ShallowLearn/io/satellite_data.py
def get_bounds(self) -> rio.coords.BoundingBox:
    """
    Get spatial bounds of the GeoTIFF file.

    Returns:
    --------
    rasterio.coords.BoundingBox
        Bounding box (left, bottom, right, top)
    """
    try:
        with rio.open(self.data_source) as src:
            self.bounds = src.bounds
        return self.bounds
    except Exception as e:
        raise RuntimeError(f"Failed to get bounds for {self.data_source}: {e}")

get_crs()

Get coordinate reference system.

Returns:

rasterio.crs.CRS Coordinate reference system

Source code in ShallowLearn/io/satellite_data.py
def get_crs(self):
    """
    Get coordinate reference system.

    Returns:
    --------
    rasterio.crs.CRS
        Coordinate reference system
    """
    try:
        with rio.open(self.data_source) as src:
            return src.crs
    except Exception as e:
        raise RuntimeError(f"Failed to get CRS for {self.data_source}: {e}")

get_metadata()

Get rasterio metadata for the GeoTIFF file.

Returns:

dict Rasterio metadata dictionary

Source code in ShallowLearn/io/satellite_data.py
def get_metadata(self) -> dict:
    """
    Get rasterio metadata for the GeoTIFF file.

    Returns:
    --------
    dict
        Rasterio metadata dictionary
    """
    try:
        with rio.open(self.data_source) as src:
            self.metadata = src.meta.copy()
        return self.metadata
    except Exception as e:
        raise RuntimeError(f"Failed to get metadata for {self.data_source}: {e}")

get_transform()

Get affine transform.

Returns:

rasterio.Affine Affine transformation

Source code in ShallowLearn/io/satellite_data.py
def get_transform(self):
    """
    Get affine transform.

    Returns:
    --------
    rasterio.Affine
        Affine transformation
    """
    try:
        with rio.open(self.data_source) as src:
            return src.transform
    except Exception as e:
        raise RuntimeError(f"Failed to get transform for {self.data_source}: {e}")

load()

Load GeoTIFF data with backwards compatibility.

Returns:

np.ndarray Image data with shape (bands, height, width)

Source code in ShallowLearn/io/satellite_data.py
def load(self) -> np.ndarray:
    """
    Load GeoTIFF data with backwards compatibility.

    Returns:
    --------
    np.ndarray
        Image data with shape (bands, height, width)
    """
    try:
        with rio.open(self.data_source) as src:
            # Read all bands
            data = src.read()

            # Handle nodata values similar to original LoadGeoTIFF
            no_data = src.nodatavals
            if no_data and any(nd is not None for nd in no_data):
                # Create mask for nodata values but don't apply it yet
                # (maintaining compatibility with original behavior)
                pass

            self.image = data
            return data

    except Exception as e:
        raise RuntimeError(f"Failed to load GeoTIFF {self.data_source}: {e}")

LandsatImage

Bases: SatelliteImage

Landsat image with strict band ordering and missing band handling.

Source code in ShallowLearn/io/satellite_data.py
class LandsatImage(SatelliteImage):
    """Landsat image with strict band ordering and missing band handling."""

    @property
    def band_order(self) -> Dict[str, int]:
        """Canonical Landsat band order with index mapping."""
        return {
            "B1": 0,  # Coastal/Aerosol
            "B2": 1,  # Blue
            "B3": 2,  # Green
            "B4": 3,  # Red
            "B5": 4,  # NIR
            "B6": 5,  # SWIR1 (Landsat 8/9) or Thermal (Landsat 7)
            "B7": 6,  # SWIR2
            "B8": 7,  # Panchromatic (Landsat 8/9)
            "B9": 8,  # Cirrus (Landsat 8/9)
            "B10": 9,  # Thermal 1 (Landsat 8/9)
            "B11": 10,  # Thermal 2 (Landsat 8/9)
            "SAA": 11,  # Solar Azimuth Angle
            "SZA": 12,  # Solar Zenith Angle
            "VAA": 13,  # View Azimuth Angle
            "VZA": 14,  # View Zenith Angle
            "PIXEL": 15,  # Pixel QA
            "RADSAT": 16,  # Radiometric Saturation
        }

    def _load_image(self):
        """Load Landsat image data from VRT file."""
        self.mtl_tags = {}

        with rio.open(self.path) as src:
            # Store metadata
            self.meta = src.meta.copy()
            self.tags = src.tags()
            self.mtl_tags = src.tags(ns="MTL")

            # Read bands individually to handle mixed dtypes in Landsat VRTs
            # VRTs can contain spectral bands + metadata bands with different dtypes
            band_data = {}
            for i in range(src.count):
                band_desc = (
                    src.descriptions[i] if src.descriptions[i] else f"Band_{i + 1}"
                )
                if band_desc in self.band_order:
                    if band_desc in band_data.keys():
                        continue
                    # Read individual band to handle mixed dtypes
                    try:
                        band_array = src.read(i + 1)
                        band_data[band_desc] = band_array
                        self.present_bands.add(band_desc)
                    except Exception as e:
                        print(f"Warning: Could not read band {band_desc}: {e}")
                        continue

            # Create ordered array with placeholders for missing bands
            self._create_ordered_array(band_data)

    def _create_ordered_array(self, band_data: Dict[str, np.ndarray]):
        """Create ordered array with NaN placeholders for missing bands."""
        ordered_bands = []
        self.band_status = {}

        for band_name in sorted(self.band_order, key=lambda x: self.band_order[x]):
            if band_name in band_data:
                ordered_bands.append(band_data[band_name])
                self.band_status[band_name] = True
            else:
                # Create zero placeholder with same shape/dtype as existing bands
                if band_data:
                    first_band = next(iter(band_data.values()))
                    placeholder = np.zeros_like(first_band)
                else:
                    placeholder = np.zeros(
                        (self.meta["height"], self.meta["width"]), dtype="uint16"
                    )
                ordered_bands.append(placeholder)
                self.band_status[band_name] = False

        self.image = np.stack(ordered_bands, axis=0)
        # (height, width, channels)
        self.image = np.transpose(self.image, (1, 2, 0))

    def get_rgb_bands(self) -> Tuple[str, str, str]:
        """Get RGB band combination for Landsat."""
        return ("B4", "B3", "B2")  # Red, Green, Blue

    def get_spectral_bands(self) -> List[str]:
        """Get list of spectral bands (excluding QA and angle bands)."""
        spectral_bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
        if self.has_band("B8"):  # Landsat 8/9 has panchromatic
            spectral_bands.append("B8")
        if self.has_band("B9"):  # Landsat 8/9 has cirrus
            spectral_bands.append("B9")
        return spectral_bands

band_order property

Canonical Landsat band order with index mapping.

get_rgb_bands()

Get RGB band combination for Landsat.

Source code in ShallowLearn/io/satellite_data.py
def get_rgb_bands(self) -> Tuple[str, str, str]:
    """Get RGB band combination for Landsat."""
    return ("B4", "B3", "B2")  # Red, Green, Blue

get_spectral_bands()

Get list of spectral bands (excluding QA and angle bands).

Source code in ShallowLearn/io/satellite_data.py
def get_spectral_bands(self) -> List[str]:
    """Get list of spectral bands (excluding QA and angle bands)."""
    spectral_bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
    if self.has_band("B8"):  # Landsat 8/9 has panchromatic
        spectral_bands.append("B8")
    if self.has_band("B9"):  # Landsat 8/9 has cirrus
        spectral_bands.append("B9")
    return spectral_bands

LandsatImageCollection

Bases: SatelliteImageCollection

Managed collection of Landsat images with date sorting and strict band order.

Source code in ShallowLearn/io/satellite_data.py
class LandsatImageCollection(SatelliteImageCollection):
    """Managed collection of Landsat images with date sorting and strict band order."""

    def _get_sorted_image_files(self) -> List[Path]:
        """Get Landsat VRT files sorted by date."""

        def extract_date(filename):
            """Extract date from Landsat filename."""
            parts = filename.stem.split("_")
            if len(parts) > 3 and re.match(r"\d{8}", parts[3]):
                return parts[3]
            return ""

        # Look for VRT files, exclude Landsat 7 if needed
        files = [f for f in self.directory.glob("*.vrt") if "LE07" not in f.name]
        return sorted(files, key=extract_date)

    def _create_image(self, file_path: Path) -> LandsatImage:
        """Create LandsatImage instance."""
        return LandsatImage(file_path)

SatelliteImage

Bases: ABC

Abstract base class for satellite images with consistent interface.

Source code in ShallowLearn/io/satellite_data.py
class SatelliteImage(ABC):
    """Abstract base class for satellite images with consistent interface."""

    def __init__(self, file_path: str):
        self.path = Path(file_path)
        self.meta = {}
        self.tags = {}
        self.present_bands = set()
        self.band_status = {}
        self.image = None

        self._load_image()

    @property
    @abstractmethod
    def band_order(self) -> Dict[str, int]:
        """Define the canonical band order for this satellite type."""
        pass

    @abstractmethod
    def _load_image(self):
        """Load image data with satellite-specific logic."""
        pass

    def __repr__(self):
        band_list = [
            f"{b} {'✓' if self.band_status.get(b, False) else '✗'}"
            for b in sorted(self.band_order, key=lambda x: self.band_order[x])
        ]
        missing_count = sum(not v for v in self.band_status.values())
        return (
            f"<{self.__class__.__name__}: {self.path}\n"
            f"  Bands: {band_list}\n"
            f"  Shape: {self.image.shape if self.image is not None else 'Not loaded'}\n"
            f"  Missing: {missing_count} bands>"
        )

    def get_band_data(self, band_name: str) -> Optional[np.ndarray]:
        """Get data for a specific band."""
        if band_name not in self.band_order:
            raise ValueError(f"Unknown band: {band_name}")

        band_index = self.band_order[band_name]
        if self.image is not None and band_index < self.image.shape[2]:
            return self.image[:, :, band_index]
        return None

    def has_band(self, band_name: str) -> bool:
        """Check if band is present (not a placeholder)."""
        return self.band_status.get(band_name, False)

    def get_metadata(self) -> Dict:
        """Get image metadata."""
        return self.meta if hasattr(self, 'meta') and self.meta else {}

    def get_bounds(self):
        """Get image bounds."""
        if hasattr(self, 'meta') and self.meta and 'transform' in self.meta:
            # Calculate bounds from transform and dimensions
            transform = self.meta['transform']
            width = self.meta['width']
            height = self.meta['height']

            # Calculate corner coordinates
            left = transform.c
            top = transform.f
            right = transform.c + width * transform.a
            bottom = transform.f + height * transform.e

            # Return bounds in a format similar to rasterio.coords.BoundingBox
            from rasterio.coords import BoundingBox
            return BoundingBox(left=left, bottom=bottom, right=right, top=top)
        return None

    def get_rgb_bands(self) -> Tuple[str, str, str]:
        """Get the typical RGB band combination for this satellite."""
        # Default implementation - should be overridden by subclasses
        return ("B4", "B3", "B2")  # Red, Green, Blue for most satellites

band_order abstractmethod property

Define the canonical band order for this satellite type.

get_band_data(band_name)

Get data for a specific band.

Source code in ShallowLearn/io/satellite_data.py
def get_band_data(self, band_name: str) -> Optional[np.ndarray]:
    """Get data for a specific band."""
    if band_name not in self.band_order:
        raise ValueError(f"Unknown band: {band_name}")

    band_index = self.band_order[band_name]
    if self.image is not None and band_index < self.image.shape[2]:
        return self.image[:, :, band_index]
    return None

get_bounds()

Get image bounds.

Source code in ShallowLearn/io/satellite_data.py
def get_bounds(self):
    """Get image bounds."""
    if hasattr(self, 'meta') and self.meta and 'transform' in self.meta:
        # Calculate bounds from transform and dimensions
        transform = self.meta['transform']
        width = self.meta['width']
        height = self.meta['height']

        # Calculate corner coordinates
        left = transform.c
        top = transform.f
        right = transform.c + width * transform.a
        bottom = transform.f + height * transform.e

        # Return bounds in a format similar to rasterio.coords.BoundingBox
        from rasterio.coords import BoundingBox
        return BoundingBox(left=left, bottom=bottom, right=right, top=top)
    return None

get_metadata()

Get image metadata.

Source code in ShallowLearn/io/satellite_data.py
def get_metadata(self) -> Dict:
    """Get image metadata."""
    return self.meta if hasattr(self, 'meta') and self.meta else {}

get_rgb_bands()

Get the typical RGB band combination for this satellite.

Source code in ShallowLearn/io/satellite_data.py
def get_rgb_bands(self) -> Tuple[str, str, str]:
    """Get the typical RGB band combination for this satellite."""
    # Default implementation - should be overridden by subclasses
    return ("B4", "B3", "B2")  # Red, Green, Blue for most satellites

has_band(band_name)

Check if band is present (not a placeholder).

Source code in ShallowLearn/io/satellite_data.py
def has_band(self, band_name: str) -> bool:
    """Check if band is present (not a placeholder)."""
    return self.band_status.get(band_name, False)

SatelliteImageCollection

Bases: ABC

Abstract base class for collections of satellite images.

Source code in ShallowLearn/io/satellite_data.py
class SatelliteImageCollection(ABC):
    """Abstract base class for collections of satellite images."""

    def __init__(self, directory: str):
        self.directory = Path(directory)
        self.image_files = self._get_sorted_image_files()
        self.images = [self._create_image(f) for f in self.image_files]

    @abstractmethod
    def _get_sorted_image_files(self) -> List[Path]:
        """Get sorted list of image files."""
        pass

    @abstractmethod
    def _create_image(self, file_path: Path) -> SatelliteImage:
        """Create appropriate satellite image instance."""
        pass

    def __iter__(self):
        return iter(self.images)

    def __getitem__(self, index):
        return self.images[index]

    def __len__(self):
        return len(self.images)

    def __repr__(self):
        return f"<{self.__class__.__name__} count={len(self)}>"

    def common_bands(self) -> Set[str]:
        """Get bands present in ALL images."""
        if not self.images:
            return set()
        return set.intersection(*[img.present_bands for img in self.images])

    def get_common_bands_array(self) -> np.ndarray:
        """Get array of common bands across all images with consistent spatial dimensions."""
        common_bands = self.common_bands()
        if not common_bands or not self.images:
            return np.array([])

        # Get canonical band indices for common bands
        first_image = self.images[0]
        band_indices = sorted([first_image.band_order[b] for b in common_bands])

        # Find maximum spatial dimensions
        max_height = max(img.image.shape[0] for img in self.images)
        max_width = max(img.image.shape[1] for img in self.images)

        # Resize and stack images
        resized_images = []
        for img in self.images:
            resized_bands = []
            for channel in range(img.image.shape[2]):
                band = img.image[:, :, channel]

                # Convert to PIL Image and resize if needed
                if band.shape != (max_height, max_width):
                    pil_band = Image.fromarray(band)
                    resized_band = pil_band.resize(
                        (max_width, max_height), Image.BILINEAR
                    )
                    resized_bands.append(np.array(resized_band))
                else:
                    resized_bands.append(band)

            # Stack resized bands and select common channels
            resized_img = np.stack(resized_bands, axis=2)
            resized_images.append(resized_img[:, :, band_indices])

        return np.stack(resized_images, axis=0)

common_bands()

Get bands present in ALL images.

Source code in ShallowLearn/io/satellite_data.py
def common_bands(self) -> Set[str]:
    """Get bands present in ALL images."""
    if not self.images:
        return set()
    return set.intersection(*[img.present_bands for img in self.images])

get_common_bands_array()

Get array of common bands across all images with consistent spatial dimensions.

Source code in ShallowLearn/io/satellite_data.py
def get_common_bands_array(self) -> np.ndarray:
    """Get array of common bands across all images with consistent spatial dimensions."""
    common_bands = self.common_bands()
    if not common_bands or not self.images:
        return np.array([])

    # Get canonical band indices for common bands
    first_image = self.images[0]
    band_indices = sorted([first_image.band_order[b] for b in common_bands])

    # Find maximum spatial dimensions
    max_height = max(img.image.shape[0] for img in self.images)
    max_width = max(img.image.shape[1] for img in self.images)

    # Resize and stack images
    resized_images = []
    for img in self.images:
        resized_bands = []
        for channel in range(img.image.shape[2]):
            band = img.image[:, :, channel]

            # Convert to PIL Image and resize if needed
            if band.shape != (max_height, max_width):
                pil_band = Image.fromarray(band)
                resized_band = pil_band.resize(
                    (max_width, max_height), Image.BILINEAR
                )
                resized_bands.append(np.array(resized_band))
            else:
                resized_bands.append(band)

        # Stack resized bands and select common channels
        resized_img = np.stack(resized_bands, axis=2)
        resized_images.append(resized_img[:, :, band_indices])

    return np.stack(resized_images, axis=0)

Sentinel2Image

Bases: SatelliteImage

Sentinel-2 image with band ordering and missing band handling.

Source code in ShallowLearn/io/satellite_data.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
class Sentinel2Image(SatelliteImage):
    """Sentinel-2 image with band ordering and missing band handling."""

    def __init__(self, file_path: str, load_all_bands: bool = False, target_resolution: str = "10m", 
                 clip_geometry=None, buffer_meters: float = 0):
        """
        Initialize Sentinel-2 image.

        Parameters:
        -----------
        file_path : str
            Path to Sentinel-2 file (.SAFE directory, .zip file, or MTD XML file)
        load_all_bands : bool
            If True, loads all 13 bands by resampling from different resolution subdatasets.
            If False, loads only the native resolution bands (default: 4 bands at 10m)
        target_resolution : str
            Target resolution when load_all_bands=True ("10m", "20m", "60m")
        clip_geometry : shapely geometry or GeoDataFrame, optional
            Geometry to clip to during loading for efficiency
        buffer_meters : float
            Buffer distance in meters to add around clip_geometry
        """
        self.load_all_bands = load_all_bands
        self.target_resolution = target_resolution
        self.clip_geometry = clip_geometry
        self.buffer_meters = buffer_meters
        super().__init__(file_path)

    @property
    def band_order(self) -> Dict[str, int]:
        """Canonical Sentinel-2 band order with index mapping."""
        return {
            "B01": 0,  # Coastal aerosol (60m)
            "B02": 1,  # Blue (10m)
            "B03": 2,  # Green (10m)
            "B04": 3,  # Red (10m)
            "B05": 4,  # Red Edge 1 (20m)
            "B06": 5,  # Red Edge 2 (20m)
            "B07": 6,  # Red Edge 3 (20m)
            "B08": 7,  # NIR (10m)
            "B8A": 8,  # NIR narrow (20m)
            "B09": 9,  # Water vapour (60m)
            "B10": 10,  # Cirrus (60m)
            "B11": 11,  # SWIR 1 (20m)
            "B12": 12,  # SWIR 2 (20m)
        }

    def _load_image(self):
        """Load Sentinel-2 image data using subdatasets approach like original LoadSentinel2L1C."""
        import zipfile
        import os

        # Handle different input types
        if str(self.path).endswith(".vrt"):
            # Handle VRT files created by the VRT builder
            self.is_vrt = True
            with rio.open(self.path) as src:
                # Store metadata
                self.meta = src.meta.copy()
                self.tags = src.tags()

                # Read bands individually to handle any mixed dtypes
                band_data = {}
                for i in range(src.count):
                    band_desc = (
                        src.descriptions[i] if src.descriptions[i] else f"Band_{i + 1}"
                    )
                    # Handle the naming convention (B2 vs B02)
                    if band_desc.startswith('B') and len(band_desc) == 2:
                        band_desc = f"B0{band_desc[1]}"

                    if band_desc in self.band_order:
                        try:
                            band_array = src.read(i + 1)
                            band_data[band_desc] = band_array
                            self.present_bands.add(band_desc)
                        except Exception as e:
                            print(f"Warning: Could not read band {band_desc}: {e}")
                            continue

                # Create ordered array with placeholders for missing bands
                self._create_ordered_array(band_data)
                return

        elif str(self.path).endswith(".zip"):
            self.is_zip = True
            with zipfile.ZipFile(self.path, 'r') as zip_ref:
                files = [
                    f for f in zip_ref.namelist()
                    if "MTD_MSIL1C.xml" in f or "MTD_MSIL2A.xml" in f
                ]
            if len(files) != 1:
                raise Exception("Multiple or no MTD files found in ZIP.")

            zip_path = f"/vsizip/{self.path}"
            metadata_file = os.path.join(zip_path, files[0])
        else:
            # Handle .SAFE directories or direct XML files
            self.is_zip = False
            if str(self.path).endswith(".xml"):
                metadata_file = str(self.path)
            else:
                # Assume it's a .SAFE directory - use simple file finding for now
                # This avoids import issues with missing modules
                from pathlib import Path
                safe_path = Path(self.path)
                mtd_files = list(safe_path.rglob("MTD_MSIL1C*.xml"))
                if not mtd_files:
                    mtd_files = list(safe_path.rglob("MTD_MSIL2A*.xml"))
                if len(mtd_files) != 1:
                    raise Exception(f"Found {len(mtd_files)} MTD files, expected 1")
                metadata_file = str(mtd_files[0])

        # Load subdatasets using the original approach
        with rio.open(metadata_file) as dataset:
            subdatasets = dataset.subdatasets

        if not subdatasets:
            raise ValueError("No subdatasets found in the Sentinel-2 file")

        # Get metadata from the first subdataset
        with rio.open(subdatasets[0]) as ds:
            self.tags = ds.tags()
            self.meta = ds.meta.copy()

        # Load bands based on configuration
        if self.load_all_bands:
            # Load all bands from multiple resolution subdatasets
            self._load_all_resolution_bands(subdatasets, self.target_resolution)
        else:
            # Use 10m resolution bands as default (original behavior)
            resolution_10m = [s for s in subdatasets if "10m" in s]

            if resolution_10m:
                # Calculate clipping window if geometry provided
                clip_window = None
                if self.clip_geometry is not None:
                    clip_window = self._calculate_clip_window(resolution_10m[0])
                    if clip_window is None:
                        print("Warning: Invalid clip window, loading full 10m image")

                # Load the 10m resolution data
                with rio.open(resolution_10m[0]) as src:
                    try:
                        # Read all bands (with clipping if specified)
                        if clip_window:
                            data = src.read(window=clip_window)  # Shape: (bands, height, width)
                            # Update metadata for clipped region
                            from rasterio.windows import transform as window_transform
                            self.meta['transform'] = window_transform(clip_window, src.meta['transform'])
                            self.meta['width'] = int(clip_window.width)
                            self.meta['height'] = int(clip_window.height)
                        else:
                            data = src.read()  # Shape: (bands, height, width)
                    except Exception as e:
                        print(f"Error reading 10m data: {e}")
                        if clip_window:
                            print(f"Window: {clip_window.width} x {clip_window.height}")
                        # Fallback to full image
                        data = src.read()

                    # Map bands based on descriptions
                    band_data = {}
                    for i in range(src.count):
                        band_desc = (
                            src.descriptions[i].split(",")[0] if src.descriptions[i] else f"Band_{i + 1}"
                        )
                        # Handle the naming convention (B2 vs B02)
                        if band_desc.startswith('B') and len(band_desc) == 2:
                            band_desc = f"B0{band_desc[1]}"

                        if band_desc in self.band_order:
                            band_data[band_desc] = data[i]
                            self.present_bands.add(band_desc)

                    # Create ordered array with placeholders for missing bands
                    self._create_ordered_array(band_data)
            else:
                # Fallback: use the first available subdataset
                with rio.open(subdatasets[0]) as src:
                    data = src.read()
                    self.image = np.transpose(data, (1, 2, 0))  # (height, width, bands)

    def _create_ordered_array(self, band_data: Dict[str, np.ndarray]):
        """Create ordered array with NaN placeholders for missing bands."""
        ordered_bands = []
        self.band_status = {}

        for band_name in sorted(self.band_order, key=lambda x: self.band_order[x]):
            if band_name in band_data:
                ordered_bands.append(band_data[band_name])
                self.band_status[band_name] = True
            else:
                # Create zero placeholder (avoid NaN casting issues with integer dtypes)
                if band_data:
                    first_band = next(iter(band_data.values()))
                    placeholder = np.zeros_like(first_band)
                else:
                    placeholder = np.zeros(
                        (self.meta["height"], self.meta["width"]), dtype="uint16"
                    )
                ordered_bands.append(placeholder)
                self.band_status[band_name] = False

        self.image = np.stack(ordered_bands, axis=0)
        # (height, width, channels)
        self.image = np.transpose(self.image, (1, 2, 0))

    def get_rgb_bands(self) -> Tuple[str, str, str]:
        """Get RGB band combination for Sentinel-2."""
        return ("B04", "B03", "B02")  # Red, Green, Blue

    def get_spectral_bands(self) -> List[str]:
        """Get list of all spectral bands."""
        return list(self.band_order.keys())

    def get_resolution_groups(self) -> Dict[str, List[str]]:
        """Get bands grouped by native resolution."""
        return {
            "10m": ["B02", "B03", "B04", "B08"],
            "20m": ["B05", "B06", "B07", "B8A", "B11", "B12"],
            "60m": ["B01", "B09", "B10"],
        }

    def _load_all_resolution_bands(self, subdatasets: List[str], target_resolution: str = "10m"):
        """
        Load all bands from multiple resolution subdatasets and resample to target resolution.
        Applies clipping during loading if clip_geometry is specified for efficiency.

        Parameters:
        -----------
        subdatasets : List[str]
            List of subdataset URIs
        target_resolution : str
            Target resolution to resample all bands to ("10m", "20m", "60m")
        """
        from rasterio.warp import reproject, Resampling
        from rasterio.enums import Resampling as ResamplingEnum
        from rasterio.windows import from_bounds

        # Find subdatasets by resolution
        resolution_subdatasets = {}
        for subdataset in subdatasets:
            if ":10m:" in subdataset:
                resolution_subdatasets["10m"] = subdataset
            elif ":20m:" in subdataset:
                resolution_subdatasets["20m"] = subdataset
            elif ":60m:" in subdataset:
                resolution_subdatasets["60m"] = subdataset

        if not resolution_subdatasets:
            raise ValueError("No resolution subdatasets found")

        # Get target resolution parameters
        target_subdataset = resolution_subdatasets.get(target_resolution)
        if not target_subdataset:
            # Fallback to 10m if target not available
            target_resolution = "10m"
            target_subdataset = resolution_subdatasets.get(target_resolution)

        if not target_subdataset:
            raise ValueError("No suitable target resolution found")

        # Calculate clipping window if geometry provided
        clip_window = None
        if self.clip_geometry is not None:
            clip_window = self._calculate_clip_window(target_subdataset)
            if clip_window is None:
                print("Warning: Invalid clip window, loading full image")

        # Get target grid parameters (from clipped region if applicable)
        with rio.open(target_subdataset) as target_ds:
            if clip_window:
                target_width = int(clip_window.width)
                target_height = int(clip_window.height)
                from rasterio.windows import transform as window_transform
                target_transform = window_transform(clip_window, target_ds.transform)
            else:
                target_width = target_ds.width
                target_height = target_ds.height
                target_transform = target_ds.transform
            target_crs = target_ds.crs

        # Load and resample all bands
        band_data = {}
        for resolution, subdataset in resolution_subdatasets.items():
            with rio.open(subdataset) as src:
                # Calculate appropriate window for this resolution
                if clip_window and resolution != target_resolution:
                    # Scale window to match resolution ratio
                    scale_factor = self._get_resolution_scale_factor(target_resolution, resolution)
                    res_window = self._scale_window(clip_window, scale_factor)
                else:
                    res_window = clip_window

                for i in range(src.count):
                    # Parse band name from description
                    band_desc = src.descriptions[i].split(",")[0] if src.descriptions[i] else f"Band_{i + 1}"
                    # Handle naming convention (B2 vs B02)
                    if band_desc.startswith('B') and len(band_desc) == 2:
                        band_desc = f"B0{band_desc[1]}"

                    if band_desc in self.band_order:
                        try:
                            # Read the band (with clipping if window specified)
                            if res_window:
                                band_array = src.read(i + 1, window=res_window)
                            else:
                                band_array = src.read(i + 1)
                        except Exception as e:
                            print(f"Error reading band {band_desc} from {resolution}: {e}")
                            if res_window:
                                print(f"Window: {res_window.width} x {res_window.height}")
                            continue

                        if resolution == target_resolution:
                            # No resampling needed
                            band_data[band_desc] = band_array
                        else:
                            # Resample to target resolution
                            resampled_array = np.empty((target_height, target_width), dtype=band_array.dtype)

                            # Get source transform for this window
                            if res_window:
                                from rasterio.windows import transform as window_transform
                                src_transform = window_transform(res_window, src.transform)
                            else:
                                src_transform = src.transform

                            reproject(
                                band_array,
                                resampled_array,
                                src_transform=src_transform,
                                src_crs=src.crs,
                                dst_transform=target_transform,
                                dst_crs=target_crs,
                                resampling=ResamplingEnum.bilinear
                            )
                            band_data[band_desc] = resampled_array

                        self.present_bands.add(band_desc)

        # Create ordered array with placeholders for missing bands
        self._create_ordered_array(band_data)

    def _calculate_clip_window(self, reference_subdataset: str):
        """Calculate clipping window from geometry for the reference subdataset."""
        import geopandas as gpd
        from rasterio.windows import from_bounds

        # Handle different geometry types
        if hasattr(self.clip_geometry, 'geometry'):
            # It's a GeoDataFrame
            gdf = self.clip_geometry
        else:
            # It's a geometry - create a GeoDataFrame
            gdf = gpd.GeoDataFrame([1], geometry=[self.clip_geometry], crs="EPSG:4326")

        # Get CRS from reference subdataset
        with rio.open(reference_subdataset) as ref_ds:
            target_crs = ref_ds.crs

            # Reproject geometry to match image CRS if needed
            if gdf.crs != target_crs:
                gdf = gdf.to_crs(target_crs)

            # Apply buffer if specified
            if self.buffer_meters > 0:
                gdf_buffered = gdf.copy()
                gdf_buffered.geometry = gdf.geometry.buffer(self.buffer_meters)
                bounds = gdf_buffered.total_bounds
            else:
                bounds = gdf.total_bounds

            # Calculate window from bounds
            window = from_bounds(
                bounds[0], bounds[1], bounds[2], bounds[3],  # left, bottom, right, top
                transform=ref_ds.transform
            )

            # Round and clip to dataset bounds
            from rasterio.windows import Window
            window = window.round_lengths().round_offsets()
            dataset_window = Window(0, 0, ref_ds.width, ref_ds.height)
            window = window.intersection(dataset_window)

            # Ensure window has valid dimensions
            if window.width <= 0 or window.height <= 0:
                print(f"Warning: Invalid window dimensions: {window.width} x {window.height}")
                print(f"Bounds: {bounds}")
                print(f"Dataset size: {ref_ds.width} x {ref_ds.height}")
                print(f"Transform: {ref_ds.transform}")
                return None

            print(f"Calculated clip window: {window.width} x {window.height} at ({window.col_off}, {window.row_off})")
            return window

    def _get_resolution_scale_factor(self, target_res: str, source_res: str) -> float:
        """Get scale factor between resolutions."""
        resolution_values = {"10m": 10, "20m": 20, "60m": 60}
        return resolution_values[source_res] / resolution_values[target_res]

    def _scale_window(self, window, scale_factor: float):
        """Scale a window by the given factor."""
        from rasterio.windows import Window
        if window is None:
            return None

        scaled = Window(
            col_off=window.col_off / scale_factor,  # Inverse scaling for higher resolution
            row_off=window.row_off / scale_factor,
            width=window.width / scale_factor,
            height=window.height / scale_factor
        ).round_lengths().round_offsets()

        # Ensure valid dimensions
        if scaled.width <= 0 or scaled.height <= 0:
            print(f"Warning: Invalid scaled window: {scaled.width} x {scaled.height}")
            return None

        return scaled

    def clip_to_bounds(self, bounds, buffer_pixels: int = 0):
        """
        Clip image data to specified bounds.

        Parameters:
        -----------
        bounds : tuple or BoundingBox
            Bounds to clip to (left, bottom, right, top) or rasterio BoundingBox
        buffer_pixels : int
            Number of pixels to add as buffer around the clipped area

        Returns:
        --------
        Sentinel2Image
            New Sentinel2Image instance with clipped data
        """
        if self.image is None or not hasattr(self, 'meta'):
            raise ValueError("Image and metadata must be loaded before clipping")

        from rasterio.coords import BoundingBox
        from rasterio.windows import from_bounds
        from copy import deepcopy

        # Ensure bounds is a BoundingBox
        if not isinstance(bounds, BoundingBox):
            bounds = BoundingBox(*bounds)

        # Calculate window from bounds
        window = from_bounds(
            bounds.left, bounds.bottom, bounds.right, bounds.top,
            transform=self.meta['transform']
        )

        # Apply buffer if specified
        if buffer_pixels > 0:
            window = window.expand(buffer_pixels)

        # Round window to integer pixels
        window = window.round_lengths().round_offsets()

        # Clip the window to image boundaries  
        from rasterio.windows import Window
        image_window = Window(0, 0, self.meta['width'], self.meta['height'])
        window = window.intersection(image_window)

        if window.width <= 0 or window.height <= 0:
            raise ValueError("Clipping bounds do not intersect with image")

        # Extract the clipped data
        row_slice = slice(int(window.row_off), int(window.row_off + window.height))
        col_slice = slice(int(window.col_off), int(window.col_off + window.width))

        clipped_image = self.image[row_slice, col_slice, :]

        # Create new instance with clipped data
        clipped_s2 = self.__class__.__new__(self.__class__)
        clipped_s2.path = self.path
        clipped_s2.present_bands = self.present_bands.copy()
        clipped_s2.band_status = self.band_status.copy()
        clipped_s2.tags = self.tags.copy()
        clipped_s2.image = clipped_image

        # Update metadata
        clipped_s2.meta = deepcopy(self.meta)
        clipped_s2.meta['width'] = int(window.width)
        clipped_s2.meta['height'] = int(window.height)

        # Update transform
        from rasterio.windows import transform as window_transform
        clipped_s2.meta['transform'] = window_transform(window, self.meta['transform'])

        return clipped_s2

    def clip_to_geometry(self, geometry, buffer_meters: float = 0):
        """
        Clip image data to a geometry (e.g., from a GeoDataFrame).

        Parameters:
        -----------
        geometry : shapely geometry or GeoDataFrame
            Geometry to clip to
        buffer_meters : float
            Buffer distance in meters to add around the geometry

        Returns:
        --------
        Sentinel2Image
            New Sentinel2Image instance with clipped data
        """
        import geopandas as gpd
        from shapely.geometry import box

        # Handle GeoDataFrame input
        if hasattr(geometry, 'geometry'):
            # It's a GeoDataFrame
            gdf = geometry
        else:
            # It's a geometry - create a GeoDataFrame
            gdf = gpd.GeoDataFrame([1], geometry=[geometry], crs="EPSG:4326")

        # Reproject to image CRS if needed
        if gdf.crs != self.meta['crs']:
            gdf = gdf.to_crs(self.meta['crs'])

        # Apply buffer if specified
        if buffer_meters > 0:
            gdf_buffered = gdf.copy()
            gdf_buffered.geometry = gdf.geometry.buffer(buffer_meters)
            geometry_bounds = gdf_buffered.total_bounds
        else:
            geometry_bounds = gdf.total_bounds

        # Clip to bounding box first
        return self.clip_to_bounds(geometry_bounds)

band_order property

Canonical Sentinel-2 band order with index mapping.

__init__(file_path, load_all_bands=False, target_resolution='10m', clip_geometry=None, buffer_meters=0)

Initialize Sentinel-2 image.

Parameters:

file_path : str Path to Sentinel-2 file (.SAFE directory, .zip file, or MTD XML file) load_all_bands : bool If True, loads all 13 bands by resampling from different resolution subdatasets. If False, loads only the native resolution bands (default: 4 bands at 10m) target_resolution : str Target resolution when load_all_bands=True ("10m", "20m", "60m") clip_geometry : shapely geometry or GeoDataFrame, optional Geometry to clip to during loading for efficiency buffer_meters : float Buffer distance in meters to add around clip_geometry

Source code in ShallowLearn/io/satellite_data.py
def __init__(self, file_path: str, load_all_bands: bool = False, target_resolution: str = "10m", 
             clip_geometry=None, buffer_meters: float = 0):
    """
    Initialize Sentinel-2 image.

    Parameters:
    -----------
    file_path : str
        Path to Sentinel-2 file (.SAFE directory, .zip file, or MTD XML file)
    load_all_bands : bool
        If True, loads all 13 bands by resampling from different resolution subdatasets.
        If False, loads only the native resolution bands (default: 4 bands at 10m)
    target_resolution : str
        Target resolution when load_all_bands=True ("10m", "20m", "60m")
    clip_geometry : shapely geometry or GeoDataFrame, optional
        Geometry to clip to during loading for efficiency
    buffer_meters : float
        Buffer distance in meters to add around clip_geometry
    """
    self.load_all_bands = load_all_bands
    self.target_resolution = target_resolution
    self.clip_geometry = clip_geometry
    self.buffer_meters = buffer_meters
    super().__init__(file_path)

clip_to_bounds(bounds, buffer_pixels=0)

Clip image data to specified bounds.

Parameters:

bounds : tuple or BoundingBox Bounds to clip to (left, bottom, right, top) or rasterio BoundingBox buffer_pixels : int Number of pixels to add as buffer around the clipped area

Returns:

Sentinel2Image New Sentinel2Image instance with clipped data

Source code in ShallowLearn/io/satellite_data.py
def clip_to_bounds(self, bounds, buffer_pixels: int = 0):
    """
    Clip image data to specified bounds.

    Parameters:
    -----------
    bounds : tuple or BoundingBox
        Bounds to clip to (left, bottom, right, top) or rasterio BoundingBox
    buffer_pixels : int
        Number of pixels to add as buffer around the clipped area

    Returns:
    --------
    Sentinel2Image
        New Sentinel2Image instance with clipped data
    """
    if self.image is None or not hasattr(self, 'meta'):
        raise ValueError("Image and metadata must be loaded before clipping")

    from rasterio.coords import BoundingBox
    from rasterio.windows import from_bounds
    from copy import deepcopy

    # Ensure bounds is a BoundingBox
    if not isinstance(bounds, BoundingBox):
        bounds = BoundingBox(*bounds)

    # Calculate window from bounds
    window = from_bounds(
        bounds.left, bounds.bottom, bounds.right, bounds.top,
        transform=self.meta['transform']
    )

    # Apply buffer if specified
    if buffer_pixels > 0:
        window = window.expand(buffer_pixels)

    # Round window to integer pixels
    window = window.round_lengths().round_offsets()

    # Clip the window to image boundaries  
    from rasterio.windows import Window
    image_window = Window(0, 0, self.meta['width'], self.meta['height'])
    window = window.intersection(image_window)

    if window.width <= 0 or window.height <= 0:
        raise ValueError("Clipping bounds do not intersect with image")

    # Extract the clipped data
    row_slice = slice(int(window.row_off), int(window.row_off + window.height))
    col_slice = slice(int(window.col_off), int(window.col_off + window.width))

    clipped_image = self.image[row_slice, col_slice, :]

    # Create new instance with clipped data
    clipped_s2 = self.__class__.__new__(self.__class__)
    clipped_s2.path = self.path
    clipped_s2.present_bands = self.present_bands.copy()
    clipped_s2.band_status = self.band_status.copy()
    clipped_s2.tags = self.tags.copy()
    clipped_s2.image = clipped_image

    # Update metadata
    clipped_s2.meta = deepcopy(self.meta)
    clipped_s2.meta['width'] = int(window.width)
    clipped_s2.meta['height'] = int(window.height)

    # Update transform
    from rasterio.windows import transform as window_transform
    clipped_s2.meta['transform'] = window_transform(window, self.meta['transform'])

    return clipped_s2

clip_to_geometry(geometry, buffer_meters=0)

Clip image data to a geometry (e.g., from a GeoDataFrame).

Parameters:

geometry : shapely geometry or GeoDataFrame Geometry to clip to buffer_meters : float Buffer distance in meters to add around the geometry

Returns:

Sentinel2Image New Sentinel2Image instance with clipped data

Source code in ShallowLearn/io/satellite_data.py
def clip_to_geometry(self, geometry, buffer_meters: float = 0):
    """
    Clip image data to a geometry (e.g., from a GeoDataFrame).

    Parameters:
    -----------
    geometry : shapely geometry or GeoDataFrame
        Geometry to clip to
    buffer_meters : float
        Buffer distance in meters to add around the geometry

    Returns:
    --------
    Sentinel2Image
        New Sentinel2Image instance with clipped data
    """
    import geopandas as gpd
    from shapely.geometry import box

    # Handle GeoDataFrame input
    if hasattr(geometry, 'geometry'):
        # It's a GeoDataFrame
        gdf = geometry
    else:
        # It's a geometry - create a GeoDataFrame
        gdf = gpd.GeoDataFrame([1], geometry=[geometry], crs="EPSG:4326")

    # Reproject to image CRS if needed
    if gdf.crs != self.meta['crs']:
        gdf = gdf.to_crs(self.meta['crs'])

    # Apply buffer if specified
    if buffer_meters > 0:
        gdf_buffered = gdf.copy()
        gdf_buffered.geometry = gdf.geometry.buffer(buffer_meters)
        geometry_bounds = gdf_buffered.total_bounds
    else:
        geometry_bounds = gdf.total_bounds

    # Clip to bounding box first
    return self.clip_to_bounds(geometry_bounds)

get_resolution_groups()

Get bands grouped by native resolution.

Source code in ShallowLearn/io/satellite_data.py
def get_resolution_groups(self) -> Dict[str, List[str]]:
    """Get bands grouped by native resolution."""
    return {
        "10m": ["B02", "B03", "B04", "B08"],
        "20m": ["B05", "B06", "B07", "B8A", "B11", "B12"],
        "60m": ["B01", "B09", "B10"],
    }

get_rgb_bands()

Get RGB band combination for Sentinel-2.

Source code in ShallowLearn/io/satellite_data.py
def get_rgb_bands(self) -> Tuple[str, str, str]:
    """Get RGB band combination for Sentinel-2."""
    return ("B04", "B03", "B02")  # Red, Green, Blue

get_spectral_bands()

Get list of all spectral bands.

Source code in ShallowLearn/io/satellite_data.py
def get_spectral_bands(self) -> List[str]:
    """Get list of all spectral bands."""
    return list(self.band_order.keys())

Sentinel2ImageCollection

Bases: SatelliteImageCollection

Managed collection of Sentinel-2 images with date sorting.

Source code in ShallowLearn/io/satellite_data.py
class Sentinel2ImageCollection(SatelliteImageCollection):
    """Managed collection of Sentinel-2 images with date sorting."""

    def _get_sorted_image_files(self) -> List[Path]:
        """Get Sentinel-2 VRT files sorted by date."""

        def extract_date(filename):
            """Extract date from Sentinel-2 filename."""
            # S2A_MSIL2A_20210101T000000_... pattern
            parts = filename.stem.split("_")
            for part in parts:
                if len(part) >= 8 and part.startswith("20") and part[8:9] == "T":
                    return part[:8]  # Extract YYYYMMDD
            return ""

        files = list(self.directory.glob("*.vrt"))
        return sorted(files, key=extract_date)

    def _create_image(self, file_path: Path) -> Sentinel2Image:
        """Create Sentinel2Image instance."""
        return Sentinel2Image(file_path)

batch_compile_geotiffs(source_list, output_dir, satellite_type=None, **kwargs)

Batch compile multiple satellite sources to GeoTIFF files.

Parameters:

source_list : List[str] List of source file/directory paths output_dir : str Output directory for GeoTIFF files satellite_type : str, optional Force specific satellite type **kwargs Additional arguments for compiler

Returns:

List[str] List of created GeoTIFF file paths

Source code in ShallowLearn/io/geotiff_compiler.py
def batch_compile_geotiffs(
    source_list: List[str],
    output_dir: str,
    satellite_type: Optional[str] = None,
    **kwargs,
) -> List[str]:
    """
    Batch compile multiple satellite sources to GeoTIFF files.

    Parameters:
    -----------
    source_list : List[str]
        List of source file/directory paths
    output_dir : str
        Output directory for GeoTIFF files
    satellite_type : str, optional
        Force specific satellite type
    **kwargs
        Additional arguments for compiler

    Returns:
    --------
    List[str]
        List of created GeoTIFF file paths
    """
    if not source_list:
        print("No sources provided")
        return []

    # Auto-detect satellite type if not provided
    if satellite_type is None:
        first_source = Path(source_list[0]).name.upper()
        if any(sat in first_source for sat in ["S2A", "S2B", "MSI"]):
            satellite_type = "sentinel2"
        elif any(sat in first_source for sat in ["LC08", "LC09", "LE07", "LT05", "LT04"]):
            satellite_type = "landsat"
        else:
            # Default to Sentinel-2 if uncertain
            satellite_type = "sentinel2"

    # Create compiler
    compiler = create_geotiff_compiler(satellite_type, output_dir, **kwargs)

    # Process each source
    created_geotiffs = []
    for i, source_path in enumerate(source_list):
        try:
            source_name = Path(source_path).stem
            output_name = f"{source_name}_compiled.tiff"

            geotiff_path = compiler.compile_geotiff(source_path, output_name, **kwargs)
            created_geotiffs.append(geotiff_path)
            print(f"Created: {geotiff_path}")
        except Exception as e:
            print(f"Failed to process {source_path}: {e}")

    print(f"Successfully created {len(created_geotiffs)} GeoTIFF files")
    return created_geotiffs

batch_process_archives(archive_list, output_dir, bounds=None, satellite_type=None, **kwargs)

Batch process multiple satellite archives to VRTs.

Parameters:

archive_list : List[str] List of archive file paths output_dir : str Output directory for VRT files bounds : gpd.GeoDataFrame, optional Geographic bounds for cropping satellite_type : str, optional Force specific satellite type **kwargs Additional arguments for VRT builder

Source code in ShallowLearn/io/vrt_builder.py
def batch_process_archives(
    archive_list: List[str],
    output_dir: str,
    bounds: Optional[gpd.GeoDataFrame] = None,
    satellite_type: Optional[str] = None,
    **kwargs,
):
    """
    Batch process multiple satellite archives to VRTs.

    Parameters:
    -----------
    archive_list : List[str]
        List of archive file paths
    output_dir : str
        Output directory for VRT files
    bounds : gpd.GeoDataFrame, optional
        Geographic bounds for cropping
    satellite_type : str, optional
        Force specific satellite type
    **kwargs
        Additional arguments for VRT builder
    """
    if not archive_list:
        print("No archives provided")
        return

    # Auto-detect satellite type if not provided
    if satellite_type is None:
        first_archive = Path(archive_list[0]).name.upper()
        if any(
            sat in first_archive for sat in ["LC08", "LC09", "LE07", "LT05", "LT04"]
        ):
            satellite_type = "landsat"
        elif any(sat in first_archive for sat in ["S2A", "S2B"]):
            satellite_type = "sentinel2"
        else:
            # Try to detect from file extension
            if archive_list[0].lower().endswith(".tar"):
                satellite_type = "landsat"
            elif archive_list[0].lower().endswith(".zip"):
                satellite_type = "sentinel2"
            else:
                raise ValueError(
                    "Cannot auto-detect satellite type. Please specify satellite_type parameter."
                )

    # Create VRT builder
    vrt_builder = create_vrt_builder(satellite_type, output_dir, **kwargs)

    # Process each archive
    created_vrts = []
    for archive_path in archive_list:
        try:
            vrt_path = vrt_builder.build_vrt(archive_path, bounds, **kwargs)
            created_vrts.append(vrt_path)
        except Exception as e:
            print(f"Failed to process {archive_path}: {e}")

    print(f"Successfully created {len(created_vrts)} VRT files")
    return created_vrts

create_geotiff_compiler(satellite_type, output_dir, **kwargs)

Factory function to create appropriate GeoTIFF compiler.

Parameters:

satellite_type : str Type of satellite ("landsat" or "sentinel2") output_dir : str Output directory for GeoTIFF files **kwargs Additional arguments for compiler

Returns:

GeoTIFFCompiler Appropriate GeoTIFF compiler instance

Source code in ShallowLearn/io/geotiff_compiler.py
def create_geotiff_compiler(
    satellite_type: str, 
    output_dir: str, 
    **kwargs
) -> GeoTIFFCompiler:
    """
    Factory function to create appropriate GeoTIFF compiler.

    Parameters:
    -----------
    satellite_type : str
        Type of satellite ("landsat" or "sentinel2")
    output_dir : str
        Output directory for GeoTIFF files
    **kwargs
        Additional arguments for compiler

    Returns:
    --------
    GeoTIFFCompiler
        Appropriate GeoTIFF compiler instance
    """
    if satellite_type.lower() in ["sentinel2", "sentinel-2", "s2"]:
        return Sentinel2GeoTIFFCompiler(output_dir, **kwargs)
    elif satellite_type.lower() == "landsat":
        return LandsatGeoTIFFCompiler(output_dir, **kwargs)
    else:
        raise ValueError(f"Unknown satellite type: {satellite_type}")

create_satellite_collection(directory, satellite_type=None)

Factory function to create appropriate satellite image collection.

Parameters:

directory : str Directory containing satellite images satellite_type : str, optional Force specific satellite type ('landsat' or 'sentinel2')

Returns:

SatelliteImageCollection Appropriate satellite image collection

Source code in ShallowLearn/io/satellite_data.py
def create_satellite_collection(
    directory: str, satellite_type: Optional[str] = None
) -> SatelliteImageCollection:
    """
    Factory function to create appropriate satellite image collection.

    Parameters:
    -----------
    directory : str
        Directory containing satellite images
    satellite_type : str, optional
        Force specific satellite type ('landsat' or 'sentinel2')

    Returns:
    --------
    SatelliteImageCollection
        Appropriate satellite image collection
    """
    if satellite_type:
        if satellite_type.lower() == "landsat":
            return LandsatImageCollection(directory)
        elif satellite_type.lower() in ["sentinel2", "sentinel-2", "s2"]:
            return Sentinel2ImageCollection(directory)
        else:
            raise ValueError(f"Unknown satellite type: {satellite_type}")

    # Auto-detect based on files in directory
    path = Path(directory)
    files = list(path.glob("*.vrt"))

    if not files:
        raise ValueError(f"No VRT files found in {directory}")

    # Check first file for satellite type
    first_file = files[0].name.upper()
    if any(sat in first_file for sat in ["LC08", "LC09", "LE07", "LT05", "LT04"]):
        return LandsatImageCollection(directory)
    elif any(sat in first_file for sat in ["S2A", "S2B"]):
        return Sentinel2ImageCollection(directory)
    else:
        # Default to Landsat if uncertain
        return LandsatImageCollection(directory)

create_satellite_image(file_path)

Factory function to create appropriate satellite image based on file path.

Parameters:

file_path : str Path to the satellite image file

Returns:

SatelliteImage Appropriate satellite image instance

Source code in ShallowLearn/io/satellite_data.py
def create_satellite_image(file_path: str) -> SatelliteImage:
    """
    Factory function to create appropriate satellite image based on file path.

    Parameters:
    -----------
    file_path : str
        Path to the satellite image file

    Returns:
    --------
    SatelliteImage
        Appropriate satellite image instance
    """
    path = Path(file_path)
    filename = path.name.upper()

    # Detect satellite type from filename patterns
    if any(sat in filename for sat in ["LC08", "LC09", "LE07", "LT05", "LT04"]):
        return LandsatImage(file_path)
    elif any(sat in filename for sat in ["S2A", "S2B"]):
        return Sentinel2Image(file_path)
    else:
        # Try to detect from file content or default to Landsat
        # This could be enhanced with more sophisticated detection
        return LandsatImage(file_path)

create_vrt_builder(satellite_type, output_dir, **kwargs)

Factory function to create appropriate VRT builder.

Parameters:

satellite_type : str Type of satellite ("landsat" or "sentinel2") output_dir : str Output directory for VRT files **kwargs Additional arguments for VRT builder

Returns:

VRTBuilder Appropriate VRT builder instance

Source code in ShallowLearn/io/vrt_builder.py
def create_vrt_builder(satellite_type: str, output_dir: str, **kwargs) -> VRTBuilder:
    """
    Factory function to create appropriate VRT builder.

    Parameters:
    -----------
    satellite_type : str
        Type of satellite ("landsat" or "sentinel2")
    output_dir : str
        Output directory for VRT files
    **kwargs
        Additional arguments for VRT builder

    Returns:
    --------
    VRTBuilder
        Appropriate VRT builder instance
    """
    if satellite_type.lower() == "landsat":
        return LandsatVRTBuilder(output_dir, **kwargs)
    elif satellite_type.lower() in ["sentinel2", "sentinel-2", "s2"]:
        return Sentinel2VRTBuilder(output_dir, **kwargs)
    else:
        raise ValueError(f"Unknown satellite type: {satellite_type}")

load_image(path, return_meta=False, clip=False, file_format=None, gdf_clip=None)

High-level image loading function with auto-detection and proper orientation.

This function serves as a replacement for ImageHelper.load_img with enhanced capabilities for handling different satellite data formats and file types.

Parameters:

path : str or Path Path to the image file return_meta : bool, default False Whether to return metadata and bounds along with the image clip : bool, default False Whether to clip values to 0-10000 range file_format : str, optional Force specific format handling ('geotiff', 'sentinel2', 'landsat') If None, format is auto-detected gdf_clip : GeoDataFrame, optional GeoDataFrame with geometries for clipping the image. If provided, the image will be clipped to the geometries

Returns:

np.ndarray or tuple If return_meta=False: Image array with shape (height, width, bands) If return_meta=True: Tuple of (image, metadata, bounds)

Raises:

FileNotFoundError If the specified file does not exist ValueError If the file format is not supported or auto-detection fails

Source code in ShallowLearn/io/image_loader.py
def load_image(
    path: Union[str, Path],
    return_meta: bool = False,
    clip: bool = False,
    file_format: Optional[str] = None,
    gdf_clip: Optional[object] = None,
) -> Union[np.ndarray, Tuple[np.ndarray, dict, object]]:
    """
    High-level image loading function with auto-detection and proper orientation.

    This function serves as a replacement for ImageHelper.load_img with enhanced
    capabilities for handling different satellite data formats and file types.

    Parameters:
    -----------
    path : str or Path
        Path to the image file
    return_meta : bool, default False
        Whether to return metadata and bounds along with the image
    clip : bool, default False
        Whether to clip values to 0-10000 range
    file_format : str, optional
        Force specific format handling ('geotiff', 'sentinel2', 'landsat')
        If None, format is auto-detected
    gdf_clip : GeoDataFrame, optional
        GeoDataFrame with geometries for clipping the image. If provided, the image will be clipped to the geometries

    Returns:
    --------
    np.ndarray or tuple
        If return_meta=False: Image array with shape (height, width, bands)
        If return_meta=True: Tuple of (image, metadata, bounds)

    Raises:
    -------
    FileNotFoundError
        If the specified file does not exist
    ValueError
        If the file format is not supported or auto-detection fails
    """
    path = Path(path)

    if not path.exists():
        raise FileNotFoundError(f"Image file not found: {path}")

    # Auto-detect format if not specified
    if file_format is None:
        file_format = _detect_file_format(path)

    # Load using appropriate loader
    if file_format == "sentinel2":
        loader = Sentinel2Image(str(path))
        img = loader.image
        if img is None:
            img = loader._load_image()  # Ensure image is loaded
    elif file_format == "landsat":
        loader = LandsatImage(str(path))
        img = loader.image
        if img is None:
            img = loader._load_image()  # Ensure image is loaded
    else:
        # Default to GeoTIFF loader
        loader = GeoTIFFImage(str(path))
        img = loader.load()

    # Apply legacy transformations for backwards compatibility
    img = _apply_legacy_transformations(img, path)

    # Always ensure channels-last format (height, width, bands)
    if len(img.shape) == 3:
        # Check if image needs to be transposed from (bands, height, width) to (height, width, bands)
        if img.shape[0] <= 20 and img.shape[2] > 20:  # Likely (bands, height, width)
            img = np.transpose(img, (1, 2, 0))
        # If it's already (height, width, bands) or we can't determine, leave as is

    # Apply GeoDataFrame clipping if requested
    if gdf_clip is not None:
        metadata = loader.get_metadata() if hasattr(loader, "get_metadata") else {}
        img = clip_image_with_gdf(img, gdf_clip, metadata.get('crs'), metadata.get('transform'))

    # Apply value clipping if requested
    if clip:
        img = clip_array(img)

    # Return with metadata if requested
    if return_meta:
        metadata = loader.get_metadata() if hasattr(loader, "get_metadata") else {}
        bounds = loader.get_bounds() if hasattr(loader, "get_bounds") else None
        return img, metadata, bounds

    return img

load_image_collection(directory, pattern='*.tif', **load_kwargs)

Load multiple images from a directory.

Parameters:

directory : str or Path Directory containing image files pattern : str, default ".tif" Glob pattern to match files *load_kwargs Additional arguments passed to load_image()

Returns:

list List of loaded image arrays

Source code in ShallowLearn/io/image_loader.py
def load_image_collection(
    directory: Union[str, Path], pattern: str = "*.tif", **load_kwargs
) -> list:
    """
    Load multiple images from a directory.

    Parameters:
    -----------
    directory : str or Path
        Directory containing image files
    pattern : str, default "*.tif"
        Glob pattern to match files
    **load_kwargs
        Additional arguments passed to load_image()

    Returns:
    --------
    list
        List of loaded image arrays
    """
    directory = Path(directory)
    files = sorted(directory.glob(pattern))

    images = []
    for file_path in files:
        try:
            img = load_image(file_path, **load_kwargs)
            images.append(img)
        except Exception as e:
            print(f"Warning: Failed to load {file_path}: {e}")
            continue

    return images

ML Module

Machine Learning module for ShallowLearn Contains dimensionality reduction, clustering, and analysis components

DimensionalityReducer

Bases: ABC

Abstract base for dimensionality reduction methods

Source code in ShallowLearn/ml/quicklook_ml.py
class DimensionalityReducer(ABC):
    """Abstract base for dimensionality reduction methods"""

    @abstractmethod
    def fit_transform(self, data: np.ndarray) -> np.ndarray:
        pass

    @abstractmethod
    def get_name(self) -> str:
        pass

QuickLookConfig dataclass

Configuration for QuickLook processing

Source code in ShallowLearn/ml/quicklook_ml.py
@dataclass
class QuickLookConfig:
    """Configuration for QuickLook processing"""
    # Dimensionality reduction
    reduction_method: str = "pca"  # "pca", "tsne", "umap", "svd"
    n_components: Union[int, float] = 0.95  # Components or variance explained

    # Clustering
    clustering_method: str = "dbscan"  # "dbscan", "kmeans", "gmm"
    clustering_params: Dict[str, Any] = None

    # Image processing
    target_size: Tuple[int, int] = (343, 343)  # Native Sentinel-2 size
    normalize: bool = True

    # Thumbnail handling
    download_thumbnails: bool = True
    cache_dir: Optional[str] = None

    def __post_init__(self):
        if self.clustering_params is None:
            if self.clustering_method == "dbscan":
                self.clustering_params = {"eps": 50, "min_samples": 5}
            elif self.clustering_method == "kmeans":
                self.clustering_params = {"n_clusters": 4}
            elif self.clustering_method == "gmm":
                self.clustering_params = {"n_components": 4}

QuickLookFilter

Main QuickLook filtering system for satellite products

Source code in ShallowLearn/ml/quicklook_ml.py
class QuickLookFilter:
    """Main QuickLook filtering system for satellite products"""

    def __init__(self, config: Optional[QuickLookConfig] = None):
        self.config = config or QuickLookConfig()
        self.thumbnail_loader = ThumbnailLoader(self.config.cache_dir)
        self.reducer = self._create_reducer()

        # State
        self.products = []
        self.thumbnails = []
        self.transformed_data = None
        self.labels = None
        self.class_dict = {}

    def _create_reducer(self) -> DimensionalityReducer:
        """Create dimensionality reducer based on config"""
        method = self.config.reduction_method.lower()
        n_components = self.config.n_components

        if method == "pca":
            return PCAReducer(n_components)
        elif method == "tsne":
            return TSNEReducer(n_components if isinstance(n_components, int) else 2)
        elif method == "umap":
            return UMAPReducer(n_components if isinstance(n_components, int) else 2)
        elif method == "svd":
            return SVDReducer(n_components)
        else:
            raise ValueError(f"Unknown reduction method: {method}")

    def process_products(self, products: List) -> Dict[str, List]:
        """Process satellite products through QuickLook pipeline

        Returns:
            Dictionary with cluster labels as keys and filtered product lists as values
        """
        print(f"Processing {len(products)} products with {self.reducer.get_name()}...")

        # Load thumbnails
        self.products = products
        self.thumbnails = []
        valid_indices = []

        for i, product in enumerate(products):
            thumbnail = self.thumbnail_loader.load_thumbnail(product, self.config.target_size)
            if thumbnail is not None:
                self.thumbnails.append(thumbnail)
                valid_indices.append(i)

        if not self.thumbnails:
            raise ValueError("No valid thumbnails loaded")

        print(f"Loaded {len(self.thumbnails)} valid thumbnails")

        # Keep only products with valid thumbnails
        self.products = [products[i] for i in valid_indices]

        # Prepare data for dimensionality reduction
        thumbnail_data = np.array(self.thumbnails)
        if self.config.normalize:
            thumbnail_data = thumbnail_data.astype(np.float32) / 255.0

        # Flatten for dimensionality reduction
        flattened_data = thumbnail_data.reshape(len(self.thumbnails), -1)

        # Apply dimensionality reduction
        print(f"Applying {self.reducer.get_name()}...")
        self.transformed_data = self.reducer.fit_transform(flattened_data)

        # Apply clustering
        print(f"Clustering with {self.config.clustering_method}...")
        self.labels = self._apply_clustering(self.transformed_data)

        # Generate class dictionary
        self._generate_class_dict()

        # Return filtered products by cluster
        return self._group_products_by_cluster()

    def _apply_clustering(self, data: np.ndarray) -> np.ndarray:
        """Apply clustering algorithm"""
        method = self.config.clustering_method.lower()
        params = self.config.clustering_params

        if method == "dbscan":
            clusterer = DBSCAN(**params)
        elif method == "kmeans":
            clusterer = KMeans(**params)
        elif method == "gmm":
            clusterer = GaussianMixture(**params)
            # GMM returns probabilities, we take argmax
            labels = clusterer.fit_predict(data)
            return labels
        else:
            raise ValueError(f"Unknown clustering method: {method}")

        return clusterer.fit_predict(data)

    def _generate_class_dict(self):
        """Generate class dictionary for visualization - semantic labels only for PCA"""
        unique_labels = np.unique(self.labels)

        # Check if we're using PCA method for semantic labeling
        is_pca = hasattr(self.reducer, 'get_name') and 'PCA' in self.reducer.get_name()

        if is_pca:
            # Calculate mean brightness for each cluster to determine cloud/clear classification
            cluster_means = {}
            for label in unique_labels:
                if label == -1:  # DBSCAN noise
                    continue
                mask = self.labels == label
                cluster_thumbnails = np.array(self.thumbnails)[mask]
                cluster_means[label] = np.mean(cluster_thumbnails)

            # Sort clusters by brightness (darker = clearer, brighter = cloudier)
            if cluster_means:
                sorted_clusters = sorted(cluster_means.items(), key=lambda x: x[1])

                # Assign semantic labels for PCA only
                self.class_dict = {-1: ("#808080", "Noise")}  # Gray for noise

                colors = ["#2ca02c", "#ff7f0e", "#d62728", "#1f77b4"]  # Green, orange, red, blue
                labels_semantic = ["Clear Sky", "Partially Cloudy", "Cloudy", "Very Cloudy"]

                for i, (cluster_id, _) in enumerate(sorted_clusters):
                    color_idx = i % len(colors)
                    label_idx = min(i, len(labels_semantic) - 1)
                    self.class_dict[cluster_id] = (colors[color_idx], labels_semantic[label_idx])
            else:
                # Fallback for no valid clusters in PCA
                self._generate_generic_labels(unique_labels)
        else:
            # For non-PCA methods, use generic cluster labels
            self._generate_generic_labels(unique_labels)

    def _generate_generic_labels(self, unique_labels):
        """Generate generic cluster labels for non-PCA methods"""
        colors = ["#2ca02c", "#ff7f0e", "#d62728", "#1f77b4", "#9467bd", "#8c564b", "#e377c2"]
        self.class_dict = {}

        for i, label in enumerate(unique_labels):
            if label == -1:  # DBSCAN noise
                self.class_dict[label] = ("#808080", "Noise")
            else:
                color_idx = i % len(colors)
                self.class_dict[label] = (colors[color_idx], f"Cluster {label}")

    def _group_products_by_cluster(self) -> Dict[str, List]:
        """Group products by their cluster labels"""
        clusters = {}

        for product, label in zip(self.products, self.labels):
            label_name = self.class_dict.get(label, (None, f"Cluster_{label}"))[1]

            if label_name not in clusters:
                clusters[label_name] = []

            clusters[label_name].append(product)

        return clusters

    def get_clear_sky_products(self) -> List:
        """Get products classified as clear sky"""
        clear_key = None
        for label, (_, name) in self.class_dict.items():
            if "clear" in name.lower():
                clear_key = label
                break

        if clear_key is not None:
            mask = self.labels == clear_key
            return [self.products[i] for i in range(len(self.products)) if mask[i]]
        else:
            print("No clear sky cluster identified")
            return []

    def get_clustering_summary(self) -> Dict[str, int]:
        """Get summary of clustering results"""
        summary = {}
        for label in np.unique(self.labels):
            count = np.sum(self.labels == label)
            label_name = self.class_dict.get(label, (None, f"Cluster_{label}"))[1]
            summary[label_name] = count

        return summary

get_clear_sky_products()

Get products classified as clear sky

Source code in ShallowLearn/ml/quicklook_ml.py
def get_clear_sky_products(self) -> List:
    """Get products classified as clear sky"""
    clear_key = None
    for label, (_, name) in self.class_dict.items():
        if "clear" in name.lower():
            clear_key = label
            break

    if clear_key is not None:
        mask = self.labels == clear_key
        return [self.products[i] for i in range(len(self.products)) if mask[i]]
    else:
        print("No clear sky cluster identified")
        return []

get_clustering_summary()

Get summary of clustering results

Source code in ShallowLearn/ml/quicklook_ml.py
def get_clustering_summary(self) -> Dict[str, int]:
    """Get summary of clustering results"""
    summary = {}
    for label in np.unique(self.labels):
        count = np.sum(self.labels == label)
        label_name = self.class_dict.get(label, (None, f"Cluster_{label}"))[1]
        summary[label_name] = count

    return summary

process_products(products)

Process satellite products through QuickLook pipeline

Returns:

Type Description
Dict[str, List]

Dictionary with cluster labels as keys and filtered product lists as values

Source code in ShallowLearn/ml/quicklook_ml.py
def process_products(self, products: List) -> Dict[str, List]:
    """Process satellite products through QuickLook pipeline

    Returns:
        Dictionary with cluster labels as keys and filtered product lists as values
    """
    print(f"Processing {len(products)} products with {self.reducer.get_name()}...")

    # Load thumbnails
    self.products = products
    self.thumbnails = []
    valid_indices = []

    for i, product in enumerate(products):
        thumbnail = self.thumbnail_loader.load_thumbnail(product, self.config.target_size)
        if thumbnail is not None:
            self.thumbnails.append(thumbnail)
            valid_indices.append(i)

    if not self.thumbnails:
        raise ValueError("No valid thumbnails loaded")

    print(f"Loaded {len(self.thumbnails)} valid thumbnails")

    # Keep only products with valid thumbnails
    self.products = [products[i] for i in valid_indices]

    # Prepare data for dimensionality reduction
    thumbnail_data = np.array(self.thumbnails)
    if self.config.normalize:
        thumbnail_data = thumbnail_data.astype(np.float32) / 255.0

    # Flatten for dimensionality reduction
    flattened_data = thumbnail_data.reshape(len(self.thumbnails), -1)

    # Apply dimensionality reduction
    print(f"Applying {self.reducer.get_name()}...")
    self.transformed_data = self.reducer.fit_transform(flattened_data)

    # Apply clustering
    print(f"Clustering with {self.config.clustering_method}...")
    self.labels = self._apply_clustering(self.transformed_data)

    # Generate class dictionary
    self._generate_class_dict()

    # Return filtered products by cluster
    return self._group_products_by_cluster()

ThumbnailLoader

Handles thumbnail/PVI loading for different satellite types

Source code in ShallowLearn/ml/quicklook_ml.py
class ThumbnailLoader:
    """Handles thumbnail/PVI loading for different satellite types"""

    def __init__(self, cache_dir: Optional[str] = None):
        self.cache_dir = cache_dir or tempfile.mkdtemp()
        os.makedirs(self.cache_dir, exist_ok=True)

    def load_thumbnail(self, product, target_size: Tuple[int, int] = (343, 343)) -> Optional[np.ndarray]:
        """Load thumbnail for a satellite product"""
        try:
            if product.thumbnail_url:
                return self._load_from_url(product.thumbnail_url, target_size)
            elif product.satellite == "sentinel2":
                return self._load_sentinel2_pvi(product, target_size)
            elif product.satellite == "landsat":
                return self._load_landsat_thumbnail(product, target_size)
            else:
                print(f"No thumbnail method for {product.satellite}")
                return None
        except Exception as e:
            print(f"Failed to load thumbnail for {product.product_id}: {e}")
            return None

    def _load_from_url(self, url: str, target_size: Tuple[int, int]) -> np.ndarray:
        """Load thumbnail from URL"""
        # Create cache filename from URL
        cache_file = os.path.join(self.cache_dir, f"{hash(url)}.jpg")

        if not os.path.exists(cache_file):
            response = requests.get(url)
            response.raise_for_status()
            with open(cache_file, 'wb') as f:
                f.write(response.content)

        with Image.open(cache_file) as img:
            img = img.convert('RGB')
            img = img.resize(target_size)
            return np.array(img)

    def _load_sentinel2_pvi(self, product, target_size: Tuple[int, int]) -> np.ndarray:
        """Load Sentinel-2 PVI from product metadata or zip file"""
        # This would integrate with the existing PVI_Dataloader logic
        # For now, create a placeholder that would work with actual PVI files
        if hasattr(product, 'metadata') and product.metadata:
            # Try to extract PVI URL from metadata
            props = product.metadata.get('properties', {})
            thumbnail = props.get('thumbnail')
            if thumbnail:
                return self._load_from_url(thumbnail, target_size)

        # Fallback: generate placeholder thumbnail
        return self._generate_placeholder_thumbnail(product, target_size)

    def _load_landsat_thumbnail(self, product, target_size: Tuple[int, int]) -> np.ndarray:
        """Load Landsat thumbnail - would integrate with Landsat thumbnail generation"""
        # Landsat thumbnails would need to be generated from the downloaded data
        # For now, create placeholder
        return self._generate_placeholder_thumbnail(product, target_size)

    def _generate_placeholder_thumbnail(self, product, target_size: Tuple[int, int]) -> np.ndarray:
        """Generate a placeholder thumbnail based on product metadata"""
        # Create a simple colored thumbnail based on cloud cover and date
        cloud_cover = getattr(product, 'cloud_cover', 50) / 100.0

        # Create gradient based on cloud cover
        thumbnail = np.zeros((*target_size, 3), dtype=np.uint8)

        # Blue channel increases with cloud cover
        thumbnail[:, :, 2] = int(255 * cloud_cover)
        # Green channel decreases with cloud cover  
        thumbnail[:, :, 1] = int(255 * (1 - cloud_cover))
        # Red stays constant
        thumbnail[:, :, 0] = 128

        return thumbnail

load_thumbnail(product, target_size=(343, 343))

Load thumbnail for a satellite product

Source code in ShallowLearn/ml/quicklook_ml.py
def load_thumbnail(self, product, target_size: Tuple[int, int] = (343, 343)) -> Optional[np.ndarray]:
    """Load thumbnail for a satellite product"""
    try:
        if product.thumbnail_url:
            return self._load_from_url(product.thumbnail_url, target_size)
        elif product.satellite == "sentinel2":
            return self._load_sentinel2_pvi(product, target_size)
        elif product.satellite == "landsat":
            return self._load_landsat_thumbnail(product, target_size)
        else:
            print(f"No thumbnail method for {product.satellite}")
            return None
    except Exception as e:
        print(f"Failed to load thumbnail for {product.product_id}: {e}")
        return None

Core Module

Core utilities for ShallowLearn. Contains clean, reusable functions with minimal dependencies.

API Module

API module for satellite data access

LandsatUSGSDownloader

Landsat downloader using USGS M2M API - matches existing BarAlHikman implementation

Source code in ShallowLearn/api/unified_satellite_api.py
class LandsatUSGSDownloader:
    """Landsat downloader using USGS M2M API - matches existing BarAlHikman implementation"""

    def __init__(self):
        self.service_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
        self.datasets = {
            "L1": "landsat_ot_c2_l1",  # Landsat 8-9 OLI/TIRS Collection 2 Level 1
            "L2": "landsat_ot_c2_l2",  # Landsat 8-9 OLI/TIRS Collection 2 Level 2
        }
        # Load credentials directly from environment
        self.token = os.getenv("LSAT_TOKEN")
        self.username = os.getenv("LSAT_USER")
        if self.username:
            self.username = (
                self.username.lower()
            )  # Force lowercase for USGS API compatibility

        if not self.token or not self.username:
            raise ValueError("Landsat credentials not found. Check .env file.")

    def _send_request(self, endpoint: str, data: Dict, api_key: str = None) -> Dict:
        """Send request to USGS M2M API - matches existing implementation"""
        headers = {"Content-Type": "application/json"}
        if api_key:
            headers["X-Auth-Token"] = api_key

        response = requests.post(
            self.service_url + endpoint, json=data, headers=headers
        )
        response.raise_for_status()
        output = response.json()

        if output.get("errorCode"):
            raise Exception(f"{output['errorCode']}: {output['errorMessage']}")
        return output["data"]

    def _extract_landsat_metadata(self, scene: Dict) -> Dict:
        """Extract and standardize Landsat-specific metadata fields"""
        metadata_dict = {}

        # Convert metadata list to dictionary
        if "metadata" in scene and isinstance(scene["metadata"], list):
            for item in scene["metadata"]:
                if "fieldName" in item and "value" in item:
                    metadata_dict[item["fieldName"]] = item["value"]

        # Extract key fields with safe conversion
        def safe_int(value):
            try:
                return int(float(value)) if value else None
            except (ValueError, TypeError):
                return None

        def safe_float(value):
            try:
                return float(value) if value else None
            except (ValueError, TypeError):
                return None

        return {
            "wrs_path": safe_int(metadata_dict.get("WRS_PATH")),
            "wrs_row": safe_int(metadata_dict.get("WRS_ROW")),
            "scene_id": metadata_dict.get("LANDSAT_SCENE_ID"),
            "spacecraft_id": metadata_dict.get("SPACECRAFT_ID"),
            "sensor_id": metadata_dict.get("SENSOR_ID"),
            "sun_azimuth": safe_float(metadata_dict.get("SUN_AZIMUTH")),
            "sun_elevation": safe_float(metadata_dict.get("SUN_ELEVATION")),
            "earth_sun_distance": safe_float(metadata_dict.get("EARTH_SUN_DISTANCE")),
            "collection_number": metadata_dict.get("Collection Number"),
            "collection_category": metadata_dict.get("Collection Category"),
            "image_quality": metadata_dict.get("IMAGE_QUALITY_OLI")
            or metadata_dict.get("IMAGE_QUALITY_TIRS"),
            "geometric_rmse": safe_float(metadata_dict.get("GEOMETRIC_RMSE_MODEL")),
        }

    def _convert_geometry_to_mbr(self, geometry) -> Dict:
        """Convert geometry to USGS MBR format"""
        if isinstance(geometry, Point):
            # For point queries, create small bounding box around the point
            lon, lat = geometry.x, geometry.y
            buffer = 0.01  # ~1km buffer
            return {
                "filterType": "mbr",
                "lowerLeft": {"latitude": lat - buffer, "longitude": lon - buffer},
                "upperRight": {"latitude": lat + buffer, "longitude": lon + buffer},
            }
        elif isinstance(geometry, tuple) and len(geometry) == 4:
            # Bounding box: (lon_min, lat_min, lon_max, lat_max)
            lon_min, lat_min, lon_max, lat_max = geometry
            return {
                "filterType": "mbr",
                "lowerLeft": {"latitude": lat_min, "longitude": lon_min},
                "upperRight": {"latitude": lat_max, "longitude": lon_max},
            }
        else:
            raise ValueError(f"Unsupported geometry type: {type(geometry)}")

    def search(self, query: SatelliteQuery) -> List[SatelliteProduct]:
        """Search Landsat scenes - matches existing BarAlHikman data_download.py logic"""
        print("Searching Landsat scenes...")

        # Step 1: Authenticate
        auth_payload = {"username": self.username, "token": self.token}
        api_key = self._send_request("login-token", auth_payload)

        try:
            # Step 2: Build search payload using same structure as existing code
            spatial_filter = self._convert_geometry_to_mbr(query.geometry)
            acquisition_filter = {
                "start": query.date_range[0],
                "end": query.date_range[1],
            }
            cloud_cover_filter = {"min": 0, "max": query.cloud_cover_max}

            dataset_name = self.datasets.get(
                query.processing_level, self.datasets["L1"]
            )

            search_payload = {
                "datasetName": dataset_name,
                "maxResults": query.max_results,
                "startingNumber": 1,
                "sceneFilter": {
                    "spatialFilter": spatial_filter,
                    "acquisitionFilter": acquisition_filter,
                    "cloudCoverFilter": cloud_cover_filter,
                },
            }

            # Step 3: Execute search
            scenes = self._send_request("scene-search", search_payload, api_key)

            if scenes["recordsReturned"] == 0:
                print("No Landsat scenes found for query")
                return []

            print(f"Found {scenes['recordsReturned']} Landsat scenes")

            # Step 4: Convert to standardized format
            products = []
            for scene in scenes["results"]:
                # Extract sensor type from entity ID
                entity_id = scene.get("entityId", "")
                if (
                    entity_id.startswith("LC08")
                    or entity_id.startswith("LC09")
                    or entity_id.startswith("LC8")
                    or entity_id.startswith("LC9")
                ):
                    sensor = "OLI/TIRS"
                elif entity_id.startswith("LE07") or entity_id.startswith("LC7"):
                    sensor = "ETM+"
                elif (
                    entity_id.startswith("LT05")
                    or entity_id.startswith("LT04")
                    or entity_id.startswith("LC5")
                    or entity_id.startswith("LC4")
                ):
                    sensor = "TM"
                else:
                    # Default to most common modern sensor for unknown patterns
                    sensor = "OLI/TIRS"

                # Look for browse/thumbnail URL in various possible locations
                thumbnail_url = None
                if "browse" in scene:
                    browse_info = scene["browse"]
                    if isinstance(browse_info, dict):
                        thumbnail_url = browse_info.get(
                            "browsePath"
                        ) or browse_info.get("browseUrl")
                    elif isinstance(browse_info, list) and len(browse_info) > 0:
                        thumbnail_url = browse_info[0].get("browsePath") or browse_info[
                            0
                        ].get("browseUrl")

                # Extract Landsat-specific metadata
                landsat_metadata = self._extract_landsat_metadata(scene)

                # Format acquisition date to match Sentinel-2 format (with timezone and microseconds)
                raw_date = scene.get("temporalCoverage", {}).get("startDate")
                if raw_date:
                    # Convert to ISO format with microseconds to match Sentinel-2
                    # "2023-06-11 00:00:00" -> "2023-06-11T00:00:00.000000Z"
                    if " " in raw_date:
                        acquisition_date = raw_date.replace(" ", "T") + ".000000Z"
                    elif "." not in raw_date:
                        # Add microseconds if missing
                        if raw_date.endswith("Z"):
                            acquisition_date = raw_date[:-1] + ".000000Z"
                        else:
                            acquisition_date = raw_date + ".000000Z"
                    else:
                        acquisition_date = raw_date
                else:
                    acquisition_date = None

                product = SatelliteProduct(
                    product_id=scene.get("entityId"),
                    satellite="landsat",
                    sensor=sensor,
                    acquisition_date=acquisition_date,
                    cloud_cover=float(scene.get("cloudCover", 0)),
                    processing_level=query.processing_level,
                    thumbnail_url=thumbnail_url,
                    bounds=scene.get("spatialBounds"),
                    metadata=scene,
                    # Landsat-specific metadata fields
                    **landsat_metadata,
                )
                products.append(product)

            return products

        finally:
            # Step 5: Logout
            self._send_request("logout", None, api_key)

    def download_product(self, product: SatelliteProduct, output_dir: str) -> str:
        """Download a Landsat product using USGS M2M API - matches BarAlHikman implementation"""
        import datetime
        import os
        import re
        import time

        # Step 1: Authenticate
        payload = {"username": self.username, "token": self.token}
        api_key = self._send_request("login-token", payload)

        try:
            # Step 2: Get download options (exact BarAlHikman approach)
            payload = {
                "datasetName": self.datasets.get(
                    product.processing_level, "landsat_ot_c2_l1"
                ),
                "entityIds": [product.product_id],
            }

            download_options = self._send_request("download-options", payload, api_key)

            # Step 3: Find available products (exact BarAlHikman logic)
            downloads = []
            for option in download_options:
                if option.get("available"):
                    downloads.append(
                        {"entityId": option["entityId"], "productId": option["id"]}
                    )

            if not downloads:
                raise Exception(
                    f"No available products to download for {product.product_id}"
                )

            # Step 4: Request downloads (exact BarAlHikman approach)
            label = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            payload = {"downloads": downloads, "label": label}

            request_results = self._send_request("download-request", payload, api_key)

            # Step 5: Handle download availability (exact BarAlHikman logic)
            download_url = None

            # If downloads are preparing, poll until ready (BarAlHikman approach)
            if request_results.get("preparingDownloads"):
                print(f"Download preparing for {product.product_id}...")
                payload = {"label": label}

                while True:
                    more_download_urls = self._send_request(
                        "download-retrieve", payload, api_key
                    )
                    available = more_download_urls.get("available", [])

                    if available:
                        download_url = available[0]["url"]
                        break

                    # Check if we have enough downloads (BarAlHikman logic)
                    if len(available) >= len(downloads):
                        break

                    print("Waiting for downloads to become available...")
                    time.sleep(30)
            else:
                # Downloads immediately available (BarAlHikman approach)
                available_downloads = request_results.get("availableDownloads", [])
                if available_downloads:
                    download_url = available_downloads[0]["url"]

            if not download_url:
                raise Exception(f"No download URL available for {product.product_id}")

            # Step 5: Download the file using requests (BarAlHikman pattern)
            response = requests.get(download_url, stream=True)
            response.raise_for_status()

            # Determine filename from response headers or URL
            disposition = response.headers.get("content-disposition", "")
            filename = None
            if "filename=" in disposition:
                filename = re.findall("filename=(.+)", disposition)[0].strip('"')
            else:
                filename = download_url.split("/")[-1].split("?")[0]

            if not filename:
                filename = f"{product.product_id}.tar.gz"  # Default extension

            file_path = os.path.join(output_dir, filename)

            # Download with progress indication
            print(f"Downloading {filename}...")
            with open(file_path, "wb") as f:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)

            print(f"Downloaded {filename}")
            return file_path

        finally:
            # Step 5: Logout
            self._send_request("logout", None, api_key)

download_product(product, output_dir)

Download a Landsat product using USGS M2M API - matches BarAlHikman implementation

Source code in ShallowLearn/api/unified_satellite_api.py
def download_product(self, product: SatelliteProduct, output_dir: str) -> str:
    """Download a Landsat product using USGS M2M API - matches BarAlHikman implementation"""
    import datetime
    import os
    import re
    import time

    # Step 1: Authenticate
    payload = {"username": self.username, "token": self.token}
    api_key = self._send_request("login-token", payload)

    try:
        # Step 2: Get download options (exact BarAlHikman approach)
        payload = {
            "datasetName": self.datasets.get(
                product.processing_level, "landsat_ot_c2_l1"
            ),
            "entityIds": [product.product_id],
        }

        download_options = self._send_request("download-options", payload, api_key)

        # Step 3: Find available products (exact BarAlHikman logic)
        downloads = []
        for option in download_options:
            if option.get("available"):
                downloads.append(
                    {"entityId": option["entityId"], "productId": option["id"]}
                )

        if not downloads:
            raise Exception(
                f"No available products to download for {product.product_id}"
            )

        # Step 4: Request downloads (exact BarAlHikman approach)
        label = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        payload = {"downloads": downloads, "label": label}

        request_results = self._send_request("download-request", payload, api_key)

        # Step 5: Handle download availability (exact BarAlHikman logic)
        download_url = None

        # If downloads are preparing, poll until ready (BarAlHikman approach)
        if request_results.get("preparingDownloads"):
            print(f"Download preparing for {product.product_id}...")
            payload = {"label": label}

            while True:
                more_download_urls = self._send_request(
                    "download-retrieve", payload, api_key
                )
                available = more_download_urls.get("available", [])

                if available:
                    download_url = available[0]["url"]
                    break

                # Check if we have enough downloads (BarAlHikman logic)
                if len(available) >= len(downloads):
                    break

                print("Waiting for downloads to become available...")
                time.sleep(30)
        else:
            # Downloads immediately available (BarAlHikman approach)
            available_downloads = request_results.get("availableDownloads", [])
            if available_downloads:
                download_url = available_downloads[0]["url"]

        if not download_url:
            raise Exception(f"No download URL available for {product.product_id}")

        # Step 5: Download the file using requests (BarAlHikman pattern)
        response = requests.get(download_url, stream=True)
        response.raise_for_status()

        # Determine filename from response headers or URL
        disposition = response.headers.get("content-disposition", "")
        filename = None
        if "filename=" in disposition:
            filename = re.findall("filename=(.+)", disposition)[0].strip('"')
        else:
            filename = download_url.split("/")[-1].split("?")[0]

        if not filename:
            filename = f"{product.product_id}.tar.gz"  # Default extension

        file_path = os.path.join(output_dir, filename)

        # Download with progress indication
        print(f"Downloading {filename}...")
        with open(file_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

        print(f"Downloaded {filename}")
        return file_path

    finally:
        # Step 5: Logout
        self._send_request("logout", None, api_key)

search(query)

Search Landsat scenes - matches existing BarAlHikman data_download.py logic

Source code in ShallowLearn/api/unified_satellite_api.py
def search(self, query: SatelliteQuery) -> List[SatelliteProduct]:
    """Search Landsat scenes - matches existing BarAlHikman data_download.py logic"""
    print("Searching Landsat scenes...")

    # Step 1: Authenticate
    auth_payload = {"username": self.username, "token": self.token}
    api_key = self._send_request("login-token", auth_payload)

    try:
        # Step 2: Build search payload using same structure as existing code
        spatial_filter = self._convert_geometry_to_mbr(query.geometry)
        acquisition_filter = {
            "start": query.date_range[0],
            "end": query.date_range[1],
        }
        cloud_cover_filter = {"min": 0, "max": query.cloud_cover_max}

        dataset_name = self.datasets.get(
            query.processing_level, self.datasets["L1"]
        )

        search_payload = {
            "datasetName": dataset_name,
            "maxResults": query.max_results,
            "startingNumber": 1,
            "sceneFilter": {
                "spatialFilter": spatial_filter,
                "acquisitionFilter": acquisition_filter,
                "cloudCoverFilter": cloud_cover_filter,
            },
        }

        # Step 3: Execute search
        scenes = self._send_request("scene-search", search_payload, api_key)

        if scenes["recordsReturned"] == 0:
            print("No Landsat scenes found for query")
            return []

        print(f"Found {scenes['recordsReturned']} Landsat scenes")

        # Step 4: Convert to standardized format
        products = []
        for scene in scenes["results"]:
            # Extract sensor type from entity ID
            entity_id = scene.get("entityId", "")
            if (
                entity_id.startswith("LC08")
                or entity_id.startswith("LC09")
                or entity_id.startswith("LC8")
                or entity_id.startswith("LC9")
            ):
                sensor = "OLI/TIRS"
            elif entity_id.startswith("LE07") or entity_id.startswith("LC7"):
                sensor = "ETM+"
            elif (
                entity_id.startswith("LT05")
                or entity_id.startswith("LT04")
                or entity_id.startswith("LC5")
                or entity_id.startswith("LC4")
            ):
                sensor = "TM"
            else:
                # Default to most common modern sensor for unknown patterns
                sensor = "OLI/TIRS"

            # Look for browse/thumbnail URL in various possible locations
            thumbnail_url = None
            if "browse" in scene:
                browse_info = scene["browse"]
                if isinstance(browse_info, dict):
                    thumbnail_url = browse_info.get(
                        "browsePath"
                    ) or browse_info.get("browseUrl")
                elif isinstance(browse_info, list) and len(browse_info) > 0:
                    thumbnail_url = browse_info[0].get("browsePath") or browse_info[
                        0
                    ].get("browseUrl")

            # Extract Landsat-specific metadata
            landsat_metadata = self._extract_landsat_metadata(scene)

            # Format acquisition date to match Sentinel-2 format (with timezone and microseconds)
            raw_date = scene.get("temporalCoverage", {}).get("startDate")
            if raw_date:
                # Convert to ISO format with microseconds to match Sentinel-2
                # "2023-06-11 00:00:00" -> "2023-06-11T00:00:00.000000Z"
                if " " in raw_date:
                    acquisition_date = raw_date.replace(" ", "T") + ".000000Z"
                elif "." not in raw_date:
                    # Add microseconds if missing
                    if raw_date.endswith("Z"):
                        acquisition_date = raw_date[:-1] + ".000000Z"
                    else:
                        acquisition_date = raw_date + ".000000Z"
                else:
                    acquisition_date = raw_date
            else:
                acquisition_date = None

            product = SatelliteProduct(
                product_id=scene.get("entityId"),
                satellite="landsat",
                sensor=sensor,
                acquisition_date=acquisition_date,
                cloud_cover=float(scene.get("cloudCover", 0)),
                processing_level=query.processing_level,
                thumbnail_url=thumbnail_url,
                bounds=scene.get("spatialBounds"),
                metadata=scene,
                # Landsat-specific metadata fields
                **landsat_metadata,
            )
            products.append(product)

        return products

    finally:
        # Step 5: Logout
        self._send_request("logout", None, api_key)

SatelliteProduct dataclass

Standardized satellite product representation

Source code in ShallowLearn/api/unified_satellite_api.py
@dataclass
class SatelliteProduct:
    """Standardized satellite product representation"""

    product_id: str
    satellite: str  # 'landsat' or 'sentinel2'
    sensor: str
    acquisition_date: str
    cloud_cover: float
    processing_level: str
    thumbnail_url: Optional[str] = None
    download_url: Optional[str] = None
    bounds: Optional[Dict] = None
    metadata: Optional[Dict] = None

    # Sentinel-2 specific metadata fields
    orbit_number: Optional[int] = None
    relative_orbit_number: Optional[int] = None
    processing_baseline: Optional[str] = None
    product_type: Optional[str] = None
    platform: Optional[str] = None
    instrument: Optional[str] = None
    timeliness: Optional[str] = None
    snow_cover: Optional[float] = None
    orbit_direction: Optional[str] = None

    # Landsat specific metadata fields
    wrs_path: Optional[int] = None
    wrs_row: Optional[int] = None
    scene_id: Optional[str] = None
    spacecraft_id: Optional[str] = None
    sensor_id: Optional[str] = None
    sun_azimuth: Optional[float] = None
    sun_elevation: Optional[float] = None
    earth_sun_distance: Optional[float] = None
    collection_number: Optional[str] = None
    collection_category: Optional[str] = None
    image_quality: Optional[str] = None
    geometric_rmse: Optional[float] = None

    def to_dict(self) -> Dict:
        return {
            "product_id": self.product_id,
            "satellite": self.satellite,
            "sensor": self.sensor,
            "acquisition_date": self.acquisition_date,
            "cloud_cover": self.cloud_cover,
            "processing_level": self.processing_level,
            "thumbnail_url": self.thumbnail_url,
            "download_url": self.download_url,
            "bounds": self.bounds,
            # Sentinel-2 fields
            "orbit_number": self.orbit_number,
            "relative_orbit_number": self.relative_orbit_number,
            "processing_baseline": self.processing_baseline,
            "product_type": self.product_type,
            "platform": self.platform,
            "instrument": self.instrument,
            "timeliness": self.timeliness,
            "snow_cover": self.snow_cover,
            "orbit_direction": self.orbit_direction,
            # Landsat fields
            "wrs_path": self.wrs_path,
            "wrs_row": self.wrs_row,
            "scene_id": self.scene_id,
            "spacecraft_id": self.spacecraft_id,
            "sensor_id": self.sensor_id,
            "sun_azimuth": self.sun_azimuth,
            "sun_elevation": self.sun_elevation,
            "earth_sun_distance": self.earth_sun_distance,
            "collection_number": self.collection_number,
            "collection_category": self.collection_category,
            "image_quality": self.image_quality,
            "geometric_rmse": self.geometric_rmse,
        }

SatelliteQuery dataclass

Unified query parameters for satellite data

Source code in ShallowLearn/api/unified_satellite_api.py
@dataclass
class SatelliteQuery:
    """Unified query parameters for satellite data"""

    geometry: Union[
        Point, Tuple[float, float, float, float]
    ]  # Point(lon, lat) or (lon_min, lat_min, lon_max, lat_max)
    date_range: Tuple[str, str]
    cloud_cover_max: int = 100
    processing_level: str = "L1C"  # L1C|L2A for S2, L1|L2 for Landsat
    satellites: List[str] = None
    max_results: int = 1000

    def __post_init__(self):
        if self.satellites is None:
            self.satellites = ["landsat", "sentinel2"]

    def to_dict(self) -> Dict:
        return {
            "geometry": str(self.geometry),
            "date_range": self.date_range,
            "cloud_cover_max": self.cloud_cover_max,
            "processing_level": self.processing_level,
            "satellites": self.satellites,
            "max_results": self.max_results,
        }

Sentinel2CDSEDownloader

Sentinel-2 downloader using CDSE API - matches existing ShallowLearn implementation

Source code in ShallowLearn/api/unified_satellite_api.py
class Sentinel2CDSEDownloader:
    """Sentinel-2 downloader using CDSE API - matches existing ShallowLearn implementation"""

    def __init__(self):
        # Load credentials directly from environment
        self.username = os.getenv("SEN_USER")
        self.password = os.getenv("SEN_PASS")

        if not self.username or not self.password:
            raise ValueError("Sentinel-2 credentials not found. Check .env file.")

        if query_features is None:
            raise ImportError("cdsetool not available. Cannot use Sentinel-2 API.")

    def _convert_geometry_to_bbox(self, geometry):
        """Convert geometry to shapely box for CDSE API"""
        if isinstance(geometry, Point):
            # For point queries, create small bounding box
            lon, lat = geometry.x, geometry.y
            buffer = 0.01
            return box(lon - buffer, lat - buffer, lon + buffer, lat + buffer)
        elif isinstance(geometry, tuple) and len(geometry) == 4:
            # Bounding box: (lon_min, lat_min, lon_max, lat_max)
            lon_min, lat_min, lon_max, lat_max = geometry
            return box(lon_min, lat_min, lon_max, lat_max)
        else:
            raise ValueError(f"Unsupported geometry type: {type(geometry)}")

    def search(self, query: SatelliteQuery) -> List[SatelliteProduct]:
        """Search Sentinel-2 scenes - matches existing ShallowLearn DownloadData.py logic"""
        print("Searching Sentinel-2 scenes...")

        # Convert query parameters to CDSE format
        bbox = self._convert_geometry_to_bbox(query.geometry)

        # Map processing levels
        processing_level_map = {"L1C": "S2MSI1C", "L2A": "S2MSI2A"}
        processing_level = processing_level_map.get(query.processing_level, "S2MSI1C")

        # Build search terms using same structure as existing code
        search_terms = {
            "startDate": query.date_range[0],
            "completionDate": query.date_range[1],
            "processingLevel": processing_level,
            "geometry": bbox,
            "maxRecords": query.max_results,
        }

        # Execute search using cdsetool
        features = list(query_features("Sentinel2", search_terms))

        if not features:
            print("No Sentinel-2 scenes found for query")
            return []

        print(f"Found {len(features)} Sentinel-2 scenes")

        # Convert to standardized format (limit results since cdsetool doesn't always respect maxRecords)
        features = features[: query.max_results]
        products = []
        for feature in features:
            props = feature["properties"]

            product = SatelliteProduct(
                product_id=props.get("title"),
                satellite="sentinel2",
                sensor="MSI",
                acquisition_date=props.get("startDate"),
                cloud_cover=float(props.get("cloudCover", 0)),
                processing_level=query.processing_level,
                thumbnail_url=props.get("thumbnail"),
                download_url=props.get("services", {}).get("download", {}).get("url"),
                bounds=feature.get("geometry"),
                metadata=feature,
                # Enhanced metadata fields
                orbit_number=props.get("orbitNumber"),
                relative_orbit_number=props.get("relativeOrbitNumber"),
                processing_baseline=props.get("processingBaseline"),
                product_type=props.get("productType"),
                platform=props.get("platform", "")
                .replace("S2A", "Sentinel-2A")
                .replace("S2B", "Sentinel-2B"),
                instrument=props.get("instrument"),
                timeliness=props.get("timeliness"),
                snow_cover=props.get("snowCover"),
                orbit_direction=props.get("orbitDirection"),
            )
            products.append(product)

        return products

    def download_product(self, product: SatelliteProduct, output_dir: str) -> str:
        """Download a Sentinel-2 product using CDSE API - matches ShallowLearn implementation"""
        try:
            from cdsetool.credentials import Credentials
            from cdsetool.download import download_features
            from cdsetool.monitor import StatusMonitor
        except ImportError:
            raise Exception(
                "cdsetool not available. Cannot download Sentinel-2 products."
            )

        import os

        # The product metadata should contain download info
        if not product.metadata:
            raise Exception(f"No metadata available for product {product.product_id}")

        try:
            # Create features list with just this product
            features = [product.metadata]

            # Use cdsetool to download
            os.makedirs(output_dir, exist_ok=True)
            print(f"Downloading {product.product_id}...")

            # Download using cdsetool - matches existing ShallowLearn pattern
            # Create config with credentials (no monitor in threads)
            config = {
                "concurrency": 1,  # Single file download
                "credentials": Credentials(self.username, self.password),
            }

            # Download features returns a generator - consume it with list()
            download_results = list(download_features(features, output_dir, config))

            # Check if download was successful
            if download_results:
                # Find the downloaded file
                downloaded_files = list(Path(output_dir).glob(f"{product.product_id}*"))
                if downloaded_files:
                    file_path = str(downloaded_files[0])
                    print(f"Downloaded {product.product_id}")
                    return file_path
                else:
                    # Check for .zip files that might match
                    zip_files = list(Path(output_dir).glob("*.zip"))
                    if zip_files:
                        # Return the most recent zip file as a fallback
                        latest_zip = max(zip_files, key=os.path.getctime)
                        print(f"Downloaded {product.product_id} as {latest_zip.name}")
                        return str(latest_zip)
                    else:
                        # Check for any files at all
                        all_files = list(Path(output_dir).glob("*"))
                        if all_files:
                            # Return the most recent file
                            latest_file = max(all_files, key=os.path.getctime)
                            print(
                                f"Downloaded {product.product_id} as {latest_file.name}"
                            )
                            return str(latest_file)
                        else:
                            raise Exception(
                                f"Download completed but no files found in {output_dir}"
                            )
            else:
                raise Exception(
                    f"Download did not return any results for {product.product_id}"
                )

        except Exception as e:
            raise Exception(f"Download failed for {product.product_id}: {str(e)}")

download_product(product, output_dir)

Download a Sentinel-2 product using CDSE API - matches ShallowLearn implementation

Source code in ShallowLearn/api/unified_satellite_api.py
def download_product(self, product: SatelliteProduct, output_dir: str) -> str:
    """Download a Sentinel-2 product using CDSE API - matches ShallowLearn implementation"""
    try:
        from cdsetool.credentials import Credentials
        from cdsetool.download import download_features
        from cdsetool.monitor import StatusMonitor
    except ImportError:
        raise Exception(
            "cdsetool not available. Cannot download Sentinel-2 products."
        )

    import os

    # The product metadata should contain download info
    if not product.metadata:
        raise Exception(f"No metadata available for product {product.product_id}")

    try:
        # Create features list with just this product
        features = [product.metadata]

        # Use cdsetool to download
        os.makedirs(output_dir, exist_ok=True)
        print(f"Downloading {product.product_id}...")

        # Download using cdsetool - matches existing ShallowLearn pattern
        # Create config with credentials (no monitor in threads)
        config = {
            "concurrency": 1,  # Single file download
            "credentials": Credentials(self.username, self.password),
        }

        # Download features returns a generator - consume it with list()
        download_results = list(download_features(features, output_dir, config))

        # Check if download was successful
        if download_results:
            # Find the downloaded file
            downloaded_files = list(Path(output_dir).glob(f"{product.product_id}*"))
            if downloaded_files:
                file_path = str(downloaded_files[0])
                print(f"Downloaded {product.product_id}")
                return file_path
            else:
                # Check for .zip files that might match
                zip_files = list(Path(output_dir).glob("*.zip"))
                if zip_files:
                    # Return the most recent zip file as a fallback
                    latest_zip = max(zip_files, key=os.path.getctime)
                    print(f"Downloaded {product.product_id} as {latest_zip.name}")
                    return str(latest_zip)
                else:
                    # Check for any files at all
                    all_files = list(Path(output_dir).glob("*"))
                    if all_files:
                        # Return the most recent file
                        latest_file = max(all_files, key=os.path.getctime)
                        print(
                            f"Downloaded {product.product_id} as {latest_file.name}"
                        )
                        return str(latest_file)
                    else:
                        raise Exception(
                            f"Download completed but no files found in {output_dir}"
                        )
        else:
            raise Exception(
                f"Download did not return any results for {product.product_id}"
            )

    except Exception as e:
        raise Exception(f"Download failed for {product.product_id}: {str(e)}")

search(query)

Search Sentinel-2 scenes - matches existing ShallowLearn DownloadData.py logic

Source code in ShallowLearn/api/unified_satellite_api.py
def search(self, query: SatelliteQuery) -> List[SatelliteProduct]:
    """Search Sentinel-2 scenes - matches existing ShallowLearn DownloadData.py logic"""
    print("Searching Sentinel-2 scenes...")

    # Convert query parameters to CDSE format
    bbox = self._convert_geometry_to_bbox(query.geometry)

    # Map processing levels
    processing_level_map = {"L1C": "S2MSI1C", "L2A": "S2MSI2A"}
    processing_level = processing_level_map.get(query.processing_level, "S2MSI1C")

    # Build search terms using same structure as existing code
    search_terms = {
        "startDate": query.date_range[0],
        "completionDate": query.date_range[1],
        "processingLevel": processing_level,
        "geometry": bbox,
        "maxRecords": query.max_results,
    }

    # Execute search using cdsetool
    features = list(query_features("Sentinel2", search_terms))

    if not features:
        print("No Sentinel-2 scenes found for query")
        return []

    print(f"Found {len(features)} Sentinel-2 scenes")

    # Convert to standardized format (limit results since cdsetool doesn't always respect maxRecords)
    features = features[: query.max_results]
    products = []
    for feature in features:
        props = feature["properties"]

        product = SatelliteProduct(
            product_id=props.get("title"),
            satellite="sentinel2",
            sensor="MSI",
            acquisition_date=props.get("startDate"),
            cloud_cover=float(props.get("cloudCover", 0)),
            processing_level=query.processing_level,
            thumbnail_url=props.get("thumbnail"),
            download_url=props.get("services", {}).get("download", {}).get("url"),
            bounds=feature.get("geometry"),
            metadata=feature,
            # Enhanced metadata fields
            orbit_number=props.get("orbitNumber"),
            relative_orbit_number=props.get("relativeOrbitNumber"),
            processing_baseline=props.get("processingBaseline"),
            product_type=props.get("productType"),
            platform=props.get("platform", "")
            .replace("S2A", "Sentinel-2A")
            .replace("S2B", "Sentinel-2B"),
            instrument=props.get("instrument"),
            timeliness=props.get("timeliness"),
            snow_cover=props.get("snowCover"),
            orbit_direction=props.get("orbitDirection"),
        )
        products.append(product)

    return products

UnifiedSatelliteAPI

Unified interface for both Landsat and Sentinel-2 APIs

Source code in ShallowLearn/api/unified_satellite_api.py
class UnifiedSatelliteAPI:
    """Unified interface for both Landsat and Sentinel-2 APIs"""

    def __init__(self):
        self.landsat_api = None
        self.sentinel2_api = None

    def _initialize_apis(self, satellites: List[str]):
        """Initialize only the requested satellite APIs"""
        if "landsat" in satellites and self.landsat_api is None:
            try:
                self.landsat_api = LandsatUSGSDownloader()
            except Exception as e:
                print(f"Warning: Could not initialize Landsat API: {e}")

        if "sentinel2" in satellites and self.sentinel2_api is None:
            try:
                self.sentinel2_api = Sentinel2CDSEDownloader()
            except Exception as e:
                print(f"Warning: Could not initialize Sentinel-2 API: {e}")

    def search(self, query: SatelliteQuery) -> List[SatelliteProduct]:
        """Search all requested satellites independently"""
        self._initialize_apis(query.satellites)
        all_products = []

        # Search Landsat independently
        if "landsat" in query.satellites and self.landsat_api is not None:
            try:
                landsat_products = self.landsat_api.search(query)
                all_products.extend(landsat_products)
                print(f"✓ Landsat API returned {len(landsat_products)} products")
            except Exception as e:
                print(f"✗ Landsat API failed: {e}")

        # Search Sentinel-2 independently
        if "sentinel2" in query.satellites and self.sentinel2_api is not None:
            try:
                sentinel2_products = self.sentinel2_api.search(query)
                all_products.extend(sentinel2_products)
                print(f"✓ Sentinel-2 API returned {len(sentinel2_products)} products")
            except Exception as e:
                print(f"✗ Sentinel-2 API failed: {e}")

        return all_products

    def download(
        self, products: List[SatelliteProduct], output_dir: str, max_concurrent: int = 3
    ) -> Dict[str, str]:
        """Download satellite products to specified directory

        Args:
            products: List of SatelliteProduct objects to download
            output_dir: Directory to save downloaded files
            max_concurrent: Maximum concurrent downloads

        Returns:
            Dictionary mapping product_id to local file path (or error message)
        """
        import os
        import threading

        os.makedirs(output_dir, exist_ok=True)
        results = {}
        semaphore = threading.Semaphore(max_concurrent)
        threads = []

        def download_single(product: SatelliteProduct):
            semaphore.acquire()
            try:
                if product.satellite == "landsat" and self.landsat_api:
                    file_path = self.landsat_api.download_product(product, output_dir)
                    results[product.product_id] = file_path
                elif product.satellite == "sentinel2" and self.sentinel2_api:
                    file_path = self.sentinel2_api.download_product(product, output_dir)
                    results[product.product_id] = file_path
                else:
                    results[product.product_id] = (
                        f"Error: No API available for {product.satellite}"
                    )

            except Exception as e:
                results[product.product_id] = f"Error: {str(e)}"
            finally:
                semaphore.release()

        # Start download threads
        for product in products:
            thread = threading.Thread(target=download_single, args=(product,))
            threads.append(thread)
            thread.start()

        # Wait for all downloads to complete
        for thread in threads:
            thread.join()

        return results

    def quicklook_filter(
        self, products: List[SatelliteProduct], config=None
    ) -> Dict[str, List[SatelliteProduct]]:
        """Apply QuickLook filtering to products based on thumbnails

        Args:
            products: List of SatelliteProduct objects
            config: Optional QuickLookConfig for customization

        Returns:
            Dictionary with cluster names as keys and filtered product lists as values
        """
        from ..ml import QuickLookConfig, QuickLookFilter

        if config is None:
            config = QuickLookConfig()

        filter_system = QuickLookFilter(config)
        clusters = filter_system.process_products(products)

        # Print summary
        summary = filter_system.get_clustering_summary()
        print("\nQuickLook Clustering Summary:")
        for cluster_name, count in summary.items():
            print(f"  {cluster_name}: {count} products")

        return clusters

    def filter_by_clusters(
        self,
        products: List[SatelliteProduct],
        clusters: Dict[str, List[SatelliteProduct]],
        target_clusters: List[str],
    ) -> List[SatelliteProduct]:
        """Filter products by specific cluster names

        Args:
            products: Original list of products
            clusters: Dictionary with cluster names as keys and product lists as values
            target_clusters: List of cluster names to include

        Returns:
            Filtered list of products from selected clusters
        """
        filtered_products = []

        for cluster_name in target_clusters:
            if cluster_name in clusters:
                filtered_products.extend(clusters[cluster_name])
                print(
                    f"✓ Added {len(clusters[cluster_name])} products from '{cluster_name}' cluster"
                )
            else:
                print(
                    f"⚠️ Cluster '{cluster_name}' not found. Available clusters: {list(clusters.keys())}"
                )

        print("\n📊 Filtering summary:")
        print(f"   • Total products: {len(products)}")
        print(f"   • Filtered products: {len(filtered_products)}")
        print(
            f"   • Selection ratio: {len(filtered_products) / len(products) * 100:.1f}%"
        )

        return filtered_products

    def create_download_manifest(
        self,
        products: List[SatelliteProduct],
        output_path: str = "download_manifest.csv",
    ) -> str:
        """Create a CSV manifest of products for download tracking

        Args:
            products: List of products to include in manifest
            output_path: Path to save the manifest CSV file

        Returns:
            Path to the created manifest file
        """
        try:
            import pandas as pd
        except ImportError:
            raise ImportError(
                "pandas is required for manifest creation. Install with: pip install pandas"
            )

        manifest_data = []
        for product in products:
            manifest_data.append(
                {
                    "product_id": product.product_id,
                    "satellite": product.satellite,
                    "acquisition_date": product.acquisition_date,
                    "cloud_cover": product.cloud_cover,
                    "download_url": product.download_url,
                    "thumbnail_url": getattr(product, "thumbnail_url", ""),
                    "bounds": str(product.bounds) if product.bounds else "",
                    "processing_level": getattr(product, "processing_level", ""),
                    "orbit_number": getattr(product, "orbit_number", ""),
                    "relative_orbit_number": getattr(
                        product, "relative_orbit_number", ""
                    ),
                    "wrs_path": getattr(product, "wrs_path", ""),
                    "wrs_row": getattr(product, "wrs_row", ""),
                    "local_filename": f"{product.product_id}.zip",
                }
            )

        df = pd.DataFrame(manifest_data)
        df.to_csv(output_path, index=False)

        print(f"📋 Download manifest created: {output_path}")
        print(f"   • Products: {len(products)}")
        print(f"   • Columns: {len(df.columns)}")

        return output_path

    def search_and_filter(
        self, query: SatelliteQuery, quicklook_config=None
    ) -> Dict[str, List[SatelliteProduct]]:
        """Combined search and QuickLook filtering workflow

        Args:
            query: SatelliteQuery object
            quicklook_config: Optional QuickLookConfig

        Returns:
            Dictionary with cluster names as keys and filtered product lists as values
        """
        # First search for products
        products = self.search(query)

        if not products:
            print("No products found to filter")
            return {}

        print(f"\nApplying QuickLook filtering to {len(products)} products...")

        # Apply QuickLook filtering
        return self.quicklook_filter(products, quicklook_config)

    def products_to_dataframe(self, products: List[SatelliteProduct]) -> "pd.DataFrame":
        """Convert list of satellite products to pandas DataFrame for analysis"""
        try:
            import pandas as pd
        except ImportError:
            raise ImportError(
                "pandas is required for DataFrame conversion. Install with: pip install pandas"
            )

        if not products:
            return pd.DataFrame()

        # Convert products to list of dictionaries
        data = [product.to_dict() for product in products]
        df = pd.DataFrame(data)

        # Convert date strings to datetime for better analysis
        if "acquisition_date" in df.columns:
            # Use utc=True to handle mixed timezone formats
            df["acquisition_date"] = pd.to_datetime(
                df["acquisition_date"], errors="coerce", utc=True
            )

        # Sort by acquisition date for better organization
        if "acquisition_date" in df.columns:
            df = df.sort_values("acquisition_date").reset_index(drop=True)

        return df

create_download_manifest(products, output_path='download_manifest.csv')

Create a CSV manifest of products for download tracking

Parameters:

Name Type Description Default
products List[SatelliteProduct]

List of products to include in manifest

required
output_path str

Path to save the manifest CSV file

'download_manifest.csv'

Returns:

Type Description
str

Path to the created manifest file

Source code in ShallowLearn/api/unified_satellite_api.py
def create_download_manifest(
    self,
    products: List[SatelliteProduct],
    output_path: str = "download_manifest.csv",
) -> str:
    """Create a CSV manifest of products for download tracking

    Args:
        products: List of products to include in manifest
        output_path: Path to save the manifest CSV file

    Returns:
        Path to the created manifest file
    """
    try:
        import pandas as pd
    except ImportError:
        raise ImportError(
            "pandas is required for manifest creation. Install with: pip install pandas"
        )

    manifest_data = []
    for product in products:
        manifest_data.append(
            {
                "product_id": product.product_id,
                "satellite": product.satellite,
                "acquisition_date": product.acquisition_date,
                "cloud_cover": product.cloud_cover,
                "download_url": product.download_url,
                "thumbnail_url": getattr(product, "thumbnail_url", ""),
                "bounds": str(product.bounds) if product.bounds else "",
                "processing_level": getattr(product, "processing_level", ""),
                "orbit_number": getattr(product, "orbit_number", ""),
                "relative_orbit_number": getattr(
                    product, "relative_orbit_number", ""
                ),
                "wrs_path": getattr(product, "wrs_path", ""),
                "wrs_row": getattr(product, "wrs_row", ""),
                "local_filename": f"{product.product_id}.zip",
            }
        )

    df = pd.DataFrame(manifest_data)
    df.to_csv(output_path, index=False)

    print(f"📋 Download manifest created: {output_path}")
    print(f"   • Products: {len(products)}")
    print(f"   • Columns: {len(df.columns)}")

    return output_path

download(products, output_dir, max_concurrent=3)

Download satellite products to specified directory

Parameters:

Name Type Description Default
products List[SatelliteProduct]

List of SatelliteProduct objects to download

required
output_dir str

Directory to save downloaded files

required
max_concurrent int

Maximum concurrent downloads

3

Returns:

Type Description
Dict[str, str]

Dictionary mapping product_id to local file path (or error message)

Source code in ShallowLearn/api/unified_satellite_api.py
def download(
    self, products: List[SatelliteProduct], output_dir: str, max_concurrent: int = 3
) -> Dict[str, str]:
    """Download satellite products to specified directory

    Args:
        products: List of SatelliteProduct objects to download
        output_dir: Directory to save downloaded files
        max_concurrent: Maximum concurrent downloads

    Returns:
        Dictionary mapping product_id to local file path (or error message)
    """
    import os
    import threading

    os.makedirs(output_dir, exist_ok=True)
    results = {}
    semaphore = threading.Semaphore(max_concurrent)
    threads = []

    def download_single(product: SatelliteProduct):
        semaphore.acquire()
        try:
            if product.satellite == "landsat" and self.landsat_api:
                file_path = self.landsat_api.download_product(product, output_dir)
                results[product.product_id] = file_path
            elif product.satellite == "sentinel2" and self.sentinel2_api:
                file_path = self.sentinel2_api.download_product(product, output_dir)
                results[product.product_id] = file_path
            else:
                results[product.product_id] = (
                    f"Error: No API available for {product.satellite}"
                )

        except Exception as e:
            results[product.product_id] = f"Error: {str(e)}"
        finally:
            semaphore.release()

    # Start download threads
    for product in products:
        thread = threading.Thread(target=download_single, args=(product,))
        threads.append(thread)
        thread.start()

    # Wait for all downloads to complete
    for thread in threads:
        thread.join()

    return results

filter_by_clusters(products, clusters, target_clusters)

Filter products by specific cluster names

Parameters:

Name Type Description Default
products List[SatelliteProduct]

Original list of products

required
clusters Dict[str, List[SatelliteProduct]]

Dictionary with cluster names as keys and product lists as values

required
target_clusters List[str]

List of cluster names to include

required

Returns:

Type Description
List[SatelliteProduct]

Filtered list of products from selected clusters

Source code in ShallowLearn/api/unified_satellite_api.py
def filter_by_clusters(
    self,
    products: List[SatelliteProduct],
    clusters: Dict[str, List[SatelliteProduct]],
    target_clusters: List[str],
) -> List[SatelliteProduct]:
    """Filter products by specific cluster names

    Args:
        products: Original list of products
        clusters: Dictionary with cluster names as keys and product lists as values
        target_clusters: List of cluster names to include

    Returns:
        Filtered list of products from selected clusters
    """
    filtered_products = []

    for cluster_name in target_clusters:
        if cluster_name in clusters:
            filtered_products.extend(clusters[cluster_name])
            print(
                f"✓ Added {len(clusters[cluster_name])} products from '{cluster_name}' cluster"
            )
        else:
            print(
                f"⚠️ Cluster '{cluster_name}' not found. Available clusters: {list(clusters.keys())}"
            )

    print("\n📊 Filtering summary:")
    print(f"   • Total products: {len(products)}")
    print(f"   • Filtered products: {len(filtered_products)}")
    print(
        f"   • Selection ratio: {len(filtered_products) / len(products) * 100:.1f}%"
    )

    return filtered_products

products_to_dataframe(products)

Convert list of satellite products to pandas DataFrame for analysis

Source code in ShallowLearn/api/unified_satellite_api.py
def products_to_dataframe(self, products: List[SatelliteProduct]) -> "pd.DataFrame":
    """Convert list of satellite products to pandas DataFrame for analysis"""
    try:
        import pandas as pd
    except ImportError:
        raise ImportError(
            "pandas is required for DataFrame conversion. Install with: pip install pandas"
        )

    if not products:
        return pd.DataFrame()

    # Convert products to list of dictionaries
    data = [product.to_dict() for product in products]
    df = pd.DataFrame(data)

    # Convert date strings to datetime for better analysis
    if "acquisition_date" in df.columns:
        # Use utc=True to handle mixed timezone formats
        df["acquisition_date"] = pd.to_datetime(
            df["acquisition_date"], errors="coerce", utc=True
        )

    # Sort by acquisition date for better organization
    if "acquisition_date" in df.columns:
        df = df.sort_values("acquisition_date").reset_index(drop=True)

    return df

quicklook_filter(products, config=None)

Apply QuickLook filtering to products based on thumbnails

Parameters:

Name Type Description Default
products List[SatelliteProduct]

List of SatelliteProduct objects

required
config

Optional QuickLookConfig for customization

None

Returns:

Type Description
Dict[str, List[SatelliteProduct]]

Dictionary with cluster names as keys and filtered product lists as values

Source code in ShallowLearn/api/unified_satellite_api.py
def quicklook_filter(
    self, products: List[SatelliteProduct], config=None
) -> Dict[str, List[SatelliteProduct]]:
    """Apply QuickLook filtering to products based on thumbnails

    Args:
        products: List of SatelliteProduct objects
        config: Optional QuickLookConfig for customization

    Returns:
        Dictionary with cluster names as keys and filtered product lists as values
    """
    from ..ml import QuickLookConfig, QuickLookFilter

    if config is None:
        config = QuickLookConfig()

    filter_system = QuickLookFilter(config)
    clusters = filter_system.process_products(products)

    # Print summary
    summary = filter_system.get_clustering_summary()
    print("\nQuickLook Clustering Summary:")
    for cluster_name, count in summary.items():
        print(f"  {cluster_name}: {count} products")

    return clusters

search(query)

Search all requested satellites independently

Source code in ShallowLearn/api/unified_satellite_api.py
def search(self, query: SatelliteQuery) -> List[SatelliteProduct]:
    """Search all requested satellites independently"""
    self._initialize_apis(query.satellites)
    all_products = []

    # Search Landsat independently
    if "landsat" in query.satellites and self.landsat_api is not None:
        try:
            landsat_products = self.landsat_api.search(query)
            all_products.extend(landsat_products)
            print(f"✓ Landsat API returned {len(landsat_products)} products")
        except Exception as e:
            print(f"✗ Landsat API failed: {e}")

    # Search Sentinel-2 independently
    if "sentinel2" in query.satellites and self.sentinel2_api is not None:
        try:
            sentinel2_products = self.sentinel2_api.search(query)
            all_products.extend(sentinel2_products)
            print(f"✓ Sentinel-2 API returned {len(sentinel2_products)} products")
        except Exception as e:
            print(f"✗ Sentinel-2 API failed: {e}")

    return all_products

search_and_filter(query, quicklook_config=None)

Combined search and QuickLook filtering workflow

Parameters:

Name Type Description Default
query SatelliteQuery

SatelliteQuery object

required
quicklook_config

Optional QuickLookConfig

None

Returns:

Type Description
Dict[str, List[SatelliteProduct]]

Dictionary with cluster names as keys and filtered product lists as values

Source code in ShallowLearn/api/unified_satellite_api.py
def search_and_filter(
    self, query: SatelliteQuery, quicklook_config=None
) -> Dict[str, List[SatelliteProduct]]:
    """Combined search and QuickLook filtering workflow

    Args:
        query: SatelliteQuery object
        quicklook_config: Optional QuickLookConfig

    Returns:
        Dictionary with cluster names as keys and filtered product lists as values
    """
    # First search for products
    products = self.search(query)

    if not products:
        print("No products found to filter")
        return {}

    print(f"\nApplying QuickLook filtering to {len(products)} products...")

    # Apply QuickLook filtering
    return self.quicklook_filter(products, quicklook_config)

Features Module

Spectral analysis utilities for ShallowLearn. Contains water quality and marine remote sensing indices.

Segmentation Module

Image segmentation utilities for ShallowLearn. Clean superpixel and DII extraction functions.

Visualization Module

Visualization utilities for ShallowLearn. Clean plotting and display functions for remote sensing data.

plot_rgb = plot_rgb_enhanced module-attribute

Plots an RGB image using specified band indices.

Parameters:

img : np.ndarray Input image array with shape (height, width, bands) band_indices : List[int] List of 3 band indices for R, G, B channels title : str, default="RGB Image" Title for the plot figsize : Tuple[int, int], default=(8, 8) Figure size show : bool, default=True Whether to display the plot

Returns:

plt.Figure or None Figure object if show=False, otherwise None

QuickLookVisualizer

Handles visualization of QuickLook results - separated from ML processing

Source code in ShallowLearn/visualization/quicklook_viz.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
class QuickLookVisualizer:
    """Handles visualization of QuickLook results - separated from ML processing"""

    def __init__(self, quicklook_processor):
        """Initialize with a QuickLookProcessor instance from ml module"""
        self.processor = quicklook_processor
        self.transformed_data = quicklook_processor.transformed_data
        self.labels = quicklook_processor.labels
        self.images = quicklook_processor.processed_images
        self.metadata_df = quicklook_processor.metadata_df

        # Generate class dictionary for visualization
        self.class_dict = self._generate_class_dict()

        # For compatibility with existing visualization code
        self.thumbnails = self.images  # processed images act as thumbnails
        self.products = self._create_product_objects()

    def _generate_class_dict(self):
        """Generate a class dictionary for visualization colors and labels."""
        if self.labels is None:
            return {}

        unique_labels = np.unique(self.labels)
        colors = [
            "#2ca02c",
            "#ff7f0e",
            "#d62728",
            "#1f77b4",
            "#9467bd",
            "#8c564b",
            "#e377c2",
        ]

        class_dict = {}
        for i, label in enumerate(unique_labels):
            color = colors[i % len(colors)]

            # Generate meaningful label names based on cluster characteristics
            if label == -1:
                name = "Outliers"
            else:
                # For L2A data, try to use AOT retrieval method for labeling
                name = self._get_cluster_name(label)

            class_dict[label] = (color, name)

        return class_dict

    def _get_cluster_name(self, label):
        """Generate cluster name based on metadata, especially AOT retrieval method for L2A."""
        if self.metadata_df is None:
            return f"Cluster {label}"

        # Get images belonging to this cluster
        cluster_mask = self.labels == label
        cluster_metadata = self.metadata_df[cluster_mask]

        if len(cluster_metadata) == 0:
            return f"Cluster {label}"

        # Check if this is L2A data by looking for AOT_RETRIEVAL_METHOD
        aot_methods = cluster_metadata['aot_retrieval_method'].dropna()

        if len(aot_methods) > 0:
            # Get most common AOT retrieval method in this cluster
            most_common_aot = aot_methods.value_counts().index[0]
            if most_common_aot != 'N/A':
                return f"AOT-{most_common_aot}"

        # Check processing level
        processing_levels = cluster_metadata['processing_level'].dropna()
        if len(processing_levels) > 0:
            most_common_level = processing_levels.value_counts().index[0]
            if 'L2A' in most_common_level or 'Level-2A' in most_common_level:
                return f"L2A-C{label}"
            elif 'L1C' in most_common_level or 'Level-1C' in most_common_level:
                return f"L1C-C{label}"

        return f"Cluster {label}"

    def _create_product_objects(self):
        """Create product-like objects for compatibility with existing viz code."""
        products = []

        if self.metadata_df is not None:
            for _, row in self.metadata_df.iterrows():
                # Create a simple product object with minimal required attributes
                product = type(
                    "Product",
                    (),
                    {
                        "product_id": Path(row["file_path"]).name,
                        "satellite": row.get("satellite_type", "unknown"),
                        "cloud_cover": row.get("cloud_cover", 0),
                        "acquisition_date": "2023-01-01",  # Default date if not available
                    },
                )()
                products.append(product)

        return products

    def plot_clusters_scatter(self, figsize_base=8, save_path=None):
        """Create scatter plot of clusters in reduced dimensional space"""
        if self.transformed_data is None:
            raise ValueError(
                "No transformed data available. Run process_products first."
            )

        fig, ax = create_square_figure(figsize_base)

        # Get colors and labels for legend
        unique_labels = np.unique(self.labels)
        colors = []
        legend_labels = []

        for label in unique_labels:
            color, name = self.class_dict.get(label, ("#808080", f"Cluster_{label}"))
            colors.append(color)
            legend_labels.append(f"{name} ({np.sum(self.labels == label)})")

        # Create custom colormap
        cmap = ListedColormap(colors)

        # Plot scatter
        scatter = ax.scatter(
            self.transformed_data[:, 0],
            self.transformed_data[:, 1],
            c=self.labels,
            cmap=cmap,
            s=50,
            alpha=0.7,
        )

        # Add colorbar with custom labels
        cbar = plt.colorbar(scatter, ticks=unique_labels)
        cbar.set_ticklabels(legend_labels)

        # Set labels
        method_name = self.processor.reducer.get_name()
        ax.set_xlabel(f"{method_name} Component 1", fontsize=12)
        ax.set_ylabel(f"{method_name} Component 2", fontsize=12)
        ax.set_title(f"QuickLook Clustering Results ({method_name})", fontsize=14)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")

        return fig, ax

    def plot_thumbnails_on_scatter(
        self,
        show_points=True,
        show_thumbnails=True,
        zoom=0.1,
        figsize_base=10,
        max_images=50,
        sample_method="random",
        add_borders=True,
        save_path=None,
    ):
        """
        Plot thumbnail images on their cluster coordinates with flexible options

        Args:
            show_points: Whether to show scatter points underneath thumbnails
            show_thumbnails: Whether to show thumbnail images
            zoom: Zoom level for thumbnails (0.05-0.2 recommended)
            figsize_base: Base size for square figure
            max_images: Maximum number of thumbnails to show (None = all)
            sample_method: 'random', 'cluster' (sample from each cluster), or 'all'
            add_borders: Whether to add colored borders to thumbnails matching clusters
            save_path: Path to save the plot
        """
        if self.transformed_data is None:
            raise ValueError("No data available. Run process_products first.")

        fig, ax = create_square_figure(figsize_base)

        # Get unique labels and colors
        unique_labels = np.unique(self.labels)
        colors = ["#2ca02c", "#ff7f0e", "#d62728", "#1f77b4", "#808080"]

        # Plot scatter points if requested
        if show_points:
            for i, label in enumerate(unique_labels):
                mask = self.labels == label
                label_name = self.class_dict.get(label, (None, f"Cluster {label}"))[1]

                ax.scatter(
                    self.transformed_data[mask, 0],
                    self.transformed_data[mask, 1],
                    c=colors[i % len(colors)],
                    label=f"{label_name} ({sum(mask)})",
                    s=50 if not show_thumbnails else 20,
                    alpha=0.7 if not show_thumbnails else 0.3,
                    edgecolors="black" if not show_thumbnails else "none",
                    linewidth=0.5 if not show_thumbnails else 0,
                )

        # Add thumbnails if requested
        if show_thumbnails and self.thumbnails:
            # Determine which thumbnails to show
            if sample_method == "all" or max_images is None:
                indices = range(len(self.thumbnails))
            elif sample_method == "cluster":
                # Sample evenly from each cluster
                indices = []
                for label in unique_labels:
                    mask = self.labels == label
                    label_indices = np.where(mask)[0]
                    n_samples = min(
                        max_images // len(unique_labels), len(label_indices)
                    )
                    if n_samples > 0:
                        sampled = np.random.choice(
                            label_indices, n_samples, replace=False
                        )
                        indices.extend(sampled)
            else:  # random
                n_images = min(len(self.thumbnails), max_images)
                indices = np.random.choice(
                    len(self.thumbnails), n_images, replace=False
                )

            # Add thumbnails
            for i in indices:
                try:
                    thumbnail = self.thumbnails[i]
                    original_dtype = thumbnail.dtype

                    # Resize thumbnail
                    shape = thumbnail.shape
                    resized_img = resize(
                        thumbnail,
                        output_shape=(int(shape[0] * zoom), int(shape[1] * zoom)),
                        anti_aliasing=True,
                        preserve_range=True,
                    ).astype(original_dtype)

                    # Create image box
                    imagebox = OffsetImage(resized_img, zoom=zoom)

                    # Add border if requested
                    if add_borders:
                        label = self.labels[i]
                        label_idx = list(unique_labels).index(label)
                        border_color = colors[label_idx % len(colors)]
                        ab = AnnotationBbox(
                            imagebox,
                            (self.transformed_data[i, 0], self.transformed_data[i, 1]),
                            frameon=True,
                            bboxprops=dict(edgecolor=border_color, linewidth=2),
                        )
                    else:
                        ab = AnnotationBbox(
                            imagebox,
                            (self.transformed_data[i, 0], self.transformed_data[i, 1]),
                            frameon=False,
                        )

                    ax.add_artist(ab)
                except Exception as e:
                    print(f"Error adding thumbnail {i}: {e}")

        # Set axis limits with some padding
        x_min, x_max = (
            self.transformed_data[:, 0].min(),
            self.transformed_data[:, 0].max(),
        )
        y_min, y_max = (
            self.transformed_data[:, 1].min(),
            self.transformed_data[:, 1].max(),
        )

        x_padding = (x_max - x_min) * 0.1
        y_padding = (y_max - y_min) * 0.1

        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)

        # Labels and legend
        method_name = (
            self.processor.reducer.get_name()
            if hasattr(self.processor, "reducer")
            else "Reduced"
        )
        ax.set_xlabel(f"{method_name} Component 1", fontsize=12)
        ax.set_ylabel(f"{method_name} Component 2", fontsize=12)

        # Title based on options
        if show_thumbnails and self.thumbnails:
            n_shown = len(list(indices))
            title = f"Satellite Thumbnails in {method_name} Space ({n_shown}/{len(self.thumbnails)} shown)"
        else:
            title = f"Satellite Clustering in {method_name} Space"
        ax.set_title(title, fontsize=14)

        if show_points:
            ax.legend(loc="best")

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")

        return fig, ax

    def plot_temporal_distribution(self, figsize=(14, 8), save_path=None):
        """Plot temporal distribution of products by cluster"""
        if not self.products:
            raise ValueError("No products available")

        # Extract dates from products
        dates = []
        cloud_covers = []

        for product in self.products:
            try:
                # Handle different date formats
                date_str = product.acquisition_date
                if "T" in date_str:  # ISO format
                    date = pd.to_datetime(date_str.split("T")[0])
                else:
                    date = pd.to_datetime(date_str.split()[0])  # Split by space
                dates.append(date)
                cloud_covers.append(product.cloud_cover)
            except Exception as e:
                print(f"Error parsing date for {product.product_id}: {e}")
                continue

        if not dates:
            print("No valid dates found in products")
            return None, None

        # Create DataFrame for easier plotting
        df = pd.DataFrame(
            {
                "date": dates,
                "cloud_cover": cloud_covers,
                "label": self.labels[
                    : len(dates)
                ],  # In case some dates failed to parse
            }
        )

        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)

        # Plot 1: Temporal distribution by cluster
        unique_labels = np.unique(self.labels)
        for label in unique_labels:
            cluster_data = df[df["label"] == label]
            if len(cluster_data) > 0:
                color, name = self.class_dict.get(
                    label, ("#808080", f"Cluster_{label}")
                )
                ax1.scatter(
                    cluster_data["date"],
                    cluster_data["cloud_cover"],
                    c=color,
                    label=f"{name} ({len(cluster_data)})",
                    alpha=0.7,
                    s=50,
                )

        ax1.set_xlabel("Date", fontsize=12)
        ax1.set_ylabel("Cloud Cover (%)", fontsize=12)
        ax1.set_title("Temporal Distribution of Products by Cluster", fontsize=14)
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Format dates on x-axis
        ax1.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
        ax1.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
        plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)

        # Plot 2: Monthly product count by cluster
        df["month"] = df["date"].dt.to_period("M")
        monthly_counts = df.groupby(["month", "label"]).size().unstack(fill_value=0)

        # Create stacked bar plot
        colors_dict = {
            label: self.class_dict.get(label, ("#808080", f"Cluster_{label}"))[0]
            for label in unique_labels
        }

        monthly_counts.plot(
            kind="bar",
            stacked=True,
            ax=ax2,
            color=[colors_dict[label] for label in monthly_counts.columns],
        )

        ax2.set_xlabel("Month", fontsize=12)
        ax2.set_ylabel("Number of Products", fontsize=12)
        ax2.set_title("Monthly Product Count by Cluster", fontsize=14)
        ax2.legend(title="Cluster")
        ax2.grid(True, alpha=0.3)
        plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")

        return fig, (ax1, ax2)

    def plot_cluster_statistics(self, figsize=(12, 8), save_path=None):
        """Plot statistics about each cluster"""
        if not self.products:
            raise ValueError("No products available")

        # Calculate statistics for each cluster
        cluster_stats = {}

        for label in np.unique(self.labels):
            mask = self.labels == label
            cluster_products = [
                self.products[i] for i in range(len(self.products)) if mask[i]
            ]
            cluster_thumbnails = np.array(self.thumbnails)[mask]

            stats = {
                "count": len(cluster_products),
                "mean_cloud_cover": np.mean([p.cloud_cover for p in cluster_products]),
                "std_cloud_cover": np.std([p.cloud_cover for p in cluster_products]),
                "mean_brightness": np.mean(cluster_thumbnails),
                "satellites": {},
            }

            # Count by satellite type
            for product in cluster_products:
                sat_type = product.satellite
                stats["satellites"][sat_type] = stats["satellites"].get(sat_type, 0) + 1

            cluster_name = self.class_dict.get(label, (None, f"Cluster_{label}"))[1]
            cluster_stats[cluster_name] = stats

        # Create visualization
        fig, axes = plt.subplots(2, 2, figsize=figsize)

        # Plot 1: Product count by cluster
        names = list(cluster_stats.keys())
        counts = [stats["count"] for stats in cluster_stats.values()]
        colors = [
            self.class_dict.get(label, ("#808080", ""))[0]
            for label in np.unique(self.labels)
        ]

        axes[0, 0].bar(names, counts, color=colors)
        axes[0, 0].set_title("Product Count by Cluster")
        axes[0, 0].set_ylabel("Number of Products")
        plt.setp(axes[0, 0].xaxis.get_majorticklabels(), rotation=45)

        # Plot 2: Mean cloud cover by cluster
        cloud_means = [stats["mean_cloud_cover"] for stats in cluster_stats.values()]
        cloud_stds = [stats["std_cloud_cover"] for stats in cluster_stats.values()]

        axes[0, 1].bar(names, cloud_means, yerr=cloud_stds, color=colors, alpha=0.7)
        axes[0, 1].set_title("Mean Cloud Cover by Cluster")
        axes[0, 1].set_ylabel("Cloud Cover (%)")
        plt.setp(axes[0, 1].xaxis.get_majorticklabels(), rotation=45)

        # Plot 3: Mean brightness by cluster
        brightness_means = [
            stats["mean_brightness"] for stats in cluster_stats.values()
        ]

        axes[1, 0].bar(names, brightness_means, color=colors)
        axes[1, 0].set_title("Mean Thumbnail Brightness by Cluster")
        axes[1, 0].set_ylabel("Brightness (0-255)")
        plt.setp(axes[1, 0].xaxis.get_majorticklabels(), rotation=45)

        # Plot 4: Satellite type distribution
        # Create a stacked bar chart for satellite types
        all_satellites = set()
        for stats in cluster_stats.values():
            all_satellites.update(stats["satellites"].keys())
        all_satellites = list(all_satellites)

        satellite_data = {}
        for sat in all_satellites:
            satellite_data[sat] = [
                stats["satellites"].get(sat, 0) for stats in cluster_stats.values()
            ]

        bottom = np.zeros(len(names))
        for i, sat in enumerate(all_satellites):
            axes[1, 1].bar(
                names, satellite_data[sat], bottom=bottom, label=sat, alpha=0.8
            )
            bottom += satellite_data[sat]

        axes[1, 1].set_title("Satellite Type Distribution by Cluster")
        axes[1, 1].set_ylabel("Number of Products")
        axes[1, 1].legend()
        plt.setp(axes[1, 1].xaxis.get_majorticklabels(), rotation=45)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")

        return fig, axes, cluster_stats

    def create_cloudcover_meshgrid(self, ax, resolution=50, alpha=0.3):
        """Create a cloud cover background meshgrid"""
        if not self.products:
            return

        # Get data bounds
        x_min, x_max = (
            self.transformed_data[:, 0].min(),
            self.transformed_data[:, 0].max(),
        )
        y_min, y_max = (
            self.transformed_data[:, 1].min(),
            self.transformed_data[:, 1].max(),
        )

        # Create grid
        x_range = x_max - x_min
        y_range = y_max - y_min
        xx, yy = np.meshgrid(
            np.linspace(x_min - x_range * 0.1, x_max + x_range * 0.1, resolution),
            np.linspace(y_min - y_range * 0.1, y_max + y_range * 0.1, resolution),
        )

        # Interpolate cloud cover values to create smooth background
        from scipy.interpolate import griddata

        try:
            # Get cloud cover values and coordinates
            cloud_values = np.array([p.cloud_cover for p in self.products])
            coords = self.transformed_data

            # Interpolate cloud cover across the grid - always use first 2 components
            grid_values = griddata(
                coords[:, :2],  # Always use first 2 components for 2D interpolation
                cloud_values,
                (xx, yy),
                method="linear",  # Linear is more stable for scattered data
                fill_value=np.mean(cloud_values),
            )

            # Create contour plot
            contour = ax.contourf(
                xx, yy, grid_values, levels=20, alpha=alpha, cmap="Blues_r"
            )

            return contour

        except Exception as e:
            print(f"Could not create meshgrid: {e}")
            return None

    def plot_publication_quality(
        self,
        method_name="PCA",
        show_thumbnails=True,
        show_meshgrid=True,
        thumbnail_sample="cluster",
        figsize=(7, 4),
        dpi=300,
        save_path=None,
    ):  # Figure size for A4 page
        """
        Create publication-quality plot with all enhancements

        Args:
            method_name: Name of dimensionality reduction method
            show_thumbnails: Whether to overlay thumbnails
            show_meshgrid: Whether to show cloud cover background
            thumbnail_sample: 'cluster', 'random', or 'all'
            figsize: Figure size
            dpi: Resolution for saving
            save_path: Path to save the plot
        """
        if self.transformed_data is None:
            raise ValueError("No data available")

        # Create figure with high DPI
        fig, ax = plt.subplots(
            figsize=figsize, dpi=100
        )  # matplotlib will scale for save

        # Add cloud cover meshgrid background if requested
        if show_meshgrid:
            contour = self.create_cloudcover_meshgrid(ax, alpha=0.2)
            if contour:
                # Add colorbar for cloud cover
                cbar = plt.colorbar(contour, ax=ax, shrink=0.8, pad=0.02)
                cbar.set_label("Cloud Cover (%)", fontsize=12, fontweight="bold")
                cbar.ax.tick_params(labelsize=10)

        # Get unique labels and colors
        unique_labels = np.unique(self.labels)
        colors = [
            "#2ca02c",
            "#ff7f0e",
            "#d62728",
            "#1f77b4",
            "#9467bd",
            "#8c564b",
            "#e377c2",
        ]

        # Plot scatter points
        scatter_handles = []
        for i, label in enumerate(unique_labels):
            mask = self.labels == label
            label_name = self.class_dict.get(label, (None, f"Cluster {label}"))[1]

            scatter = ax.scatter(
                self.transformed_data[mask, 0],
                self.transformed_data[mask, 1],
                c=colors[i % len(colors)],
                label=f"{label_name} ({sum(mask)})",
                s=80 if not show_thumbnails else 40,
                alpha=0.8 if not show_thumbnails else 0.6,
                edgecolors="white",
                linewidth=1.5,
                zorder=3,
            )
            scatter_handles.append(scatter)

        # Add thumbnails if requested
        if show_thumbnails and self.thumbnails:
            # Determine which thumbnails to show
            if thumbnail_sample == "cluster":
                # Sample evenly from each cluster
                indices = []
                max_per_cluster = max(1, 12 // len(unique_labels))  # Up to 12 total
                for label in unique_labels:
                    mask = self.labels == label
                    label_indices = np.where(mask)[0]
                    n_samples = min(max_per_cluster, len(label_indices))
                    if n_samples > 0:
                        sampled = np.random.choice(
                            label_indices, n_samples, replace=False
                        )
                        indices.extend(sampled)
            elif thumbnail_sample == "all":
                indices = range(len(self.thumbnails))
            else:  # random
                n_images = min(len(self.thumbnails), 12)
                indices = np.random.choice(
                    len(self.thumbnails), n_images, replace=False
                )

            zoom = 0.12  # Larger thumbnails for better visibility
            for i in indices:
                try:
                    thumbnail = self.thumbnails[i]

                    # Resize thumbnail
                    resized_img = resize(
                        thumbnail,
                        output_shape=(
                            int(thumbnail.shape[0] * zoom),
                            int(thumbnail.shape[1] * zoom),
                        ),
                        anti_aliasing=True,
                        preserve_range=True,
                    ).astype(thumbnail.dtype)

                    # Get cluster color for border
                    label = self.labels[i]
                    label_idx = list(unique_labels).index(label)
                    border_color = colors[label_idx % len(colors)]

                    imagebox = OffsetImage(resized_img, zoom=zoom)
                    ab = AnnotationBbox(
                        imagebox,
                        (self.transformed_data[i, 0], self.transformed_data[i, 1]),
                        frameon=True,
                        bboxprops=dict(
                            edgecolor=border_color,
                            linewidth=2,
                            facecolor="white",
                            alpha=0.9,
                        ),
                        zorder=4,
                    )
                    ax.add_artist(ab)
                except Exception as e:
                    print(f"Error adding thumbnail {i}: {e}")

        # Set axis limits with proper padding
        x_min, x_max = (
            self.transformed_data[:, 0].min(),
            self.transformed_data[:, 0].max(),
        )
        y_min, y_max = (
            self.transformed_data[:, 1].min(),
            self.transformed_data[:, 1].max(),
        )

        x_range = x_max - x_min
        y_range = y_max - y_min
        padding_x = x_range * 0.15
        padding_y = y_range * 0.15

        ax.set_xlim(x_min - padding_x, x_max + padding_x)
        ax.set_ylim(y_min - padding_y, y_max + padding_y)

        # Publication-quality formatting
        ax.set_xlabel(f"{method_name} Component 1", fontsize=14, fontweight="bold")
        ax.set_ylabel(f"{method_name} Component 2", fontsize=14, fontweight="bold")

        # Determine satellite type for title
        satellite_types = set(p.satellite for p in self.products)
        if len(satellite_types) == 1:
            sat_name = list(satellite_types)[0].title()
        else:
            sat_name = "Multi-satellite"

        title = f"{sat_name} Thumbnail Clustering ({method_name})"
        if show_meshgrid:
            title += "\nwith Cloud Cover Background"

        ax.set_title(title, fontsize=16, fontweight="bold", pad=20)

        # Enhanced legend
        legend = ax.legend(
            handles=scatter_handles,
            loc="upper right",
            fontsize=11,
            frameon=True,
            fancybox=True,
            shadow=True,
            framealpha=0.9,
            edgecolor="black",
        )
        legend.get_frame().set_linewidth(1.5)

        # Grid and styling
        ax.grid(True, alpha=0.3, linestyle="--", linewidth=0.8)
        ax.tick_params(axis="both", labelsize=11, width=1.2, length=4)

        # Ensure square aspect ratio for better comparison
        ax.set_aspect("equal", adjustable="box")

        plt.tight_layout()

        if save_path:
            plt.savefig(
                save_path,
                dpi=dpi,
                bbox_inches="tight",
                facecolor="white",
                edgecolor="none",
            )
            print(f"✅ Saved: {save_path}")

        return fig, ax

    def generate_all_plots(self, output_dir="publication_plots", dpi=300):
        """Generate all possible visualization plots"""
        import os

        os.makedirs(output_dir, exist_ok=True)

        plots_generated = []

        # Get method name
        method_name = (
            self.processor.reducer.get_name()
            if hasattr(self.processor, "reducer")
            else "Unknown"
        )
        satellite_name = (
            list(set(p.satellite for p in self.products))[0]
            if self.products
            else "unknown"
        )

        # 1. Basic scatter plot
        fig1, ax1 = self.plot_clusters_scatter(
            save_path=f"{output_dir}/{satellite_name}_{method_name}_scatter.png"
        )
        plots_generated.append(f"{satellite_name}_{method_name}_scatter.png")

        # 2. Thumbnails without meshgrid
        fig2, ax2 = self.plot_publication_quality(
            method_name=method_name.split("(")[0],  # Clean method name
            show_thumbnails=True,
            show_meshgrid=False,
            save_path=f"{output_dir}/{satellite_name}_{method_name}_thumbnails.png",
            dpi=dpi,
        )
        plots_generated.append(f"{satellite_name}_{method_name}_thumbnails.png")

        # 3. Thumbnails with cloud cover meshgrid
        fig3, ax3 = self.plot_publication_quality(
            method_name=method_name.split("(")[0],
            show_thumbnails=True,
            show_meshgrid=True,
            save_path=f"{output_dir}/{satellite_name}_{method_name}_with_meshgrid.png",
            dpi=dpi,
        )
        plots_generated.append(f"{satellite_name}_{method_name}_with_meshgrid.png")

        # 4. Temporal distribution
        if len(self.products) > 3:
            fig4, axes4 = self.plot_temporal_distribution(
                save_path=f"{output_dir}/{satellite_name}_{method_name}_temporal.png"
            )
            plots_generated.append(f"{satellite_name}_{method_name}_temporal.png")

        # 5. Cluster statistics
        fig5, axes5, stats = self.plot_cluster_statistics(
            save_path=f"{output_dir}/{satellite_name}_{method_name}_statistics.png"
        )
        plots_generated.append(f"{satellite_name}_{method_name}_statistics.png")

        print(f"\n✅ Generated {len(plots_generated)} publication-quality plots:")
        for plot in plots_generated:
            print(f"   • {plot}")

        return plots_generated

__init__(quicklook_processor)

Initialize with a QuickLookProcessor instance from ml module

Source code in ShallowLearn/visualization/quicklook_viz.py
def __init__(self, quicklook_processor):
    """Initialize with a QuickLookProcessor instance from ml module"""
    self.processor = quicklook_processor
    self.transformed_data = quicklook_processor.transformed_data
    self.labels = quicklook_processor.labels
    self.images = quicklook_processor.processed_images
    self.metadata_df = quicklook_processor.metadata_df

    # Generate class dictionary for visualization
    self.class_dict = self._generate_class_dict()

    # For compatibility with existing visualization code
    self.thumbnails = self.images  # processed images act as thumbnails
    self.products = self._create_product_objects()

create_cloudcover_meshgrid(ax, resolution=50, alpha=0.3)

Create a cloud cover background meshgrid

Source code in ShallowLearn/visualization/quicklook_viz.py
def create_cloudcover_meshgrid(self, ax, resolution=50, alpha=0.3):
    """Create a cloud cover background meshgrid"""
    if not self.products:
        return

    # Get data bounds
    x_min, x_max = (
        self.transformed_data[:, 0].min(),
        self.transformed_data[:, 0].max(),
    )
    y_min, y_max = (
        self.transformed_data[:, 1].min(),
        self.transformed_data[:, 1].max(),
    )

    # Create grid
    x_range = x_max - x_min
    y_range = y_max - y_min
    xx, yy = np.meshgrid(
        np.linspace(x_min - x_range * 0.1, x_max + x_range * 0.1, resolution),
        np.linspace(y_min - y_range * 0.1, y_max + y_range * 0.1, resolution),
    )

    # Interpolate cloud cover values to create smooth background
    from scipy.interpolate import griddata

    try:
        # Get cloud cover values and coordinates
        cloud_values = np.array([p.cloud_cover for p in self.products])
        coords = self.transformed_data

        # Interpolate cloud cover across the grid - always use first 2 components
        grid_values = griddata(
            coords[:, :2],  # Always use first 2 components for 2D interpolation
            cloud_values,
            (xx, yy),
            method="linear",  # Linear is more stable for scattered data
            fill_value=np.mean(cloud_values),
        )

        # Create contour plot
        contour = ax.contourf(
            xx, yy, grid_values, levels=20, alpha=alpha, cmap="Blues_r"
        )

        return contour

    except Exception as e:
        print(f"Could not create meshgrid: {e}")
        return None

generate_all_plots(output_dir='publication_plots', dpi=300)

Generate all possible visualization plots

Source code in ShallowLearn/visualization/quicklook_viz.py
def generate_all_plots(self, output_dir="publication_plots", dpi=300):
    """Generate all possible visualization plots"""
    import os

    os.makedirs(output_dir, exist_ok=True)

    plots_generated = []

    # Get method name
    method_name = (
        self.processor.reducer.get_name()
        if hasattr(self.processor, "reducer")
        else "Unknown"
    )
    satellite_name = (
        list(set(p.satellite for p in self.products))[0]
        if self.products
        else "unknown"
    )

    # 1. Basic scatter plot
    fig1, ax1 = self.plot_clusters_scatter(
        save_path=f"{output_dir}/{satellite_name}_{method_name}_scatter.png"
    )
    plots_generated.append(f"{satellite_name}_{method_name}_scatter.png")

    # 2. Thumbnails without meshgrid
    fig2, ax2 = self.plot_publication_quality(
        method_name=method_name.split("(")[0],  # Clean method name
        show_thumbnails=True,
        show_meshgrid=False,
        save_path=f"{output_dir}/{satellite_name}_{method_name}_thumbnails.png",
        dpi=dpi,
    )
    plots_generated.append(f"{satellite_name}_{method_name}_thumbnails.png")

    # 3. Thumbnails with cloud cover meshgrid
    fig3, ax3 = self.plot_publication_quality(
        method_name=method_name.split("(")[0],
        show_thumbnails=True,
        show_meshgrid=True,
        save_path=f"{output_dir}/{satellite_name}_{method_name}_with_meshgrid.png",
        dpi=dpi,
    )
    plots_generated.append(f"{satellite_name}_{method_name}_with_meshgrid.png")

    # 4. Temporal distribution
    if len(self.products) > 3:
        fig4, axes4 = self.plot_temporal_distribution(
            save_path=f"{output_dir}/{satellite_name}_{method_name}_temporal.png"
        )
        plots_generated.append(f"{satellite_name}_{method_name}_temporal.png")

    # 5. Cluster statistics
    fig5, axes5, stats = self.plot_cluster_statistics(
        save_path=f"{output_dir}/{satellite_name}_{method_name}_statistics.png"
    )
    plots_generated.append(f"{satellite_name}_{method_name}_statistics.png")

    print(f"\n✅ Generated {len(plots_generated)} publication-quality plots:")
    for plot in plots_generated:
        print(f"   • {plot}")

    return plots_generated

plot_cluster_statistics(figsize=(12, 8), save_path=None)

Plot statistics about each cluster

Source code in ShallowLearn/visualization/quicklook_viz.py
def plot_cluster_statistics(self, figsize=(12, 8), save_path=None):
    """Plot statistics about each cluster"""
    if not self.products:
        raise ValueError("No products available")

    # Calculate statistics for each cluster
    cluster_stats = {}

    for label in np.unique(self.labels):
        mask = self.labels == label
        cluster_products = [
            self.products[i] for i in range(len(self.products)) if mask[i]
        ]
        cluster_thumbnails = np.array(self.thumbnails)[mask]

        stats = {
            "count": len(cluster_products),
            "mean_cloud_cover": np.mean([p.cloud_cover for p in cluster_products]),
            "std_cloud_cover": np.std([p.cloud_cover for p in cluster_products]),
            "mean_brightness": np.mean(cluster_thumbnails),
            "satellites": {},
        }

        # Count by satellite type
        for product in cluster_products:
            sat_type = product.satellite
            stats["satellites"][sat_type] = stats["satellites"].get(sat_type, 0) + 1

        cluster_name = self.class_dict.get(label, (None, f"Cluster_{label}"))[1]
        cluster_stats[cluster_name] = stats

    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=figsize)

    # Plot 1: Product count by cluster
    names = list(cluster_stats.keys())
    counts = [stats["count"] for stats in cluster_stats.values()]
    colors = [
        self.class_dict.get(label, ("#808080", ""))[0]
        for label in np.unique(self.labels)
    ]

    axes[0, 0].bar(names, counts, color=colors)
    axes[0, 0].set_title("Product Count by Cluster")
    axes[0, 0].set_ylabel("Number of Products")
    plt.setp(axes[0, 0].xaxis.get_majorticklabels(), rotation=45)

    # Plot 2: Mean cloud cover by cluster
    cloud_means = [stats["mean_cloud_cover"] for stats in cluster_stats.values()]
    cloud_stds = [stats["std_cloud_cover"] for stats in cluster_stats.values()]

    axes[0, 1].bar(names, cloud_means, yerr=cloud_stds, color=colors, alpha=0.7)
    axes[0, 1].set_title("Mean Cloud Cover by Cluster")
    axes[0, 1].set_ylabel("Cloud Cover (%)")
    plt.setp(axes[0, 1].xaxis.get_majorticklabels(), rotation=45)

    # Plot 3: Mean brightness by cluster
    brightness_means = [
        stats["mean_brightness"] for stats in cluster_stats.values()
    ]

    axes[1, 0].bar(names, brightness_means, color=colors)
    axes[1, 0].set_title("Mean Thumbnail Brightness by Cluster")
    axes[1, 0].set_ylabel("Brightness (0-255)")
    plt.setp(axes[1, 0].xaxis.get_majorticklabels(), rotation=45)

    # Plot 4: Satellite type distribution
    # Create a stacked bar chart for satellite types
    all_satellites = set()
    for stats in cluster_stats.values():
        all_satellites.update(stats["satellites"].keys())
    all_satellites = list(all_satellites)

    satellite_data = {}
    for sat in all_satellites:
        satellite_data[sat] = [
            stats["satellites"].get(sat, 0) for stats in cluster_stats.values()
        ]

    bottom = np.zeros(len(names))
    for i, sat in enumerate(all_satellites):
        axes[1, 1].bar(
            names, satellite_data[sat], bottom=bottom, label=sat, alpha=0.8
        )
        bottom += satellite_data[sat]

    axes[1, 1].set_title("Satellite Type Distribution by Cluster")
    axes[1, 1].set_ylabel("Number of Products")
    axes[1, 1].legend()
    plt.setp(axes[1, 1].xaxis.get_majorticklabels(), rotation=45)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig, axes, cluster_stats

plot_clusters_scatter(figsize_base=8, save_path=None)

Create scatter plot of clusters in reduced dimensional space

Source code in ShallowLearn/visualization/quicklook_viz.py
def plot_clusters_scatter(self, figsize_base=8, save_path=None):
    """Create scatter plot of clusters in reduced dimensional space"""
    if self.transformed_data is None:
        raise ValueError(
            "No transformed data available. Run process_products first."
        )

    fig, ax = create_square_figure(figsize_base)

    # Get colors and labels for legend
    unique_labels = np.unique(self.labels)
    colors = []
    legend_labels = []

    for label in unique_labels:
        color, name = self.class_dict.get(label, ("#808080", f"Cluster_{label}"))
        colors.append(color)
        legend_labels.append(f"{name} ({np.sum(self.labels == label)})")

    # Create custom colormap
    cmap = ListedColormap(colors)

    # Plot scatter
    scatter = ax.scatter(
        self.transformed_data[:, 0],
        self.transformed_data[:, 1],
        c=self.labels,
        cmap=cmap,
        s=50,
        alpha=0.7,
    )

    # Add colorbar with custom labels
    cbar = plt.colorbar(scatter, ticks=unique_labels)
    cbar.set_ticklabels(legend_labels)

    # Set labels
    method_name = self.processor.reducer.get_name()
    ax.set_xlabel(f"{method_name} Component 1", fontsize=12)
    ax.set_ylabel(f"{method_name} Component 2", fontsize=12)
    ax.set_title(f"QuickLook Clustering Results ({method_name})", fontsize=14)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig, ax

plot_publication_quality(method_name='PCA', show_thumbnails=True, show_meshgrid=True, thumbnail_sample='cluster', figsize=(7, 4), dpi=300, save_path=None)

Create publication-quality plot with all enhancements

Parameters:

Name Type Description Default
method_name

Name of dimensionality reduction method

'PCA'
show_thumbnails

Whether to overlay thumbnails

True
show_meshgrid

Whether to show cloud cover background

True
thumbnail_sample

'cluster', 'random', or 'all'

'cluster'
figsize

Figure size

(7, 4)
dpi

Resolution for saving

300
save_path

Path to save the plot

None
Source code in ShallowLearn/visualization/quicklook_viz.py
def plot_publication_quality(
    self,
    method_name="PCA",
    show_thumbnails=True,
    show_meshgrid=True,
    thumbnail_sample="cluster",
    figsize=(7, 4),
    dpi=300,
    save_path=None,
):  # Figure size for A4 page
    """
    Create publication-quality plot with all enhancements

    Args:
        method_name: Name of dimensionality reduction method
        show_thumbnails: Whether to overlay thumbnails
        show_meshgrid: Whether to show cloud cover background
        thumbnail_sample: 'cluster', 'random', or 'all'
        figsize: Figure size
        dpi: Resolution for saving
        save_path: Path to save the plot
    """
    if self.transformed_data is None:
        raise ValueError("No data available")

    # Create figure with high DPI
    fig, ax = plt.subplots(
        figsize=figsize, dpi=100
    )  # matplotlib will scale for save

    # Add cloud cover meshgrid background if requested
    if show_meshgrid:
        contour = self.create_cloudcover_meshgrid(ax, alpha=0.2)
        if contour:
            # Add colorbar for cloud cover
            cbar = plt.colorbar(contour, ax=ax, shrink=0.8, pad=0.02)
            cbar.set_label("Cloud Cover (%)", fontsize=12, fontweight="bold")
            cbar.ax.tick_params(labelsize=10)

    # Get unique labels and colors
    unique_labels = np.unique(self.labels)
    colors = [
        "#2ca02c",
        "#ff7f0e",
        "#d62728",
        "#1f77b4",
        "#9467bd",
        "#8c564b",
        "#e377c2",
    ]

    # Plot scatter points
    scatter_handles = []
    for i, label in enumerate(unique_labels):
        mask = self.labels == label
        label_name = self.class_dict.get(label, (None, f"Cluster {label}"))[1]

        scatter = ax.scatter(
            self.transformed_data[mask, 0],
            self.transformed_data[mask, 1],
            c=colors[i % len(colors)],
            label=f"{label_name} ({sum(mask)})",
            s=80 if not show_thumbnails else 40,
            alpha=0.8 if not show_thumbnails else 0.6,
            edgecolors="white",
            linewidth=1.5,
            zorder=3,
        )
        scatter_handles.append(scatter)

    # Add thumbnails if requested
    if show_thumbnails and self.thumbnails:
        # Determine which thumbnails to show
        if thumbnail_sample == "cluster":
            # Sample evenly from each cluster
            indices = []
            max_per_cluster = max(1, 12 // len(unique_labels))  # Up to 12 total
            for label in unique_labels:
                mask = self.labels == label
                label_indices = np.where(mask)[0]
                n_samples = min(max_per_cluster, len(label_indices))
                if n_samples > 0:
                    sampled = np.random.choice(
                        label_indices, n_samples, replace=False
                    )
                    indices.extend(sampled)
        elif thumbnail_sample == "all":
            indices = range(len(self.thumbnails))
        else:  # random
            n_images = min(len(self.thumbnails), 12)
            indices = np.random.choice(
                len(self.thumbnails), n_images, replace=False
            )

        zoom = 0.12  # Larger thumbnails for better visibility
        for i in indices:
            try:
                thumbnail = self.thumbnails[i]

                # Resize thumbnail
                resized_img = resize(
                    thumbnail,
                    output_shape=(
                        int(thumbnail.shape[0] * zoom),
                        int(thumbnail.shape[1] * zoom),
                    ),
                    anti_aliasing=True,
                    preserve_range=True,
                ).astype(thumbnail.dtype)

                # Get cluster color for border
                label = self.labels[i]
                label_idx = list(unique_labels).index(label)
                border_color = colors[label_idx % len(colors)]

                imagebox = OffsetImage(resized_img, zoom=zoom)
                ab = AnnotationBbox(
                    imagebox,
                    (self.transformed_data[i, 0], self.transformed_data[i, 1]),
                    frameon=True,
                    bboxprops=dict(
                        edgecolor=border_color,
                        linewidth=2,
                        facecolor="white",
                        alpha=0.9,
                    ),
                    zorder=4,
                )
                ax.add_artist(ab)
            except Exception as e:
                print(f"Error adding thumbnail {i}: {e}")

    # Set axis limits with proper padding
    x_min, x_max = (
        self.transformed_data[:, 0].min(),
        self.transformed_data[:, 0].max(),
    )
    y_min, y_max = (
        self.transformed_data[:, 1].min(),
        self.transformed_data[:, 1].max(),
    )

    x_range = x_max - x_min
    y_range = y_max - y_min
    padding_x = x_range * 0.15
    padding_y = y_range * 0.15

    ax.set_xlim(x_min - padding_x, x_max + padding_x)
    ax.set_ylim(y_min - padding_y, y_max + padding_y)

    # Publication-quality formatting
    ax.set_xlabel(f"{method_name} Component 1", fontsize=14, fontweight="bold")
    ax.set_ylabel(f"{method_name} Component 2", fontsize=14, fontweight="bold")

    # Determine satellite type for title
    satellite_types = set(p.satellite for p in self.products)
    if len(satellite_types) == 1:
        sat_name = list(satellite_types)[0].title()
    else:
        sat_name = "Multi-satellite"

    title = f"{sat_name} Thumbnail Clustering ({method_name})"
    if show_meshgrid:
        title += "\nwith Cloud Cover Background"

    ax.set_title(title, fontsize=16, fontweight="bold", pad=20)

    # Enhanced legend
    legend = ax.legend(
        handles=scatter_handles,
        loc="upper right",
        fontsize=11,
        frameon=True,
        fancybox=True,
        shadow=True,
        framealpha=0.9,
        edgecolor="black",
    )
    legend.get_frame().set_linewidth(1.5)

    # Grid and styling
    ax.grid(True, alpha=0.3, linestyle="--", linewidth=0.8)
    ax.tick_params(axis="both", labelsize=11, width=1.2, length=4)

    # Ensure square aspect ratio for better comparison
    ax.set_aspect("equal", adjustable="box")

    plt.tight_layout()

    if save_path:
        plt.savefig(
            save_path,
            dpi=dpi,
            bbox_inches="tight",
            facecolor="white",
            edgecolor="none",
        )
        print(f"✅ Saved: {save_path}")

    return fig, ax

plot_temporal_distribution(figsize=(14, 8), save_path=None)

Plot temporal distribution of products by cluster

Source code in ShallowLearn/visualization/quicklook_viz.py
def plot_temporal_distribution(self, figsize=(14, 8), save_path=None):
    """Plot temporal distribution of products by cluster"""
    if not self.products:
        raise ValueError("No products available")

    # Extract dates from products
    dates = []
    cloud_covers = []

    for product in self.products:
        try:
            # Handle different date formats
            date_str = product.acquisition_date
            if "T" in date_str:  # ISO format
                date = pd.to_datetime(date_str.split("T")[0])
            else:
                date = pd.to_datetime(date_str.split()[0])  # Split by space
            dates.append(date)
            cloud_covers.append(product.cloud_cover)
        except Exception as e:
            print(f"Error parsing date for {product.product_id}: {e}")
            continue

    if not dates:
        print("No valid dates found in products")
        return None, None

    # Create DataFrame for easier plotting
    df = pd.DataFrame(
        {
            "date": dates,
            "cloud_cover": cloud_covers,
            "label": self.labels[
                : len(dates)
            ],  # In case some dates failed to parse
        }
    )

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)

    # Plot 1: Temporal distribution by cluster
    unique_labels = np.unique(self.labels)
    for label in unique_labels:
        cluster_data = df[df["label"] == label]
        if len(cluster_data) > 0:
            color, name = self.class_dict.get(
                label, ("#808080", f"Cluster_{label}")
            )
            ax1.scatter(
                cluster_data["date"],
                cluster_data["cloud_cover"],
                c=color,
                label=f"{name} ({len(cluster_data)})",
                alpha=0.7,
                s=50,
            )

    ax1.set_xlabel("Date", fontsize=12)
    ax1.set_ylabel("Cloud Cover (%)", fontsize=12)
    ax1.set_title("Temporal Distribution of Products by Cluster", fontsize=14)
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Format dates on x-axis
    ax1.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
    ax1.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)

    # Plot 2: Monthly product count by cluster
    df["month"] = df["date"].dt.to_period("M")
    monthly_counts = df.groupby(["month", "label"]).size().unstack(fill_value=0)

    # Create stacked bar plot
    colors_dict = {
        label: self.class_dict.get(label, ("#808080", f"Cluster_{label}"))[0]
        for label in unique_labels
    }

    monthly_counts.plot(
        kind="bar",
        stacked=True,
        ax=ax2,
        color=[colors_dict[label] for label in monthly_counts.columns],
    )

    ax2.set_xlabel("Month", fontsize=12)
    ax2.set_ylabel("Number of Products", fontsize=12)
    ax2.set_title("Monthly Product Count by Cluster", fontsize=14)
    ax2.legend(title="Cluster")
    ax2.grid(True, alpha=0.3)
    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig, (ax1, ax2)

plot_thumbnails_on_scatter(show_points=True, show_thumbnails=True, zoom=0.1, figsize_base=10, max_images=50, sample_method='random', add_borders=True, save_path=None)

Plot thumbnail images on their cluster coordinates with flexible options

Parameters:

Name Type Description Default
show_points

Whether to show scatter points underneath thumbnails

True
show_thumbnails

Whether to show thumbnail images

True
zoom

Zoom level for thumbnails (0.05-0.2 recommended)

0.1
figsize_base

Base size for square figure

10
max_images

Maximum number of thumbnails to show (None = all)

50
sample_method

'random', 'cluster' (sample from each cluster), or 'all'

'random'
add_borders

Whether to add colored borders to thumbnails matching clusters

True
save_path

Path to save the plot

None
Source code in ShallowLearn/visualization/quicklook_viz.py
def plot_thumbnails_on_scatter(
    self,
    show_points=True,
    show_thumbnails=True,
    zoom=0.1,
    figsize_base=10,
    max_images=50,
    sample_method="random",
    add_borders=True,
    save_path=None,
):
    """
    Plot thumbnail images on their cluster coordinates with flexible options

    Args:
        show_points: Whether to show scatter points underneath thumbnails
        show_thumbnails: Whether to show thumbnail images
        zoom: Zoom level for thumbnails (0.05-0.2 recommended)
        figsize_base: Base size for square figure
        max_images: Maximum number of thumbnails to show (None = all)
        sample_method: 'random', 'cluster' (sample from each cluster), or 'all'
        add_borders: Whether to add colored borders to thumbnails matching clusters
        save_path: Path to save the plot
    """
    if self.transformed_data is None:
        raise ValueError("No data available. Run process_products first.")

    fig, ax = create_square_figure(figsize_base)

    # Get unique labels and colors
    unique_labels = np.unique(self.labels)
    colors = ["#2ca02c", "#ff7f0e", "#d62728", "#1f77b4", "#808080"]

    # Plot scatter points if requested
    if show_points:
        for i, label in enumerate(unique_labels):
            mask = self.labels == label
            label_name = self.class_dict.get(label, (None, f"Cluster {label}"))[1]

            ax.scatter(
                self.transformed_data[mask, 0],
                self.transformed_data[mask, 1],
                c=colors[i % len(colors)],
                label=f"{label_name} ({sum(mask)})",
                s=50 if not show_thumbnails else 20,
                alpha=0.7 if not show_thumbnails else 0.3,
                edgecolors="black" if not show_thumbnails else "none",
                linewidth=0.5 if not show_thumbnails else 0,
            )

    # Add thumbnails if requested
    if show_thumbnails and self.thumbnails:
        # Determine which thumbnails to show
        if sample_method == "all" or max_images is None:
            indices = range(len(self.thumbnails))
        elif sample_method == "cluster":
            # Sample evenly from each cluster
            indices = []
            for label in unique_labels:
                mask = self.labels == label
                label_indices = np.where(mask)[0]
                n_samples = min(
                    max_images // len(unique_labels), len(label_indices)
                )
                if n_samples > 0:
                    sampled = np.random.choice(
                        label_indices, n_samples, replace=False
                    )
                    indices.extend(sampled)
        else:  # random
            n_images = min(len(self.thumbnails), max_images)
            indices = np.random.choice(
                len(self.thumbnails), n_images, replace=False
            )

        # Add thumbnails
        for i in indices:
            try:
                thumbnail = self.thumbnails[i]
                original_dtype = thumbnail.dtype

                # Resize thumbnail
                shape = thumbnail.shape
                resized_img = resize(
                    thumbnail,
                    output_shape=(int(shape[0] * zoom), int(shape[1] * zoom)),
                    anti_aliasing=True,
                    preserve_range=True,
                ).astype(original_dtype)

                # Create image box
                imagebox = OffsetImage(resized_img, zoom=zoom)

                # Add border if requested
                if add_borders:
                    label = self.labels[i]
                    label_idx = list(unique_labels).index(label)
                    border_color = colors[label_idx % len(colors)]
                    ab = AnnotationBbox(
                        imagebox,
                        (self.transformed_data[i, 0], self.transformed_data[i, 1]),
                        frameon=True,
                        bboxprops=dict(edgecolor=border_color, linewidth=2),
                    )
                else:
                    ab = AnnotationBbox(
                        imagebox,
                        (self.transformed_data[i, 0], self.transformed_data[i, 1]),
                        frameon=False,
                    )

                ax.add_artist(ab)
            except Exception as e:
                print(f"Error adding thumbnail {i}: {e}")

    # Set axis limits with some padding
    x_min, x_max = (
        self.transformed_data[:, 0].min(),
        self.transformed_data[:, 0].max(),
    )
    y_min, y_max = (
        self.transformed_data[:, 1].min(),
        self.transformed_data[:, 1].max(),
    )

    x_padding = (x_max - x_min) * 0.1
    y_padding = (y_max - y_min) * 0.1

    ax.set_xlim(x_min - x_padding, x_max + x_padding)
    ax.set_ylim(y_min - y_padding, y_max + y_padding)

    # Labels and legend
    method_name = (
        self.processor.reducer.get_name()
        if hasattr(self.processor, "reducer")
        else "Reduced"
    )
    ax.set_xlabel(f"{method_name} Component 1", fontsize=12)
    ax.set_ylabel(f"{method_name} Component 2", fontsize=12)

    # Title based on options
    if show_thumbnails and self.thumbnails:
        n_shown = len(list(indices))
        title = f"Satellite Thumbnails in {method_name} Space ({n_shown}/{len(self.thumbnails)} shown)"
    else:
        title = f"Satellite Clustering in {method_name} Space"
    ax.set_title(title, fontsize=14)

    if show_points:
        ax.legend(loc="best")

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig, ax

add_north_arrow_to_axis(ax, relative_position=(0.05, 0.05), arrow_length=0.05, text_offset=-0.02)

Adds a north arrow to an axis.

Parameters:

ax : plt.Axes Matplotlib axis object relative_position : Tuple[float, float], default=(0.05, 0.05) Relative position of the arrow (0-1 range) arrow_length : float, default=0.05 Length of the arrow relative to axis size text_offset : float, default=-0.02 Text offset relative to axis size

Source code in ShallowLearn/visualization/display.py
def add_north_arrow_to_axis(ax: plt.Axes, 
                           relative_position: Tuple[float, float] = (0.05, 0.05),
                           arrow_length: float = 0.05,
                           text_offset: float = -0.02) -> None:
    """
    Adds a north arrow to an axis.

    Parameters:
    -----------
    ax : plt.Axes
        Matplotlib axis object
    relative_position : Tuple[float, float], default=(0.05, 0.05)
        Relative position of the arrow (0-1 range)
    arrow_length : float, default=0.05
        Length of the arrow relative to axis size
    text_offset : float, default=-0.02
        Text offset relative to axis size
    """
    xlim, ylim = ax.get_xlim(), ax.get_ylim()

    x = xlim[0] + (xlim[1] - xlim[0]) * relative_position[0]
    y = ylim[0] + (ylim[1] - ylim[0]) * relative_position[1]

    ax.arrow(x, y, 0, arrow_length * (ylim[1] - ylim[0]), 
            head_width=0.02 * (xlim[1] - xlim[0]), 
            head_length=0.03 * (ylim[1] - ylim[0]), 
            fc='black', ec='black')
    ax.text(x, y + text_offset * (ylim[1] - ylim[0]), 'N', 
           horizontalalignment='center', verticalalignment='center', 
           fontsize=12, fontweight='bold', color='black')

create_rgb_image(img, band_indices, stretch=True)

Creates an RGB image from multispectral data using specified band indices.

Parameters:

img : np.ndarray Input image array with shape (height, width, bands) band_indices : List[int] List of 3 band indices for R, G, B channels stretch : bool, default=True Whether to apply min-max stretch to each channel

Returns:

np.ndarray RGB image array with shape (height, width, 3) and dtype uint8

Source code in ShallowLearn/visualization/display.py
def create_rgb_image(img: np.ndarray, 
                    band_indices: List[int], 
                    stretch: bool = True) -> np.ndarray:
    """
    Creates an RGB image from multispectral data using specified band indices.

    Parameters:
    -----------
    img : np.ndarray
        Input image array with shape (height, width, bands)
    band_indices : List[int]
        List of 3 band indices for R, G, B channels
    stretch : bool, default=True
        Whether to apply min-max stretch to each channel

    Returns:
    --------
    np.ndarray
        RGB image array with shape (height, width, 3) and dtype uint8
    """
    if len(band_indices) != 3:
        raise ValueError("Exactly 3 band indices required for RGB")

    img_shape = img.shape
    rgb_channels = []

    for band_idx in band_indices:
        if band_idx >= img_shape[2]:
            raise ValueError(f"Band index {band_idx} out of bounds for image with {img_shape[2]} bands")

        channel = img[:, :, band_idx].astype(float)

        if stretch:
            # Apply min-max stretch
            channel_flat = channel.flatten()
            channel_stretched = minmax_scale(channel_flat, feature_range=(0, 255), copy=True)
            channel = channel_stretched.reshape(img_shape[0], img_shape[1])
        else:
            # Simple scaling to 0-255 range
            channel = np.clip(channel * 255 / np.max(channel), 0, 255)

        rgb_channels.append(channel.astype(np.uint8))

    return np.dstack(rgb_channels)

plot_color_space(img, color_space='hsv', band_indices=None, band_mapping=None, band_names=None, plot=False, title=None, figsize=(10, 8))

Convert image to different color spaces with flexible band selection.

This function replaces ImageHelper functions like plot_hsv, plot_lab, plot_ycbcr with a unified interface.

Parameters:

img : np.ndarray Input image array with shape (height, width, bands) color_space : str, default='hsv' Target color space ('hsv', 'lab', 'ycbcr') band_indices : List[int], optional List of 3 band indices for R, G, B channels used in conversion band_mapping : Dict, optional Band mapping dictionary for converting band names to indices
band_names : List[str], optional List of 3 band names to use with band_mapping plot : bool, default=False Whether to display the converted image title : str, optional Title for the plot. If None, auto-generated based on color_space figsize : Tuple[int, int], default=(10, 8) Figure size if plotting

Returns:

np.ndarray or None Converted image array if plot=False, otherwise None

Raises:

ValueError If color_space is not supported

Source code in ShallowLearn/visualization/display.py
def plot_color_space(
    img: np.ndarray,
    color_space: str = 'hsv',
    band_indices: Optional[List[int]] = None,
    band_mapping: Optional[Dict] = None,
    band_names: Optional[List[str]] = None,
    plot: bool = False,
    title: Optional[str] = None,
    figsize: Tuple[int, int] = (10, 8)
) -> Union[np.ndarray, None]:
    """
    Convert image to different color spaces with flexible band selection.

    This function replaces ImageHelper functions like plot_hsv, plot_lab, plot_ycbcr
    with a unified interface.

    Parameters:
    -----------
    img : np.ndarray
        Input image array with shape (height, width, bands)
    color_space : str, default='hsv'
        Target color space ('hsv', 'lab', 'ycbcr')
    band_indices : List[int], optional
        List of 3 band indices for R, G, B channels used in conversion
    band_mapping : Dict, optional
        Band mapping dictionary for converting band names to indices  
    band_names : List[str], optional
        List of 3 band names to use with band_mapping
    plot : bool, default=False
        Whether to display the converted image
    title : str, optional
        Title for the plot. If None, auto-generated based on color_space
    figsize : Tuple[int, int], default=(10, 8)
        Figure size if plotting

    Returns:
    --------
    np.ndarray or None
        Converted image array if plot=False, otherwise None

    Raises:
    -------
    ValueError
        If color_space is not supported
    """
    # First create RGB image
    rgb_img = plot_rgb_enhanced(
        img, 
        band_indices=band_indices,
        band_mapping=band_mapping,
        band_names=band_names,
        stretch=True,
        plot=False
    )

    # Convert to requested color space
    if color_space.lower() == 'hsv':
        converted_img = rgb2hsv(rgb_img)
        default_title = "HSV Color Space"
    elif color_space.lower() == 'lab':
        converted_img = rgb2lab(rgb_img)  
        default_title = "LAB Color Space"
    elif color_space.lower() == 'ycbcr':
        converted_img = rgb2ycbcr(rgb_img)
        default_title = "YCbCr Color Space"
    else:
        raise ValueError(f"Unsupported color space: {color_space}")

    if plot:
        if title is None:
            title = default_title

        fig, axes = plt.subplots(1, 3, figsize=figsize)
        fig.suptitle(title)

        channel_names = {
            'hsv': ['Hue', 'Saturation', 'Value'],
            'lab': ['Lightness', 'A*', 'B*'],
            'ycbcr': ['Luma', 'Chroma Blue', 'Chroma Red']
        }

        names = channel_names.get(color_space.lower(), ['Channel 1', 'Channel 2', 'Channel 3'])

        for i in range(3):
            axes[i].imshow(converted_img[:, :, i], cmap='gray')
            axes[i].set_title(names[i])
            axes[i].axis('off')

        plt.tight_layout()
        plt.show()
        return None

    return converted_img

plot_discrete_image(arr, value_labels=None, colors=None, pixel_scale=10, title='Discrete Image', figsize=(10, 8), show=True)

Plots a discrete array with custom colors and labels.

Parameters:

arr : np.ndarray Input discrete array value_labels : Dict, optional Dictionary mapping values to labels colors : List, optional List of colors for each unique value pixel_scale : float, default=10 Scale for the scale bar (pixels per km) title : str, default="Discrete Image" Title for the plot figsize : Tuple[int, int], default=(10, 8) Figure size show : bool, default=True Whether to display the plot

Returns:

plt.Figure or None Figure object if show=False, otherwise None

Source code in ShallowLearn/visualization/display.py
def plot_discrete_image(arr: np.ndarray, 
                       value_labels: Optional[Dict] = None,
                       colors: Optional[List] = None,
                       pixel_scale: float = 10,
                       title: str = "Discrete Image",
                       figsize: Tuple[int, int] = (10, 8),
                       show: bool = True) -> Optional[plt.Figure]:
    """
    Plots a discrete array with custom colors and labels.

    Parameters:
    -----------
    arr : np.ndarray
        Input discrete array
    value_labels : Dict, optional
        Dictionary mapping values to labels
    colors : List, optional
        List of colors for each unique value
    pixel_scale : float, default=10
        Scale for the scale bar (pixels per km)
    title : str, default="Discrete Image"
        Title for the plot
    figsize : Tuple[int, int], default=(10, 8)
        Figure size
    show : bool, default=True
        Whether to display the plot

    Returns:
    --------
    plt.Figure or None
        Figure object if show=False, otherwise None
    """
    if len(arr.shape) == 1:
        arr = arr.reshape(-1, 1)

    unique_labels = np.unique(arr)
    num_labels = len(unique_labels)

    # Create label to integer mapping
    label_to_int = {label: i for i, label in enumerate(unique_labels)}
    int_arr = np.vectorize(label_to_int.get)(arr)

    # Create colormap
    if colors is None:
        colors = plt.get_cmap('viridis')(np.linspace(0, 1, num_labels))
    elif len(colors) < num_labels:
        # Extend colors if not enough provided
        base_colors = plt.get_cmap('viridis')(np.linspace(0, 1, num_labels))
        for i, color in enumerate(colors):
            base_colors[i] = to_rgba(color)
        colors = base_colors
    else:
        colors = [to_rgba(c) for c in colors[:num_labels]]

    cmap = ListedColormap(colors)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(int_arr, cmap=cmap)

    # Create colorbar
    cbar = fig.colorbar(im, ticks=np.arange(num_labels), drawedges=True)
    cbar.set_label('Labels')

    # Set tick labels
    if value_labels:
        tick_labels = [value_labels.get(label, str(label)) for label in unique_labels]
    else:
        tick_labels = [str(label) for label in unique_labels]
    cbar.set_ticklabels(tick_labels)

    # Add scale bar
    scalebar = AnchoredSizeBar(ax.transData,
                              10 * pixel_scale, '1 km', 'lower right',
                              pad=0.25,
                              color='white',
                              frameon=False,
                              size_vertical=1)
    ax.add_artist(scalebar)

    ax.set_title(title)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')

    if show:
        plt.show()
        return None
    return fig

plot_histogram(img, channels=None, bins=50, min_value=1, channel_names=None, title='Histogram', figsize=(10, 6), show=True)

Plots histograms for specified channels.

Parameters:

img : np.ndarray Input image array channels : List[int], optional List of channel indices to plot. If None, plots all channels bins : int, default=50 Number of bins for histogram min_value : float, default=1 Minimum value threshold for filtering channel_names : List[str], optional Names for channels in legend title : str, default="Histogram" Title for the plot figsize : Tuple[int, int], default=(10, 6) Figure size show : bool, default=True Whether to display the plot

Returns:

plt.Figure or None Figure object if show=False, otherwise None

Source code in ShallowLearn/visualization/display.py
def plot_histogram(img: np.ndarray, 
                  channels: Optional[List[int]] = None,
                  bins: int = 50, 
                  min_value: float = 1,
                  channel_names: Optional[List[str]] = None,
                  title: str = "Histogram",
                  figsize: Tuple[int, int] = (10, 6),
                  show: bool = True) -> Optional[plt.Figure]:
    """
    Plots histograms for specified channels.

    Parameters:
    -----------
    img : np.ndarray
        Input image array
    channels : List[int], optional
        List of channel indices to plot. If None, plots all channels
    bins : int, default=50
        Number of bins for histogram
    min_value : float, default=1
        Minimum value threshold for filtering
    channel_names : List[str], optional
        Names for channels in legend
    title : str, default="Histogram"
        Title for the plot
    figsize : Tuple[int, int], default=(10, 6)
        Figure size
    show : bool, default=True
        Whether to display the plot

    Returns:
    --------
    plt.Figure or None
        Figure object if show=False, otherwise None
    """
    if len(img.shape) == 2:
        # Single channel image
        img = img[:, :, np.newaxis]

    if channels is None:
        channels = list(range(img.shape[2]))

    fig, ax = plt.subplots(figsize=figsize)

    x = np.linspace(0, np.max(img), bins)

    for i, channel_idx in enumerate(channels):
        if channel_idx >= img.shape[2]:
            continue

        channel_data = img[:, :, channel_idx].flatten()
        channel_data = channel_data[channel_data >= min_value]

        if len(channel_data) == 0:
            continue

        histogram, _ = np.histogram(channel_data, bins=bins, range=(0, np.max(img)))

        # Use channel names if provided
        if channel_names and i < len(channel_names):
            label = channel_names[i]
        else:
            label = f'Channel {channel_idx + 1}'

        ax.plot(x, histogram, label=label, alpha=0.7)

    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)

    if show:
        plt.show()
        return None
    return fig

plot_rgb_enhanced(img, band_indices=None, band_mapping=None, band_names=None, stretch=True, plot=False, title='RGB Image', figsize=(10, 8))

Enhanced RGB plotting function with flexible band selection and reduced hardcoding.

This function replaces ImageHelper.plot_rgb with improved flexibility and reduced dependency on hardcoded band mappings.

Parameters:

img : np.ndarray Input image array with shape (height, width, bands) band_indices : List[int], optional List of 3 band indices for R, G, B channels. If None, defaults to [3, 2, 1] (which corresponds to typical Red, Green, Blue for Sentinel-2) band_mapping : Dict, optional Band mapping dictionary for converting band names to indices band_names : List[str], optional
List of 3 band names (e.g., ['B04', 'B03', 'B02']) to use with band_mapping stretch : bool, default=True Whether to apply min-max stretch to enhance contrast plot : bool, default=False Whether to display the image plot using matplotlib title : str, default="RGB Image" Title for the plot figsize : Tuple[int, int], default=(10, 8) Figure size if plotting

Returns:

np.ndarray or None RGB image array with shape (height, width, 3) and dtype uint8 if plot=False, otherwise None

Source code in ShallowLearn/visualization/display.py
def plot_rgb_enhanced(
    img: np.ndarray, 
    band_indices: Optional[List[int]] = None,
    band_mapping: Optional[Dict] = None,
    band_names: Optional[List[str]] = None,
    stretch: bool = True,
    plot: bool = False,
    title: str = "RGB Image",
    figsize: Tuple[int, int] = (10, 8)
) -> Union[np.ndarray, None]:
    """
    Enhanced RGB plotting function with flexible band selection and reduced hardcoding.

    This function replaces ImageHelper.plot_rgb with improved flexibility and reduced 
    dependency on hardcoded band mappings.

    Parameters:
    -----------
    img : np.ndarray
        Input image array with shape (height, width, bands)
    band_indices : List[int], optional
        List of 3 band indices for R, G, B channels. If None, defaults to [3, 2, 1] 
        (which corresponds to typical Red, Green, Blue for Sentinel-2)
    band_mapping : Dict, optional
        Band mapping dictionary for converting band names to indices
    band_names : List[str], optional  
        List of 3 band names (e.g., ['B04', 'B03', 'B02']) to use with band_mapping
    stretch : bool, default=True
        Whether to apply min-max stretch to enhance contrast
    plot : bool, default=False
        Whether to display the image plot using matplotlib
    title : str, default="RGB Image"
        Title for the plot
    figsize : Tuple[int, int], default=(10, 8)
        Figure size if plotting

    Returns:
    --------
    np.ndarray or None
        RGB image array with shape (height, width, 3) and dtype uint8 if plot=False,
        otherwise None
    """
    # Determine band indices
    if band_indices is None:
        if band_names and band_mapping:
            # Use band mapping to convert names to indices
            band_indices = [band_mapping[band]['index'] for band in band_names]
        elif band_names is None and band_mapping is None:
            # Default to typical RGB bands for Sentinel-2 (B04=Red, B03=Green, B02=Blue)
            band_indices = [3, 2, 1]  # Assuming 0-indexed bands
        else:
            raise ValueError("Either band_indices or both band_names and band_mapping must be provided")

    if len(band_indices) != 3:
        raise ValueError("Exactly 3 band indices required for RGB")

    # Validate band indices
    for idx in band_indices:
        if idx >= img.shape[2]:
            raise ValueError(f"Band index {idx} out of bounds for image with {img.shape[2]} bands")

    img_shape = img.shape
    rgb_channels = []

    # Extract and process each channel
    for band_idx in band_indices:
        channel = img[:, :, band_idx].astype(float)

        if stretch:
            # Apply min-max stretch
            channel = minmax_scale(
                channel.flatten(), 
                feature_range=(0, 255), 
                axis=0, 
                copy=True
            ).reshape(img_shape[0], img_shape[1])
        else:
            # Simple clipping to 0-255 range
            channel = np.clip(channel, 0, 255)

        rgb_channels.append(np.uint8(channel))

    # Stack channels to create RGB image
    rgb = np.dstack(rgb_channels)

    if plot:
        plt.figure(figsize=figsize)
        plt.imshow(rgb)
        plt.title(title)
        plt.axis('off')
        plt.show()
        return None

    return rgb

plot_with_legend(array, value_dict, title='Classified Image', figsize=(10, 8), show=True)

Plots a 2D array with a legend using distinct colors for discrete class labels.

Parameters:

array : np.ndarray 2D array to be plotted value_dict : Dict Dictionary mapping values in the array to labels title : str, default="Classified Image" Title for the plot figsize : Tuple[int, int], default=(10, 8) Figure size show : bool, default=True Whether to display the plot

Returns:

plt.Figure or None Figure object if show=False, otherwise None

Source code in ShallowLearn/visualization/display.py
def plot_with_legend(array: np.ndarray, 
                    value_dict: Dict,
                    title: str = "Classified Image",
                    figsize: Tuple[int, int] = (10, 8),
                    show: bool = True) -> Optional[plt.Figure]:
    """
    Plots a 2D array with a legend using distinct colors for discrete class labels.

    Parameters:
    -----------
    array : np.ndarray
        2D array to be plotted
    value_dict : Dict
        Dictionary mapping values in the array to labels
    title : str, default="Classified Image"
        Title for the plot
    figsize : Tuple[int, int], default=(10, 8)
        Figure size
    show : bool, default=True
        Whether to display the plot

    Returns:
    --------
    plt.Figure or None
        Figure object if show=False, otherwise None
    """
    n_classes = len(value_dict)
    cmap = plt.cm.get_cmap('Set3', n_classes)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(array, cmap=cmap)

    # Create color map index for each discrete value
    colors = [cmap(i) for i in range(n_classes)]

    # Create legend patches
    patches = [mpatches.Patch(color=colors[i], label=label) 
              for i, (value, label) in enumerate(value_dict.items())]

    # Add legend
    ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    ax.set_title(title)

    if show:
        plt.show()
        return None
    return fig

Utilities Module

ShallowLearn utilities module Contains cross-cutting utility functions for file operations, etc.

find_files_in_directory(directory, max_files=5)

Find Sentinel-2 files in directory.

Parameters:

Name Type Description Default
directory str

Directory path to search

required
max_files int

Maximum number of files to return

5

Returns:

Type Description
List[str]

List of file paths as strings

Source code in ShallowLearn/utilities/file_discovery.py
def find_files_in_directory(directory: str, max_files: int = 5) -> List[str]:
    """
    Find Sentinel-2 files in directory.

    Args:
        directory: Directory path to search
        max_files: Maximum number of files to return

    Returns:
        List of file paths as strings
    """
    directory = Path(directory)
    if not directory.exists():
        print(f"❌ Directory not found: {directory}")
        return []

    # Look for .SAFE directories and .zip files
    safe_files = list(directory.glob("*.SAFE"))[:max_files]
    zip_files = list(directory.glob("*.zip"))[:max_files]
    all_files = safe_files + zip_files

    print(f"Found {len(safe_files)} .SAFE and {len(zip_files)} .zip files (using {len(all_files)})")
    return [str(f) for f in all_files]

find_matching_files_by_date(list1, list2)

Match files from two lists based on identical acquisition dates. Typically used for matching L1C and L2A Sentinel-2 files.

Parameters:

Name Type Description Default
list1 List[str]

First list of file paths

required
list2 List[str]

Second list of file paths

required

Returns:

Type Description
List[Tuple[datetime, str, str]]

List of tuples containing (date, file1, file2) for matching dates

Source code in ShallowLearn/utilities/file_discovery.py
def find_matching_files_by_date(list1: List[str], list2: List[str]) -> List[Tuple[datetime, str, str]]:
    """
    Match files from two lists based on identical acquisition dates.
    Typically used for matching L1C and L2A Sentinel-2 files.

    Args:
        list1: First list of file paths
        list2: Second list of file paths

    Returns:
        List of tuples containing (date, file1, file2) for matching dates
    """
    def extract_date(file: str, prefix: str) -> Optional[datetime]:
        pattern = re.compile(rf"{prefix}(\d{{8}}T\d{{6}})")
        if match := pattern.search(file):
            return datetime.strptime(match.group(1), "%Y%m%dT%H%M%S")
        return None

    # Extract dates with validation
    l1_dates = [(extract_date(f, "MSIL1C_"), f) for f in list1]
    l2_dates = [(extract_date(f, "MSIL2A_"), f) for f in list2]

    # Filter invalid filenames and create date-keyed dictionaries
    valid_l1 = {d.date(): (d, f) for d, f in l1_dates if d}
    valid_l2 = {d.date(): (d, f) for d, f in l2_dates if d}

    # Find common dates
    common_dates = valid_l1.keys() & valid_l2.keys()

    return [
        (
            valid_l1[date][0],  # datetime object
            valid_l1[date][1],  # L1C path
            valid_l2[date][1],  # L2A path
        )
        for date in sorted(common_dates)
    ]

process_reef_data(files, reef_gdf, reef_indices, data_type='L1C', buffer_meters=100)

Process satellite files for multiple reefs separately.

Parameters:

Name Type Description Default
files

List of satellite file paths

required
reef_gdf

GeoDataFrame containing reef polygons

required
reef_indices

List of reef indices to process

required
data_type

'L1C' or 'L2A'

'L1C'
buffer_meters

Buffer size for clipping

100

Returns:

Type Description

Dict mapping reef names to lists of processed images

Source code in ShallowLearn/utilities/file_discovery.py
def process_reef_data(files, reef_gdf, reef_indices, data_type='L1C', buffer_meters=100):
    """
    Process satellite files for multiple reefs separately.

    Args:
        files: List of satellite file paths
        reef_gdf: GeoDataFrame containing reef polygons
        reef_indices: List of reef indices to process
        data_type: 'L1C' or 'L2A'
        buffer_meters: Buffer size for clipping

    Returns:
        Dict mapping reef names to lists of processed images
    """
    from ShallowLearn.io.satellite_data import Sentinel2Image
    from pathlib import Path

    reef_data = {}
    reef_names = reef_gdf.ORIG_NAME.to_list()

    for reef_idx in reef_indices:
        single_reef = reef_gdf.iloc[reef_idx:reef_idx+1].copy()
        reef_name = reef_names[reef_idx]
        reef_area = single_reef.iloc[0].Area

        print(f"\n📊 Processing {data_type} Reef: {reef_name} (Area: {reef_area:,.0f} m²)")

        reef_images = []
        for file_path in files:
            try:
                print(f"   Loading {data_type}: {Path(file_path).name}")
                s2_image = Sentinel2Image(
                    file_path,
                    load_all_bands=False,
                    clip_geometry=single_reef,
                    buffer_meters=buffer_meters
                )
                reef_images.append(s2_image)
                print(f"   ✅ Clipped to {s2_image.image.shape}")

            except Exception as e:
                print(f"   ❌ Failed: {e}")

        if reef_images:
            reef_data[reef_name] = reef_images
            print(f"   📋 {len(reef_images)} images processed for {reef_name}")

    return reef_data

safe_filename(name)

Convert reef name to safe filename.

Source code in ShallowLearn/utilities/file_discovery.py
def safe_filename(name: str) -> str:
    """Convert reef name to safe filename."""
    import re
    return re.sub(r'[^\w\-_.]', '_', str(name))