Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/virtual_stain_flow/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
input_channel_keys: Optional[Union[str, Sequence[str]]] = None,
target_channel_keys: Optional[Union[str, Sequence[str]]] = None,
transforms: Optional[Sequence[LoggableTransform]] = None,
input_transforms: Optional[Sequence[LoggableTransform]] = None,
target_transforms: Optional[Sequence[LoggableTransform]] = None,
cache_capacity: Optional[int] = None,
file_state: Optional[FileState] = None,
):
Expand All @@ -49,6 +51,10 @@ def __init__(
:param target_channel_keys: Keys for target channels in the file index.
:param transforms: Optional sequence of LoggableTransform objects to apply
to the images before returning them.
:param input_transforms: Optional sequence of LoggableTransform objects to
apply to the input image only, after `transforms`.
:param target_transforms: Optional sequence of LoggableTransform objects to
apply to the target image only, after `transforms`.
:param cache_capacity: Optional capacity for caching loaded images.
When set to None, default caching behavior of caching at most
`file_index.shape[0]` images is used. When set to -1, unbounded
Expand Down Expand Up @@ -86,6 +92,26 @@ def __init__(
raise ValueError("All transforms must be instances of LoggableTransform.")
self.transforms = transforms

self.input_transforms = self._normalize_transforms(
input_transforms, "input_transforms"
)
self.target_transforms = self._normalize_transforms(
target_transforms, "target_transforms"
)

def _normalize_transforms(
self,
transforms: Optional[Sequence[LoggableTransform]],
name: str,
) -> Sequence[LoggableTransform]:
if not isinstance(transforms, Sequence):
transforms = [transforms] if transforms else []
if not all(isinstance(t, LoggableTransform) for t in transforms):
raise ValueError(
f"All {name} must be instances of LoggableTransform."
)
return transforms

def get_raw_item(
self,
idx: int
Expand Down Expand Up @@ -133,14 +159,38 @@ def _apply_transforms(
image = transform.apply(img=image)
return image

def _apply_input_transforms(
self,
image: np.ndarray,
) -> np.ndarray:
for transform in self.input_transforms:
image = transform.apply(img=image)
return image

def _apply_target_transforms(
self,
image: np.ndarray,
) -> np.ndarray:
for transform in self.target_transforms:
image = transform.apply(img=image)
return image

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overridden Dataset `__getitem__` method so class works with torch DataLoader.
"""
input_image_raw, target_image_raw = self.get_raw_item(idx)

return (torch.from_numpy(self._apply_transforms(input_image_raw)).float(),
torch.from_numpy(self._apply_transforms(target_image_raw)).float())
input_image = self._apply_transforms(input_image_raw)
target_image = self._apply_transforms(target_image_raw)

input_image = self._apply_input_transforms(input_image)
target_image = self._apply_target_transforms(target_image)

return (
torch.from_numpy(input_image).float(),
torch.from_numpy(target_image).float()
)

@property
def pil_image_mode(self) -> str:
Expand Down
21 changes: 21 additions & 0 deletions src/virtual_stain_flow/datasets/crop_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
input_channel_keys: Optional[Union[str, Sequence[str]]] = None,
target_channel_keys: Optional[Union[str, Sequence[str]]] = None,
transforms: Optional[Sequence[LoggableTransform]] = None,
input_transforms: Optional[Sequence[LoggableTransform]] = None,
target_transforms: Optional[Sequence[LoggableTransform]] = None,
crop_file_state: Optional[CropFileState] = None,
):
"""
Expand All @@ -53,6 +55,10 @@ def __init__(
:param input_channel_keys: Keys for input channels in the file index.
:param target_channel_keys: Keys for target channels in the file index.
:param transforms: Optional sequence of transformations to apply to the images.
:param input_transforms: Optional sequence of LoggableTransform objects to
apply to the input image only, after `transforms`.
:param target_transforms: Optional sequence of LoggableTransform objects to
apply to the target image only, after `transforms`.
:param crop_file_state: Optional pre-initialized CropFileState object. If provided,
it takes precedence over `file_index` and `crop_specs`. Intended
to be used by only .from_config class method and similar deserialization
Expand Down Expand Up @@ -84,6 +90,13 @@ def __init__(
raise ValueError("All transforms must be instances of LoggableTransform.")
self.transforms = transforms

self.input_transforms = self._normalize_transforms(
input_transforms, "input_transforms"
)
self.target_transforms = self._normalize_transforms(
target_transforms, "target_transforms"
)

@property
def pil_image_mode(self) -> str:
return self.manifest.pil_image_mode
Expand Down Expand Up @@ -157,6 +170,8 @@ def from_base_dataset(
cls,
base_dataset: BaseImageDataset,
transforms: Optional[Sequence[LoggableTransform]] = None,
input_transforms: Optional[Sequence[LoggableTransform]] = None,
target_transforms: Optional[Sequence[LoggableTransform]] = None,
how: Type[CropGenerator] = generate_center_crops,
**kwargs: Any
) -> 'CropImageDataset':
Expand All @@ -166,6 +181,10 @@ def from_base_dataset(
:param base_dataset: The BaseImageDataset to convert.
:param how: A function that generates crop specifications from the base dataset.
Default is `generate_center_crops`.
:param input_transforms: Optional sequence of LoggableTransform objects to
apply to the input image only, after `transforms`.
:param target_transforms: Optional sequence of LoggableTransform objects to
apply to the target image only, after `transforms`.
:param kwargs: Additional keyword arguments for the `how` function.
"""

Expand All @@ -177,6 +196,8 @@ def from_base_dataset(
return cls(
file_index=base_dataset.file_index,
transforms=transforms,
input_transforms=input_transforms,
target_transforms=target_transforms,
crop_specs=crop_specs,
pil_image_mode=base_dataset.pil_image_mode,
input_channel_keys=base_dataset.input_channel_keys,
Expand Down
12 changes: 11 additions & 1 deletion src/virtual_stain_flow/transforms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,33 @@ To facilitate reproducible experiments, the transformations are also made serial

## Overview

This subpackage consists of three modules:
This subpackage consists of four modules:
1. **`transform_utils.py`**: Contains type definitions and validation utilities for transform objects, defining acceptable transform types and providing runtime type checking capabilities for both standard Albumentations transforms and custom LoggableTransform classes.

2. **`base_transform.py`**: Defines the abstract `LoggableTransform` base class that extends Albumentations' `ImageOnlyTransform`, adding serialization capabilities, naming conventions, and standardized logging interfaces for scientific reproducibility.

3. **`normalizations.py`**: Implements concrete normalization transforms including `MaxScaleNormalize` for scaling images to [0,1] range and `ZScoreNormalize` for statistical standardization, both inheriting from `LoggableTransform` to ensure proper integration with the package's logging and configuration systems.

4. **`channelwise.py`**: Provides `ChannelwiseTransform`, a wrapper that applies a distinct transform to each channel of a channel-first image.

---

## Usage:
See examples for in context use with datasets
```python
from normalizations import MaxScaleNormalize
from channelwise import ChannelwiseTransform

scale_transform = MaxScaleNormalize(
normalization_factor='16bit',
name="Scale16BitImages",
p=1.0
)

channelwise_transform = ChannelwiseTransform(
transforms=[
MaxScaleNormalize(normalization_factor='16bit'),
MaxScaleNormalize(normalization_factor='8bit')
]
)
```
12 changes: 12 additions & 0 deletions src/virtual_stain_flow/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""
/transforms/__init__.py
"""

from .channelwise import ChannelwiseTransform
from .normalizations import (
MaxScaleNormalize,
ZScoreNormalize,
)

__all__ = [
"ChannelwiseTransform",
"MaxScaleNormalize",
"ZScoreNormalize",
]
91 changes: 91 additions & 0 deletions src/virtual_stain_flow/transforms/channelwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
channelwise.py

Defines channel-specific transform wrappers.
"""

from typing import List, Optional, Sequence

import numpy as np

from .base_transform import LoggableTransform


class ChannelwiseTransform(LoggableTransform):
"""
Apply a list of transforms to a channel-first image, one transform per channel.
"""

def __init__(
self,
transforms: Sequence[Optional[LoggableTransform]],
name: str = "ChannelwiseTransform",
p: float = 1.0,
channel_axis: int = 0,
):
super().__init__(name=name, p=p)

if channel_axis != 0:
raise ValueError("Only channel-first images with channel_axis=0 are supported.")

if not isinstance(transforms, Sequence) or len(transforms) == 0:
raise ValueError("Expected a non-empty sequence of LoggableTransform or None.")
if not all((t is None) or isinstance(t, LoggableTransform) for t in transforms):
raise ValueError("All transforms must be instances of LoggableTransform or None.")

self._transforms: List[Optional[LoggableTransform]] = list(transforms)
self._channel_axis = channel_axis

@property
def transforms(self) -> List[Optional[LoggableTransform]]:
return self._transforms

@property
def channel_axis(self) -> int:
return self._channel_axis

def apply(self, img: np.ndarray, **params) -> np.ndarray:
if not isinstance(img, np.ndarray):
raise TypeError(
"Expected input image to be a NumPy array, "
f"got {type(img).__name__} instead."
)

if img.ndim != 3:
raise ValueError(
"Expected a channel-first image with shape (C, H, W)."
)

channel_count = img.shape[self._channel_axis]
if channel_count != len(self._transforms):
raise ValueError(
"Number of transforms must match number of channels. "
f"Got {len(self._transforms)} transforms for {channel_count} channels."
)

transformed_channels = []
for channel_idx, transform in enumerate(self._transforms):
channel = img[channel_idx:channel_idx + 1, ...]
if transform is None:
transformed_channels.append(channel)
else:
transformed_channels.append(transform.apply(img=channel))

return np.concatenate(transformed_channels, axis=0)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(name={self._name}, "
f"channels={len(self._transforms)}, p={self.p})"
)

def to_config(self) -> dict:
return {
"class": self.__class__.__name__,
"name": self._name,
"params": {
"channel_axis": self._channel_axis,
"p": self.p,
"transforms": [t.to_config() if t is not None else None for t in self._transforms],
},
}
Loading