A baseline ML model to predict cell responses to drug perturbations
Award-winning solution from the NeurIPS 2023 Single-Cell Perturbations Challenge:
- π₯ $10,000 Judges' Prize for performance and methodology
- π₯ 2nd place in post-hoc analysis
- π Top 2% overall (16th/1097 teams)
pip install git+https://github.com/scapeML/scape.git
import scape
# data from zenodo can be downloaded via
scape.io.download_from_zenodo(target_dir = ".")
# Train model with drug cross-validation
result = scape.api.train(
de_file="_data/de_train.parquet",
lfc_file="_data/lfc_train.parquet",
cv_drug="Belinostat",
n_genes=64
)
# Visualize performance vs baselines
scape.util.plot_result(result._last_train_results)
ScAPE is a lightweight neural network (~9.6M parameters for the single-task version) that predicts differential gene expression in response to drug perturbations. Built with Keras 3 for multi-backend support (TensorFlow, JAX, PyTorch).
- π― Single or Multi-Task Learning: Predict p-values only or jointly with fold changes
- π Multi-Backend Support: Choose between TensorFlow, JAX, or PyTorch
- π² Built-in Ensemble Methods: Simple blending for robust predictions
- π Cross-Validation: Cell-type and drug-based validation strategies
- β‘ Efficient: Handles ~18,000 genes with median-based feature engineering
The model uses median-based feature engineering: for each drug and cell type, we compute median differential expression values across the dataset. This reduces ~18,000 genes to manageable drug/cell signatures while preserving biological signal.
Key design choices:# Command line
python -m scape train --n-genes 64 --cv-drug Belinostat _data/de_train.parquet _data/lfc_train.parquet
# Python API
import scape
model = scape.model.create_default_model(
n_genes=64,
df_de=de_data,
df_lfc=lfc_data
)
results = model.train(
val_cells=['NK cells'],
val_drugs=['Belinostat'],
epochs=600
)
Configure the model to jointly predict both p-values and fold changes:
# Multi-task configuration with optimal weights
model.model.compile(
optimizer=optimizer,
loss={'slogpval': mrrmse, 'lfc': mrrmse},
loss_weights={'slogpval': 0.8, 'lfc': 0.2}
)
# Use JAX backend (recommended for performance)
KERAS_BACKEND=jax python -m scape train ...
# Use TensorFlow backend
KERAS_BACKEND=tensorflow python -m scape train ...
# Use PyTorch backend
KERAS_BACKEND=torch python -m scape train ...
Improve robustness with simple ensemble blending:
from sklearn.model_selection import KFold
import numpy as np
# Train multiple models with K-fold
predictions = []
for train_idx, val_idx in KFold(n_splits=5).split(all_combinations):
model = scape.model.create_default_model(...)
model.train(...)
predictions.append(model.predict(test_combinations))
# Blend predictions (median)
ensemble_pred = np.median([p.values for p in predictions], axis=0)
# Custom architecture
config = {
"encoder_hidden_layer_sizes": [128, 128],
"decoder_hidden_layer_sizes": [128, 512],
"outputs": {
"slogpval": (64, "linear"),
"lfc": (64, "linear"), # Multi-task
},
"noise": 0.01,
"dropout": 0.05
}
model = scape.model.create_model(
n_genes=64,
df_de=de_data,
df_lfc=lfc_data,
config=config
)
Track model improvement over baselines:
- Zero baseline: Always predicts 0 (competition baseline)
- Median baseline: Predicts drug-specific medians
- π Quick Start Tutorial
- π Training Pipeline
- π Google Colab Demo
- π Technical Report
- πΎ Dataset (Zenodo)
# Setup with pixi
pixi install
pixi shell -e dev
# Run tests (JAX backend recommended)
KERAS_BACKEND=jax pixi run -e dev test
# Lint & format
pixi run lint
pixi run format
@article {scape2025perturb,
author = {Romero-Tapiador, Sergio and Rodriguez-Mier, Pablo and Garrido-Rodriguez, Martin and Tolosana, Ruben and Morales, Aythami and Saez-Rodriguez, Julio},
title = {ScAPE: A lightweight multitask learning baseline method to predict transcriptomic responses to perturbations},
year = {2025},
doi = {10.1101/2025.09.08.674873},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2025/09/09/2025.09.08.674873},
eprint = {https://www.biorxiv.org/content/early/2025/09/09/2025.09.08.674873.full.pdf},
journal = {bioRxiv}
}