Skip to content

Conversation

@savitha-eng
Copy link
Collaborator

Description

This PR implements a complete end-to-end genomic foundation model training pipeline using LLAMA3 architecture with EVO2-compatible data processing and distributed training capabilities.

Key Components:

Genomic Tokenization System:
ASCII character-level tokenization with 256-token vocabulary covering complete ASCII character set
Direct nucleotide encoding supporting standard nucleotides (A, T, C, G, N) and IUPAC ambiguity codes
EVO2-compatible byte-level tokenization methodology
Integrated BOS, EOS, and PAD tokens for proper sequence boundary management
EVO2-Style Genomic Dataloader (following the Sharded Eden Dataloader):
Efficient sequence windowing: 8192-token sequences with 7992-token stride (200bp overlap) - following the EVO2 approach
Randomization: fixed window positions with shuffled access order for training stability (
SQLite backend integration for high-performance sequence retrieval
Full FSDP and DDP distributed training compatibility
Memory-efficient streaming with configurable parameters
Currently processes small subsample dataset (1,024 sequences into 254,043 training windows) with future scaling planned for production datasets
Model Pipeline Architecture:
Dynamic RoPE implementation supporting arbitrary sequence lengths with ESM-2 style computation
Random weight initialization using from-config approach to avoid text-domain bias
Memory-optimized LLAMA3: 8-layer, 403M parameter configuration for development/testing
FSDP integration with TransformerEngine optimization
Complete Weights & Biases integration for experiment tracking
Quick Validation Results: - Much more testing needed
L0 Sanity: 4-layer model, 25 steps, 3-minute validation
L1 Pilot: 8-layer model, 200 steps, 50-minute training
Stable training convergence demonstrated (loss: 6.1 → 1.38 over 200 steps)
Consistent 23.5GB GPU memory usage with 4.1 iterations/second sustained performance
Note: Much more extensive validation is needed with full-scale models and production datasets

Usage

Quick validation run:
cd bionemo-recipes/recipes/llama_native_te_nvfsdp torchrun --nproc_per_node=1 train.py --config-name=L0_sanity
More Realistic training run (still very small scale):
cd bionemo-recipes/recipes/llama_native_te_nvfsdp torchrun --nproc_per_node=1 train.py --config-name=L1_pilot
Configuration customization:
Type of changes
[x] New feature (non-breaking change which adds functionality)
[ ] Bug fix (non-breaking change which fixes an issue)
[ ] Refactor
[ ] Documentation update
[ ] Other (please describe):
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
[ ] ciflow:skip - Skip all CI tests for this PR
[ ] ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
[ ] ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
[ ] ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
[x] ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.
Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.
For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
/ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.
Pre-submit Checklist
[x] I have tested these changes locally
[ ] I have updated the documentation accordingly
[ ] I have added/updated tests as needed
[] All existing tests pass successfully

Next Steps:

  • Dataset Scaling: Scale from current subsample (1K sequences) to larger datasets (10M+ sequences)
  • Model Scaling: Increase to 16+ layers and full parameter counts for training
  • Multi-Node Training: Test and validate existing FSDP implementation across multiple GPU nodes
  • THD Sequence Packing
  • Extended Context Training: Support for longer context lengths (32K+ tokens) with iterative training strategies and memory optimization
  • Extended Validation: Comprehensive evaluation with downstream genomic tasks and comparative benchmarking against EVO2

…ration into bionemo recipes

Signed-off-by: savitha-eng <[email protected]>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 29, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 29, 2025

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

✨ Finishing touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch savitha-llama3integration-with-dataloader-and-ascii-tokenizer

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Collaborator

@pstjohn pstjohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving a first set of comments, will keep reviewing later

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this folder isn't a package, so this file won't do anything 🤷

from tokenizer import NucleotideASCIITokenizer


class GenomicSequenceDataset(Dataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dataset class isn't quite right; high-level if we want to be consistent with the ESM-2 recipe (and generally follow the huggingface flow) we'd want a dataset that doesn't do tokenization, it just returns the full sequences. I'm not seeing a reason that couldn't essentially be a bunch of sharded parquet files with the raw sequences in them, like we do for esm-2:

https://huggingface.co/datasets/nvidia/esm2_uniref_pretraining_data

then you'd want various map operations to tokenize the dataset from there, and possibly a sampler to shuffle it.


print(f"Loaded {len(self.sequences)} sequences from database")

def _create_window_mappings(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the stride parameter from hf tokenizers not do this?
https://huggingface.co/docs/transformers/en/main_classes/tokenizer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are two potential issues:

  1. Shuffling would happen at the dataset level, so you would randomly iterate through input sequences, but then each of those input sequences would get mapped to many ordered windows by the tokenizer. As @pstjohn suggested in the slack thread, this could be helped with some kind of a local random buffer. Depending on the size of the buffer and how many chunks are returned from a given genome sequence it may take a while before you eventually start sampling from different sequences.
  2. You would need to handle the fact that your dataset returns N items and your tokenizer applied to the dataset returns M >>= N items given the 1-many relationship of the tokenizer to the inputs to the tokenizer if you went this path. If you pass stride, max_length, and return_overflowing_tokens=True to your tokenizer call wherever that happens then that would return ordered strided window samplings for a given input which you would then want to reshuffle and repack into smaller batches of a target size ideally.


print(f"Created {total_windows:,} window mappings with stride={self.stride} (overlap={self.seq_length - self.stride}bp)")

def _randomize_window_access(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't scale very well :D this just shuffles the dataset though?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a shuffle buffer instead?

seq_idx, contig_id, start_pos, window_length = self.window_mappings[actual_idx]

# Retrieve sequence from SQL database (speed requirement)
with sqlite3.connect(str(self.database_path)) as conn:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we used sqlite for esm2 in bionemo2 for various reasons, but this really shouldn't be our standard approach. this won't scale very well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pstjohn what would you suggest instead for better scaling? I used it because it was there in sharded_eden_dataloader.py (which I was following as a model). If there's another dataloader that I can look at as an example please let me know.

database_path: Path to SQLite database with sequences
seq_length: Window size (8192 for LLAMA3)
tokenizer: ASCII tokenizer for nucleotides
stride: Stride between windows (7992 = 8192-200 for EVO2 overlap)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we get some sort of explanation why stride is required?

random.seed(seed)

# Step 1: Load sequences from database
self._load_sequences()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the memory requirements for this?

# Step 3: Randomize lookups - shuffle window indices
self._randomize_window_access()

print(f"Dataset ready: {len(self.window_mappings):,} windows with EVO2-style tiling")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger over print

with sqlite3.connect(str(self.database_path)) as conn:
cursor = conn.cursor()
cursor.execute("SELECT contig_id, length FROM sequences ORDER BY contig_id")
self.sequences = cursor.fetchall()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you're loading the entire sqlite into memory? If that sqlite grows you will go OOM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just the id,length tuple, not the full sequences, so it doesn't OOM in practice on a 1T database.

cursor.execute("SELECT contig_id, length FROM sequences ORDER BY contig_id")
self.sequences = cursor.fetchall()

print(f"Loaded {len(self.sequences)} sequences from database")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger > print


def __len__(self) -> int:
"""Return number of windows."""
return len(self.window_mappings)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think that the length of the dataset would be the number of rows in the SQlite, but here I see its the number of window mappings?

loss_mask[tokens == self.tokenizer.bos_token_id] = 0
loss_mask[tokens == self.tokenizer.eos_token_id] = 0

return {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're casting seq_idx to your data type, do you want to cast the rest of the things to a specific data type here?

# Create simple dataset for Bruno's training loop
dataset = GenomicSequenceDataset(
database_path=args.dataset.database_path,
seq_length=args.dataset.seq_length,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is seq_length --> max_seq_len? Since if you have sequences < seq_len you're going to pad them right?

tokenizer=tokenizer,
stride=args.dataset.get("stride", args.dataset.seq_length - 200), # EVO2 default
min_window_length=args.dataset.get("min_window_length", 1000),
seed=args.dataset.get("seed", 42),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed should come from a config indeally

# Calculate epoch length
epoch_len = len(dataloader) # Use dataloader length for distributed setting

print(f"Created genomic dataloader: {len(dataset):,} windows, {epoch_len:,} batches per epoch")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger > print please.
logger.info(f"blahblah {}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants