diff --git a/src/virtual_stain_flow/datasets/base_dataset.py b/src/virtual_stain_flow/datasets/base_dataset.py index 6fdd20b..fda86ed 100644 --- a/src/virtual_stain_flow/datasets/base_dataset.py +++ b/src/virtual_stain_flow/datasets/base_dataset.py @@ -30,6 +30,8 @@ def __init__( input_channel_keys: Optional[Union[str, Sequence[str]]] = None, target_channel_keys: Optional[Union[str, Sequence[str]]] = None, transforms: Optional[Sequence[LoggableTransform]] = None, + input_transforms: Optional[Sequence[LoggableTransform]] = None, + target_transforms: Optional[Sequence[LoggableTransform]] = None, cache_capacity: Optional[int] = None, file_state: Optional[FileState] = None, ): @@ -49,6 +51,10 @@ def __init__( :param target_channel_keys: Keys for target channels in the file index. :param transforms: Optional sequence of LoggableTransform objects to apply to the images before returning them. + :param input_transforms: Optional sequence of LoggableTransform objects to + apply to the input image only, after `transforms`. + :param target_transforms: Optional sequence of LoggableTransform objects to + apply to the target image only, after `transforms`. :param cache_capacity: Optional capacity for caching loaded images. When set to None, default caching behavior of caching at most `file_index.shape[0]` images is used. When set to -1, unbounded @@ -86,6 +92,26 @@ def __init__( raise ValueError("All transforms must be instances of LoggableTransform.") self.transforms = transforms + self.input_transforms = self._normalize_transforms( + input_transforms, "input_transforms" + ) + self.target_transforms = self._normalize_transforms( + target_transforms, "target_transforms" + ) + + def _normalize_transforms( + self, + transforms: Optional[Sequence[LoggableTransform]], + name: str, + ) -> Sequence[LoggableTransform]: + if not isinstance(transforms, Sequence): + transforms = [transforms] if transforms else [] + if not all(isinstance(t, LoggableTransform) for t in transforms): + raise ValueError( + f"All {name} must be instances of LoggableTransform." + ) + return transforms + def get_raw_item( self, idx: int @@ -133,14 +159,38 @@ def _apply_transforms( image = transform.apply(img=image) return image + def _apply_input_transforms( + self, + image: np.ndarray, + ) -> np.ndarray: + for transform in self.input_transforms: + image = transform.apply(img=image) + return image + + def _apply_target_transforms( + self, + image: np.ndarray, + ) -> np.ndarray: + for transform in self.target_transforms: + image = transform.apply(img=image) + return image + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Overridden Dataset `__getitem__` method so class works with torch DataLoader. """ input_image_raw, target_image_raw = self.get_raw_item(idx) - return (torch.from_numpy(self._apply_transforms(input_image_raw)).float(), - torch.from_numpy(self._apply_transforms(target_image_raw)).float()) + input_image = self._apply_transforms(input_image_raw) + target_image = self._apply_transforms(target_image_raw) + + input_image = self._apply_input_transforms(input_image) + target_image = self._apply_target_transforms(target_image) + + return ( + torch.from_numpy(input_image).float(), + torch.from_numpy(target_image).float() + ) @property def pil_image_mode(self) -> str: diff --git a/src/virtual_stain_flow/datasets/crop_dataset.py b/src/virtual_stain_flow/datasets/crop_dataset.py index 975019f..daae3de 100644 --- a/src/virtual_stain_flow/datasets/crop_dataset.py +++ b/src/virtual_stain_flow/datasets/crop_dataset.py @@ -30,6 +30,8 @@ def __init__( input_channel_keys: Optional[Union[str, Sequence[str]]] = None, target_channel_keys: Optional[Union[str, Sequence[str]]] = None, transforms: Optional[Sequence[LoggableTransform]] = None, + input_transforms: Optional[Sequence[LoggableTransform]] = None, + target_transforms: Optional[Sequence[LoggableTransform]] = None, crop_file_state: Optional[CropFileState] = None, ): """ @@ -53,6 +55,10 @@ def __init__( :param input_channel_keys: Keys for input channels in the file index. :param target_channel_keys: Keys for target channels in the file index. :param transforms: Optional sequence of transformations to apply to the images. + :param input_transforms: Optional sequence of LoggableTransform objects to + apply to the input image only, after `transforms`. + :param target_transforms: Optional sequence of LoggableTransform objects to + apply to the target image only, after `transforms`. :param crop_file_state: Optional pre-initialized CropFileState object. If provided, it takes precedence over `file_index` and `crop_specs`. Intended to be used by only .from_config class method and similar deserialization @@ -84,6 +90,13 @@ def __init__( raise ValueError("All transforms must be instances of LoggableTransform.") self.transforms = transforms + self.input_transforms = self._normalize_transforms( + input_transforms, "input_transforms" + ) + self.target_transforms = self._normalize_transforms( + target_transforms, "target_transforms" + ) + @property def pil_image_mode(self) -> str: return self.manifest.pil_image_mode @@ -157,6 +170,8 @@ def from_base_dataset( cls, base_dataset: BaseImageDataset, transforms: Optional[Sequence[LoggableTransform]] = None, + input_transforms: Optional[Sequence[LoggableTransform]] = None, + target_transforms: Optional[Sequence[LoggableTransform]] = None, how: Type[CropGenerator] = generate_center_crops, **kwargs: Any ) -> 'CropImageDataset': @@ -166,6 +181,10 @@ def from_base_dataset( :param base_dataset: The BaseImageDataset to convert. :param how: A function that generates crop specifications from the base dataset. Default is `generate_center_crops`. + :param input_transforms: Optional sequence of LoggableTransform objects to + apply to the input image only, after `transforms`. + :param target_transforms: Optional sequence of LoggableTransform objects to + apply to the target image only, after `transforms`. :param kwargs: Additional keyword arguments for the `how` function. """ @@ -177,6 +196,8 @@ def from_base_dataset( return cls( file_index=base_dataset.file_index, transforms=transforms, + input_transforms=input_transforms, + target_transforms=target_transforms, crop_specs=crop_specs, pil_image_mode=base_dataset.pil_image_mode, input_channel_keys=base_dataset.input_channel_keys, diff --git a/src/virtual_stain_flow/transforms/README.md b/src/virtual_stain_flow/transforms/README.md index df63f20..a2543c6 100644 --- a/src/virtual_stain_flow/transforms/README.md +++ b/src/virtual_stain_flow/transforms/README.md @@ -7,23 +7,33 @@ To facilitate reproducible experiments, the transformations are also made serial ## Overview -This subpackage consists of three modules: +This subpackage consists of four modules: 1. **`transform_utils.py`**: Contains type definitions and validation utilities for transform objects, defining acceptable transform types and providing runtime type checking capabilities for both standard Albumentations transforms and custom LoggableTransform classes. 2. **`base_transform.py`**: Defines the abstract `LoggableTransform` base class that extends Albumentations' `ImageOnlyTransform`, adding serialization capabilities, naming conventions, and standardized logging interfaces for scientific reproducibility. 3. **`normalizations.py`**: Implements concrete normalization transforms including `MaxScaleNormalize` for scaling images to [0,1] range and `ZScoreNormalize` for statistical standardization, both inheriting from `LoggableTransform` to ensure proper integration with the package's logging and configuration systems. +4. **`channelwise.py`**: Provides `ChannelwiseTransform`, a wrapper that applies a distinct transform to each channel of a channel-first image. + --- ## Usage: See examples for in context use with datasets ```python from normalizations import MaxScaleNormalize +from channelwise import ChannelwiseTransform scale_transform = MaxScaleNormalize( normalization_factor='16bit', name="Scale16BitImages", p=1.0 ) + +channelwise_transform = ChannelwiseTransform( + transforms=[ + MaxScaleNormalize(normalization_factor='16bit'), + MaxScaleNormalize(normalization_factor='8bit') + ] +) ``` diff --git a/src/virtual_stain_flow/transforms/__init__.py b/src/virtual_stain_flow/transforms/__init__.py index 4c6b16b..b5bfa14 100644 --- a/src/virtual_stain_flow/transforms/__init__.py +++ b/src/virtual_stain_flow/transforms/__init__.py @@ -1,3 +1,15 @@ """ /transforms/__init__.py """ + +from .channelwise import ChannelwiseTransform +from .normalizations import ( + MaxScaleNormalize, + ZScoreNormalize, +) + +__all__ = [ + "ChannelwiseTransform", + "MaxScaleNormalize", + "ZScoreNormalize", +] diff --git a/src/virtual_stain_flow/transforms/channelwise.py b/src/virtual_stain_flow/transforms/channelwise.py new file mode 100644 index 0000000..0bda69d --- /dev/null +++ b/src/virtual_stain_flow/transforms/channelwise.py @@ -0,0 +1,91 @@ +""" +channelwise.py + +Defines channel-specific transform wrappers. +""" + +from typing import List, Optional, Sequence + +import numpy as np + +from .base_transform import LoggableTransform + + +class ChannelwiseTransform(LoggableTransform): + """ + Apply a list of transforms to a channel-first image, one transform per channel. + """ + + def __init__( + self, + transforms: Sequence[Optional[LoggableTransform]], + name: str = "ChannelwiseTransform", + p: float = 1.0, + channel_axis: int = 0, + ): + super().__init__(name=name, p=p) + + if channel_axis != 0: + raise ValueError("Only channel-first images with channel_axis=0 are supported.") + + if not isinstance(transforms, Sequence) or len(transforms) == 0: + raise ValueError("Expected a non-empty sequence of LoggableTransform or None.") + if not all((t is None) or isinstance(t, LoggableTransform) for t in transforms): + raise ValueError("All transforms must be instances of LoggableTransform or None.") + + self._transforms: List[Optional[LoggableTransform]] = list(transforms) + self._channel_axis = channel_axis + + @property + def transforms(self) -> List[Optional[LoggableTransform]]: + return self._transforms + + @property + def channel_axis(self) -> int: + return self._channel_axis + + def apply(self, img: np.ndarray, **params) -> np.ndarray: + if not isinstance(img, np.ndarray): + raise TypeError( + "Expected input image to be a NumPy array, " + f"got {type(img).__name__} instead." + ) + + if img.ndim != 3: + raise ValueError( + "Expected a channel-first image with shape (C, H, W)." + ) + + channel_count = img.shape[self._channel_axis] + if channel_count != len(self._transforms): + raise ValueError( + "Number of transforms must match number of channels. " + f"Got {len(self._transforms)} transforms for {channel_count} channels." + ) + + transformed_channels = [] + for channel_idx, transform in enumerate(self._transforms): + channel = img[channel_idx:channel_idx + 1, ...] + if transform is None: + transformed_channels.append(channel) + else: + transformed_channels.append(transform.apply(img=channel)) + + return np.concatenate(transformed_channels, axis=0) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(name={self._name}, " + f"channels={len(self._transforms)}, p={self.p})" + ) + + def to_config(self) -> dict: + return { + "class": self.__class__.__name__, + "name": self._name, + "params": { + "channel_axis": self._channel_axis, + "p": self.p, + "transforms": [t.to_config() if t is not None else None for t in self._transforms], + }, + }