Add ConformalPredictor and ConformalCalibrator for split-conformal (LAC) prediction sets#8938
Add ConformalPredictor and ConformalCalibrator for split-conformal (LAC) prediction sets#8938txmed82 wants to merge 3 commits into
Conversation
…AC) prediction sets Fixes Project-MONAI#8935 (part 1 of 2). Brings split-conformal prediction to MONAI via the LAC score `1 - softmax[y]` (Sadinle et al. 2019). First of two PRs for Project-MONAI#8935; the second will add image-level loss-bounded calibration and a per-voxel uncertainty mask to `monai/metrics/`. `monai/inferers/conformal_predictor.py`: - ConformalCalibrator: collect scores on a held-out split, return the threshold qhat with a 1 - alpha marginal coverage guarantee. Works for classification (B, C) and per-voxel segmentation (B, C, spatial...); include_background=False drops background-labeled samples so the threshold isn't dominated by easy background. - ConformalPredictor: an Inferer that wraps a network + qhat and returns the prediction-set mask { y : 1 - softmax[y] <= qhat }. Accepts a pre-calibrated qhat or calibrates in-band from a loader. Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughA new module Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (4)
monai/inferers/conformal_predictor.py (2)
108-108: ⚡ Quick winStore accumulated scores on CPU.
Per-voxel calibration can accumulate millions of scores; keeping them on GPU risks OOM without improving the final threshold computation.
Proposed fix
- self._scores.append((1.0 - true_p).detach()) + self._scores.append((1.0 - true_p).detach().cpu())🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/inferers/conformal_predictor.py` at line 108, The accumulated scores in the _scores list are being kept on GPU memory, which can cause OOM errors during per-voxel calibration with millions of scores. In the line where self._scores.append((1.0 - true_p).detach()) is called, add .cpu() after .detach() to move the tensor to CPU memory before appending, ensuring scores are stored on CPU and not GPU.
28-242: 🏗️ Heavy liftComplete Google-style docstrings for the new definitions.
Several definitions omit required
Args,Returns, orRaisessections, especially raised exceptions in_quantile_threshold,accumulate,calibrate,set_threshold, and__call__.As per coding guidelines, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/inferers/conformal_predictor.py` around lines 28 - 242, Complete the Google-style docstrings for all new definitions to include all required sections. For the _quantile_threshold function, add Returns and Raises sections documenting the returned quantile tensor and potential ValueError exceptions. For the accumulate method in ConformalCalibrator, add a Raises section documenting ValueError conditions. For both calibrate methods (in ConformalCalibrator and ConformalPredictor), add complete Returns sections describing the returned qhat threshold tensor, and Raises sections documenting RuntimeError and other exceptions. For the set_threshold method, add Args and Raises sections documenting the qhat parameter and TypeError. For the __call__ method in ConformalPredictor, add a Raises section documenting RuntimeError and TypeError exceptions that can be raised.Source: Coding guidelines
tests/inferers/test_conformal_predictor.py (2)
130-132: ⚡ Quick winAssert excluded classes have no selected voxels.
The current
.all()checks pass even if only one voxel is wrongly included for class 1 or 2.Proposed fix
self.assertEqual(sets.dtype, torch.bool) self.assertTrue(sets[:, 0].all()) - self.assertFalse(sets[:, 1].all()) - self.assertFalse(sets[:, 2].all()) + self.assertFalse(sets[:, 1].any()) + self.assertFalse(sets[:, 2].any())🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/inferers/test_conformal_predictor.py` around lines 130 - 132, The assertions for excluded classes (indices 1 and 2) in the test are checking the wrong condition. The current use of assertFalse with .all() passes even if some voxels are wrongly included. To properly verify that excluded classes have no selected voxels, change the assertions from assertFalse(sets[:, 1].all()) and assertFalse(sets[:, 2].all()) to assertFalse(sets[:, 1].any()) and assertFalse(sets[:, 2].any()) respectively. This ensures that no voxels are selected for the excluded classes rather than just checking that not all voxels are selected.
23-174: 🏗️ Heavy liftAdd Google-style docstrings to the new test definitions.
The new test classes, helper, local module, and test methods currently lack docstrings required by the repository rule.
As per coding guidelines, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/inferers/test_conformal_predictor.py` around lines 23 - 174, Add Google-style docstrings to all test classes and their methods that currently lack them. For TestQuantileThreshold class and its test methods (test_exact_coverage_rank, test_clamps_to_valid_rank, test_rejects_empty, test_rejects_bad_alpha), for TestConformalCalibrator class including the _cal_batch helper method and its test methods (test_classification_single_batch, test_classification_multi_batch, test_segmentation_voxel_reshape, test_exclude_background_scores_foreground_only, test_exclude_background_drops_bg_voxels, test_unsupported_score_raises, test_calibrate_empty_raises), and for TestConformalPredictor class and its test methods (test_set_and_predict, test_no_threshold_raises, test_calibrate_then_predict, test_bad_network_output_raises), each docstring should include a brief summary describing what the test validates, what inputs or conditions are being tested, and any expected outcomes or exceptions being checked.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/inferers/conformal_predictor.py`:
- Line 23: The `__all__` list in the conformal_predictor.py module does not
follow alphabetical order as required by Ruff linting rule RUF022. Sort the
entries in the `__all__` list alphabetically so that "ConformalCalibrator"
appears before "ConformalPredictor" to comply with the linter requirement.
- Around line 176-180: The set_threshold method currently only validates that
qhat is a torch.Tensor, but does not check if it is a scalar or contains valid
probability values. Add validation after the existing type check to ensure qhat
is a scalar tensor (has a single element) and its value is within the valid
probability range [0, 1]. Raise appropriate TypeErrors or ValueErrors if these
conditions are not satisfied to prevent broadcasting issues and invalid
prediction sets downstream.
- Around line 197-218: The network is set to evaluation mode using
network.eval() at the start of the calibration process but its training state is
never restored before returning. This causes the network to remain in eval mode
after the method completes, which disables dropout and batch normalization. Save
the original training state of the network before calling network.eval(), then
after cal.calibrate() and self.set_threshold(qhat) are called, restore the
network to its original training state using network.train() if it was
originally in training mode, or leave it in eval mode if it was originally in
eval mode, before returning qhat.
- Around line 98-107: The `.clamp()` operation on the `labels_flat` variable
silently maps invalid labels (negative or out-of-range values like -1 or 255) to
valid class indices, which corrupts the non-conformity scores. Instead of
clamping, create a validity mask that identifies valid labels (where labels are
>= 0 and < probs_flat.shape[1]), then apply this mask to filter both
`probs_flat` and `labels_flat` to keep only rows with valid labels before the
gather operation. This ensures invalid labels are rejected entirely rather than
silently remapped.
---
Nitpick comments:
In `@monai/inferers/conformal_predictor.py`:
- Line 108: The accumulated scores in the _scores list are being kept on GPU
memory, which can cause OOM errors during per-voxel calibration with millions of
scores. In the line where self._scores.append((1.0 - true_p).detach()) is
called, add .cpu() after .detach() to move the tensor to CPU memory before
appending, ensuring scores are stored on CPU and not GPU.
- Around line 28-242: Complete the Google-style docstrings for all new
definitions to include all required sections. For the _quantile_threshold
function, add Returns and Raises sections documenting the returned quantile
tensor and potential ValueError exceptions. For the accumulate method in
ConformalCalibrator, add a Raises section documenting ValueError conditions. For
both calibrate methods (in ConformalCalibrator and ConformalPredictor), add
complete Returns sections describing the returned qhat threshold tensor, and
Raises sections documenting RuntimeError and other exceptions. For the
set_threshold method, add Args and Raises sections documenting the qhat
parameter and TypeError. For the __call__ method in ConformalPredictor, add a
Raises section documenting RuntimeError and TypeError exceptions that can be
raised.
In `@tests/inferers/test_conformal_predictor.py`:
- Around line 130-132: The assertions for excluded classes (indices 1 and 2) in
the test are checking the wrong condition. The current use of assertFalse with
.all() passes even if some voxels are wrongly included. To properly verify that
excluded classes have no selected voxels, change the assertions from
assertFalse(sets[:, 1].all()) and assertFalse(sets[:, 2].all()) to
assertFalse(sets[:, 1].any()) and assertFalse(sets[:, 2].any()) respectively.
This ensures that no voxels are selected for the excluded classes rather than
just checking that not all voxels are selected.
- Around line 23-174: Add Google-style docstrings to all test classes and their
methods that currently lack them. For TestQuantileThreshold class and its test
methods (test_exact_coverage_rank, test_clamps_to_valid_rank,
test_rejects_empty, test_rejects_bad_alpha), for TestConformalCalibrator class
including the _cal_batch helper method and its test methods
(test_classification_single_batch, test_classification_multi_batch,
test_segmentation_voxel_reshape, test_exclude_background_scores_foreground_only,
test_exclude_background_drops_bg_voxels, test_unsupported_score_raises,
test_calibrate_empty_raises), and for TestConformalPredictor class and its test
methods (test_set_and_predict, test_no_threshold_raises,
test_calibrate_then_predict, test_bad_network_output_raises), each docstring
should include a brief summary describing what the test validates, what inputs
or conditions are being tested, and any expected outcomes or exceptions being
checked.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: eac20e1e-86cd-4755-a6ec-856fb018223e
📒 Files selected for processing (3)
monai/inferers/__init__.pymonai/inferers/conformal_predictor.pytests/inferers/test_conformal_predictor.py
| from monai.inferers.inferer import Inferer | ||
| from monai.utils.module import optional_import | ||
|
|
||
| __all__ = ["ConformalPredictor", "ConformalCalibrator"] |
There was a problem hiding this comment.
Sort __all__ to satisfy Ruff RUF022.
Proposed fix
-__all__ = ["ConformalPredictor", "ConformalCalibrator"]
+__all__ = ["ConformalCalibrator", "ConformalPredictor"]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| __all__ = ["ConformalPredictor", "ConformalCalibrator"] | |
| __all__ = ["ConformalCalibrator", "ConformalPredictor"] |
🧰 Tools
🪛 Ruff (0.15.17)
[warning] 23-23: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/inferers/conformal_predictor.py` at line 23, The `__all__` list in the
conformal_predictor.py module does not follow alphabetical order as required by
Ruff linting rule RUF022. Sort the entries in the `__all__` list alphabetically
so that "ConformalCalibrator" appears before "ConformalPredictor" to comply with
the linter requirement.
Source: Linters/SAST tools
| def set_threshold(self, qhat: torch.Tensor) -> None: | ||
| """Set (or update) the calibrated threshold. Lets you keep one inferer and re-calibrate.""" | ||
| if not isinstance(qhat, torch.Tensor): | ||
| raise TypeError(f"qhat must be a torch.Tensor, got {type(qhat)}.") | ||
| self.qhat = qhat.detach().clone() |
There was a problem hiding this comment.
Validate qhat as a scalar probability threshold.
A multi-element or out-of-range qhat can broadcast incorrectly in Line 241 or produce invalid prediction sets.
Proposed fix
def set_threshold(self, qhat: torch.Tensor) -> None:
"""Set (or update) the calibrated threshold. Lets you keep one inferer and re-calibrate."""
if not isinstance(qhat, torch.Tensor):
raise TypeError(f"qhat must be a torch.Tensor, got {type(qhat)}.")
- self.qhat = qhat.detach().clone()
+ if qhat.numel() != 1:
+ raise ValueError(f"qhat must be a scalar tensor, got shape {tuple(qhat.shape)}.")
+ if not torch.isfinite(qhat).all():
+ raise ValueError("qhat must be finite.")
+ qhat_value = qhat.detach().clone().reshape(())
+ if not 0.0 <= qhat_value.item() <= 1.0:
+ raise ValueError(f"qhat must be in [0, 1], got {qhat_value.item()}.")
+ self.qhat = qhat_value📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def set_threshold(self, qhat: torch.Tensor) -> None: | |
| """Set (or update) the calibrated threshold. Lets you keep one inferer and re-calibrate.""" | |
| if not isinstance(qhat, torch.Tensor): | |
| raise TypeError(f"qhat must be a torch.Tensor, got {type(qhat)}.") | |
| self.qhat = qhat.detach().clone() | |
| def set_threshold(self, qhat: torch.Tensor) -> None: | |
| """Set (or update) the calibrated threshold. Lets you keep one inferer and re-calibrate.""" | |
| if not isinstance(qhat, torch.Tensor): | |
| raise TypeError(f"qhat must be a torch.Tensor, got {type(qhat)}.") | |
| if qhat.numel() != 1: | |
| raise ValueError(f"qhat must be a scalar tensor, got shape {tuple(qhat.shape)}.") | |
| if not torch.isfinite(qhat).all(): | |
| raise ValueError("qhat must be finite.") | |
| qhat_value = qhat.detach().clone().reshape(()) | |
| if not 0.0 <= qhat_value.item() <= 1.0: | |
| raise ValueError(f"qhat must be in [0, 1], got {qhat_value.item()}.") | |
| self.qhat = qhat_value |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/inferers/conformal_predictor.py` around lines 176 - 180, The
set_threshold method currently only validates that qhat is a torch.Tensor, but
does not check if it is a scalar or contains valid probability values. Add
validation after the existing type check to ensure qhat is a scalar tensor (has
a single element) and its value is within the valid probability range [0, 1].
Raise appropriate TypeErrors or ValueErrors if these conditions are not
satisfied to prevent broadcasting issues and invalid prediction sets downstream.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/inferers/test_conformal_predictor.py (1)
112-130: ⚡ Quick winAdd required Google-style docstrings to new test and nested
forwarddefinitions.Newly added defs in Line 112, Line 123, Line 191, Line 196, Line 201, Line 217, and nested
forwardmethods (Line 203 and Line 219) are missing docstrings required by this repo’s Python guideline.As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
Also applies to: 191-231
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/inferers/test_conformal_predictor.py` around lines 112 - 130, The new test methods test_invalid_labels_are_dropped_not_clamped and test_mixed_valid_invalid_labels_keeps_valid, along with other new definitions and nested forward methods mentioned in the comment, are missing required Google-style docstrings that comply with the repository's Python guidelines. Add Google-style docstrings to each of these methods that describe what the test validates, document any parameters, return values, and exceptions raised. Ensure the docstrings follow the standard format of: brief description, Args section (if applicable), Returns section (if applicable), and Raises section (if applicable).Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/inferers/test_conformal_predictor.py`:
- Around line 112-130: The new test methods
test_invalid_labels_are_dropped_not_clamped and
test_mixed_valid_invalid_labels_keeps_valid, along with other new definitions
and nested forward methods mentioned in the comment, are missing required
Google-style docstrings that comply with the repository's Python guidelines. Add
Google-style docstrings to each of these methods that describe what the test
validates, document any parameters, return values, and exceptions raised. Ensure
the docstrings follow the standard format of: brief description, Args section
(if applicable), Returns section (if applicable), and Raises section (if
applicable).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 5caa2d09-100a-49c3-b8dc-c17853f38d3b
📒 Files selected for processing (2)
monai/inferers/conformal_predictor.pytests/inferers/test_conformal_predictor.py
🚧 Files skipped from review as they are similar to previous changes (1)
- monai/inferers/conformal_predictor.py
…reference Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
…rop invalid labels - set_threshold: reject non-scalar qhat (prevents silent broadcasting bugs) - calibrate: save and restore network.training state (avoid side effect of eval()) - accumulate: drop invalid labels (negative or >= C) instead of clamping them (clamping silently remapped -1 -> 0 and 255 -> C-1, corrupting LAC scores) - accumulate: move scores to CPU after detach (avoid GPU OOM on large 3D cal sets) - tests: add coverage for scalar validation, train-state restore, invalid-label handling; fix assertFalse(.all()) -> .any() for excluded-class assertions - docstrings: add Returns/Raises sections to _quantile_threshold, accumulate, calibrate, set_threshold, __call__ Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
Fixes #8935 (part 1 of 2).
Description
Brings split-conformal prediction to MONAI via the LAC score
1 - softmax[y](Sadinle et al. 2019). First of two PRs for #8935; thesecond will add image-level loss-bounded calibration and a per-voxel
uncertainty mask to
monai/metrics/.monai/inferers/conformal_predictor.py:ConformalCalibrator: collect scores on a held-out split, return thethreshold
qhatwith a1 - alphamarginal coverage guarantee. Worksfor classification
(B, C)and per-voxel segmentation(B, C, spatial...);include_background=Falsedrops background-labeled samples so thethreshold isn't dominated by easy background.
ConformalPredictor: anInfererthat wraps a network +qhatandreturns the prediction-set mask
{ y : 1 - softmax[y] <= qhat }. Acceptsa pre-calibrated
qhator calibrates in-band from a loader.API docs: added
ConformalPredictorandConformalCalibratorautoclassentries to
docs/source/inferers.rst.Verification run locally
pytest tests/inferers/test_conformal_predictor.py— 15 passed.pytest tests/inferers/— 246 passed, 176 skipped (excludingtest_zarr_avg_merger.py, which fails collection due to a missing optionalzarrdep, unrelated to this PR).isort/black/ruff— clean on the changed files.mypy monai/inferers/conformal_predictor.py— no errors in this module.I have not run the full
./runtests.sh --net --coverageintegration suiteor the whole-repo
--quick --unittests --disttestspass locally, so thoseboxes are left unchecked for CI to confirm.
make htmlwas likewise not run.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.