nnClass is a Python-based project that extends the nnUNet framework classification capabilities.
- Stratified Data Splitting: Advanced data splitting with demographic stratification (age, gender)
- Multi-Modal Input: Support for multi-channel medical images (CT, PET, MRI)
- Flexible Classification: Support for both binary and multi-class classification tasks
- Comprehensive Inference: Batch processing with sliding window prediction and test-time augmentation
To set up nnUNetCLS, first install the required dependencies:
# Create and activate a new conda environment
conda create -n nnclass python=3.12 -y
conda activate nnclass
# Install dependencies
pip install wandb
# Install the current project in editable mode
pip install -e .Before using nnUNetCLS, you need to configure the nnUNet environment paths. Modify the paths in nnunetv2/paths.py to point to your desired directories:
# Edit nnunetv2/paths.py
nnUNet_raw = "/path/to/your/nnUNet_raw"
nnUNet_preprocessed = "/path/to/your/nnUNet_preprocessed"
nnUNet_results = "/path/to/your/nnUNet_results"nnUNetCLS relies on nnUNet’s preprocessing pipeline to standardize image spacing, intensity normalization, and patch extraction. Preprocessing must be completed before training or inference.
Your dataset must follow the nnUNet folder convention:
nnUNet_raw/
└── Dataset<DATASET_ID>_<DATASET_NAME>/
├── imagesTr/ # Training images (NIfTI format)
│ ├── PatientID_0000.nii.gz # First modality (e.g., CT)
│ ├── PatientID_0001.nii.gz # Second modality (e.g., MRI)
│ └── ...
├── labelsTr/ # Training labels (segmentation masks)
│ ├── PatientID.nii.gz
│ └── ...
├── imagesTs/ # Test images (no labels required)
│ ├── TestID_0000.nii.gz
│ └── ...
└── dataset.json # Dataset description fileNotes:
- Each modality is indexed as
_0000,_0001, etc. - Segmentation labels must have the same base name as the training images (without modality suffix).
dataset.jsondefines modalities, labels, and dataset splits.
Run the standard nnUNet preprocessing:
nnUNetv2_plan_and_preprocess -d <DATASET_ID> -c 3d_fullres --verify_dataset_integrity
For Res Encoder
nnUNetv2_plan_experiment -d <DATASET_ID> -pl nnUNetPlannerResEncM #nnUNetPlannerResEncL / nnUNetPlannerResEncXLUse generate_cls_data.py to create stratified train/validation/test splits from your clinical dataset:
python generate_cls_data.py \
--input_path /path/to/clinical_data.csv \
--output_path /path/to/output/folder \
--identifier_column PatientID \
--label_column diagnosisArguments:
--input_path, -i: Path to CSV/Excel file containing clinical and imaging information--output_path, -o: Directory to save classification data and splits -> "/path/to/your/nnUNet_preprocessed"--identifier_column, -id: Column name for patient identifiers (default: 'patient_id')--label_column, -label: Column name for classification labels (default: 'label')
Required CSV columns:
- Patient identifiers (e.g., 'PatientID')
- Classification labels
Age_at_StudyDate: For age-based stratificationGender: For gender-based stratification
Outputs:
cls_data.csv: Classification datasettest_data.csv: Held-out test set (20% of data)splits_final.json: 5-fold cross-validation splits with stratification- Automatic filtering of cases without segmentation data
nnUNetv2_train 161 3d_fullres 0 -tr <TrainerName>
nnUNetv2_train 714 3d_fullres all -p nnUNetResEncUNetMPlansNotes: for baseline cls support
- MedNeXtTrainer
- ViTTrainer
- DenseNetTrainer
- SEResNetTrainer
- SwinViTTrainer
- nnUNetREGTrainer
- MedNeXtREGTrainer
- ViTREGTrainer
- DenseNetREGTrainer
- SEResNetREGTrainer
- SwinViTREGTrainer define in nnunetv2/training/nnUNetTrainer/nnUNetCLSTrainer.py
Run joint segmentation and classification inference on NIfTI images:
python nnunet_cls_infer_nii.py \
--input_path /path/to/input/images/ \
--output_path /path/to/output/ \
--model_path /path/to/trained/model \
--fold all \
--checkpoint checkpoint_best.pth \
--device cuda \
--cls_mode meanArguments:
--input_path, -i: Directory containing input NIfTI images (expects*_000X.nii.gznaming convention)--output_path, -o: Directory to save segmentation masks and classification results--model_path: Path to trained nnUNet model directory--fold: Fold number or 'all' for ensemble prediction (default: 'all')--checkpoint: Checkpoint filename (default: 'checkpoint_best.pth')--use_softmax: Apply softmax to segmentation output (default: False)--device: Computing device ('cuda' or 'cpu', default: 'cuda')--cls_mode: Classification aggregation mode ('mean' or 'weighted', default: 'mean')
Input Format: Images should follow nnUNet naming convention:
PatientID_0000.nii.gz(first modality)PatientID_0001.nii.gz(second modality)- etc.
Outputs:
{PatientID}.nii.gz: Segmentation masks for each caseresults.csv: Classification probabilities for all cases
Stratified Cross-Validation:
- Creates balanced splits based on age quartiles, gender, and target labels
- Ensures representative distribution across all folds
- 80/20 train-test split with 5-fold cross-validation on training data
Advanced Inference:
- Sliding window prediction with Gaussian weighting
- Test-time augmentation with mirroring
- Multi-fold model ensembling
- Memory-efficient processing for large images
- Automatic batch processing of multiple cases
Classification Modes:
mean: Average classification scores across all patchesweighted: Weight classification by segmentation confidence
The framework extends nnUNet with:
- Shared encoder for both segmentation and classification
- Dual output heads (segmentation + classification)
- Feature aggregation from the final encoder stage
- Support for both binary and multi-class classification
This project is licensed under the Apache License 2.0.
If you use nnUNetCLS in your research, please cite the original nnUNet paper and this extension.
Note: Ensure your input data follows the nnUNet preprocessing requirements and naming conventions for optimal performance.