Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions scgpt/model/generation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,16 @@ def __init__(
use_fast_transformer = False
self.use_fast_transformer = use_fast_transformer

# STEP 1: EMBED input vectors: gene tokens, binned expressions and condition vector (which is perturbation here)
self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token])
self.value_encoder = ContinuousValueEncoder(d_model, dropout)
self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id)

print("Using simple batchnorm instead of domain specific batchnorm")
self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5)

# STEP 2: create encoder with number of blocks defined in parameters d_model=512, nhead=8, etc
# use built-in standard attention encoder from either fast_transformer, flash-attention or Torch
if use_fast_transformer:
if fast_transformer_backend == "linear":
self.transformer_encoder = FastTransformerEncoderWrapper(
Expand All @@ -102,11 +105,14 @@ def __init__(
)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

# STEP 3: create several DECODERs
# self.decoder = nn.Linear(d_model, 1)
self.decoder = ExprDecoder(
d_model,
explicit_zero_prob=explicit_zero_prob,
)

# decoder for classification task
self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls)
if do_mvc:
self.mvc_decoder = MVCDecoder(
Expand All @@ -131,13 +137,21 @@ def _encode(
input_pert_flags,
src_key_padding_mask: Tensor,
) -> Tensor:

# EMBED all input vectors
src = self.encoder(src) # (batch, seq_len, embsize)
self.cur_gene_token_embs = src
values = self.value_encoder(values) # (batch, seq_len, embsize)
perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize)

# SUM UP: collapse all of them into a single vector (see article)
total_embs = src + values + perts

# hotfix to be able to install later version of flashattention
# to cuda 12: https://github.com/bowang-lab/scGPT/issues/69
total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)

# ENCODE with all encoding blocks (12 by default)
output = self.transformer_encoder(
total_embs, src_key_padding_mask=src_key_padding_mask
)
Expand Down Expand Up @@ -203,10 +217,16 @@ def forward(
do_sample = True
logger.warning("Auto set do_sample to True when model is in eval mode.")

# STEP 1: ENCODE -> Embedding and all blocks inside !2 blocks by default)
transformer_output = self._encode(
src, values, input_pert_flags, src_key_padding_mask
)

# STEP 2: DECODE -> ExprDecoder: Linear/RelU/Linear/relu/Linear
# use different decoders depending on task: simple masking modelling / classification etc
output = {}

# Masked Language Modeling (MLM)
mlm_output = self.decoder(transformer_output)
if self.explicit_zero_prob and do_sample:
bernoulli = Bernoulli(probs=mlm_output["zero_probs"])
Expand All @@ -217,8 +237,12 @@ def forward(
output["mlm_zero_probs"] = mlm_output["zero_probs"]

cell_emb = self._get_cell_emb_from_layer(transformer_output, values)

# if celltype classification objective
if CLS:
output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls)

# if Masked value prediction for cell embedding
if MVC:
mvc_output = self.mvc_decoder(
cell_emb,
Expand All @@ -231,6 +255,8 @@ def forward(
output["mvc_output"] = mvc_output["pred"] # (batch, seq_len)
if self.explicit_zero_prob:
output["mvc_zero_probs"] = mvc_output["zero_probs"]

# if Elastic cell similarity objective
if ECS:
# Here using customized cosine similarity instead of F.cosine_similarity
# to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064
Expand All @@ -246,6 +272,13 @@ def forward(

output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2)

# output might contain:
# output["mlm_output"]
# output["mlm_zero_probs"]
# output["cls_output"]
# output["mvc_output"]
# output["mvc_zero_probs"]
# output["loss_ecs"]
return output

def encode_batch(
Expand Down