-
Couldn't load subscription status.
- Fork 92
[WIP] LLAMA3 Integration with BioNeMo Recipes with ASCII Tokenization/EVO2-Style Dataloader #1205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] LLAMA3 Integration with BioNeMo Recipes with ASCII Tokenization/EVO2-Style Dataloader #1205
Conversation
…ration into bionemo recipes Signed-off-by: savitha-eng <[email protected]>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the ✨ Finishing touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note:
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leaving a first set of comments, will keep reviewing later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this folder isn't a package, so this file won't do anything 🤷
| from tokenizer import NucleotideASCIITokenizer | ||
|
|
||
|
|
||
| class GenomicSequenceDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This dataset class isn't quite right; high-level if we want to be consistent with the ESM-2 recipe (and generally follow the huggingface flow) we'd want a dataset that doesn't do tokenization, it just returns the full sequences. I'm not seeing a reason that couldn't essentially be a bunch of sharded parquet files with the raw sequences in them, like we do for esm-2:
https://huggingface.co/datasets/nvidia/esm2_uniref_pretraining_data
then you'd want various map operations to tokenize the dataset from there, and possibly a sampler to shuffle it.
|
|
||
| print(f"Loaded {len(self.sequences)} sequences from database") | ||
|
|
||
| def _create_window_mappings(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does the stride parameter from hf tokenizers not do this?
https://huggingface.co/docs/transformers/en/main_classes/tokenizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are two potential issues:
- Shuffling would happen at the dataset level, so you would randomly iterate through input sequences, but then each of those input sequences would get mapped to many ordered windows by the tokenizer. As @pstjohn suggested in the slack thread, this could be helped with some kind of a local random buffer. Depending on the size of the buffer and how many chunks are returned from a given genome sequence it may take a while before you eventually start sampling from different sequences.
- You would need to handle the fact that your dataset returns N items and your tokenizer applied to the dataset returns M >>= N items given the 1-many relationship of the tokenizer to the inputs to the tokenizer if you went this path. If you pass
stride, max_length, and return_overflowing_tokens=Trueto your tokenizer call wherever that happens then that would return ordered strided window samplings for a given input which you would then want to reshuffle and repack into smaller batches of a target size ideally.
|
|
||
| print(f"Created {total_windows:,} window mappings with stride={self.stride} (overlap={self.seq_length - self.stride}bp)") | ||
|
|
||
| def _randomize_window_access(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this won't scale very well :D this just shuffles the dataset though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a shuffle buffer instead?
| seq_idx, contig_id, start_pos, window_length = self.window_mappings[actual_idx] | ||
|
|
||
| # Retrieve sequence from SQL database (speed requirement) | ||
| with sqlite3.connect(str(self.database_path)) as conn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, we used sqlite for esm2 in bionemo2 for various reasons, but this really shouldn't be our standard approach. this won't scale very well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pstjohn what would you suggest instead for better scaling? I used it because it was there in sharded_eden_dataloader.py (which I was following as a model). If there's another dataloader that I can look at as an example please let me know.
| database_path: Path to SQLite database with sequences | ||
| seq_length: Window size (8192 for LLAMA3) | ||
| tokenizer: ASCII tokenizer for nucleotides | ||
| stride: Stride between windows (7992 = 8192-200 for EVO2 overlap) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we get some sort of explanation why stride is required?
| random.seed(seed) | ||
|
|
||
| # Step 1: Load sequences from database | ||
| self._load_sequences() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the memory requirements for this?
| # Step 3: Randomize lookups - shuffle window indices | ||
| self._randomize_window_access() | ||
|
|
||
| print(f"Dataset ready: {len(self.window_mappings):,} windows with EVO2-style tiling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logger over print
| with sqlite3.connect(str(self.database_path)) as conn: | ||
| cursor = conn.cursor() | ||
| cursor.execute("SELECT contig_id, length FROM sequences ORDER BY contig_id") | ||
| self.sequences = cursor.fetchall() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you're loading the entire sqlite into memory? If that sqlite grows you will go OOM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's just the id,length tuple, not the full sequences, so it doesn't OOM in practice on a 1T database.
| cursor.execute("SELECT contig_id, length FROM sequences ORDER BY contig_id") | ||
| self.sequences = cursor.fetchall() | ||
|
|
||
| print(f"Loaded {len(self.sequences)} sequences from database") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger > print
|
|
||
| def __len__(self) -> int: | ||
| """Return number of windows.""" | ||
| return len(self.window_mappings) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would think that the length of the dataset would be the number of rows in the SQlite, but here I see its the number of window mappings?
| loss_mask[tokens == self.tokenizer.bos_token_id] = 0 | ||
| loss_mask[tokens == self.tokenizer.eos_token_id] = 0 | ||
|
|
||
| return { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're casting seq_idx to your data type, do you want to cast the rest of the things to a specific data type here?
| # Create simple dataset for Bruno's training loop | ||
| dataset = GenomicSequenceDataset( | ||
| database_path=args.dataset.database_path, | ||
| seq_length=args.dataset.seq_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is seq_length --> max_seq_len? Since if you have sequences < seq_len you're going to pad them right?
| tokenizer=tokenizer, | ||
| stride=args.dataset.get("stride", args.dataset.seq_length - 200), # EVO2 default | ||
| min_window_length=args.dataset.get("min_window_length", 1000), | ||
| seed=args.dataset.get("seed", 42), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed should come from a config indeally
| # Calculate epoch length | ||
| epoch_len = len(dataloader) # Use dataloader length for distributed setting | ||
|
|
||
| print(f"Created genomic dataloader: {len(dataset):,} windows, {epoch_len:,} batches per epoch") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger > print please.
logger.info(f"blahblah {}")
Description
This PR implements a complete end-to-end genomic foundation model training pipeline using LLAMA3 architecture with EVO2-compatible data processing and distributed training capabilities.
Key Components:
Genomic Tokenization System:
ASCII character-level tokenization with 256-token vocabulary covering complete ASCII character set
Direct nucleotide encoding supporting standard nucleotides (A, T, C, G, N) and IUPAC ambiguity codes
EVO2-compatible byte-level tokenization methodology
Integrated BOS, EOS, and PAD tokens for proper sequence boundary management
EVO2-Style Genomic Dataloader (following the Sharded Eden Dataloader):
Efficient sequence windowing: 8192-token sequences with 7992-token stride (200bp overlap) - following the EVO2 approach
Randomization: fixed window positions with shuffled access order for training stability (
SQLite backend integration for high-performance sequence retrieval
Full FSDP and DDP distributed training compatibility
Memory-efficient streaming with configurable parameters
Currently processes small subsample dataset (1,024 sequences into 254,043 training windows) with future scaling planned for production datasets
Model Pipeline Architecture:
Dynamic RoPE implementation supporting arbitrary sequence lengths with ESM-2 style computation
Random weight initialization using from-config approach to avoid text-domain bias
Memory-optimized LLAMA3: 8-layer, 403M parameter configuration for development/testing
FSDP integration with TransformerEngine optimization
Complete Weights & Biases integration for experiment tracking
Quick Validation Results: - Much more testing needed
L0 Sanity: 4-layer model, 25 steps, 3-minute validation
L1 Pilot: 8-layer model, 200 steps, 50-minute training
Stable training convergence demonstrated (loss: 6.1 → 1.38 over 200 steps)
Consistent 23.5GB GPU memory usage with 4.1 iterations/second sustained performance
Note: Much more extensive validation is needed with full-scale models and production datasets
Usage
Quick validation run:
cd bionemo-recipes/recipes/llama_native_te_nvfsdp torchrun --nproc_per_node=1 train.py --config-name=L0_sanityMore Realistic training run (still very small scale):
cd bionemo-recipes/recipes/llama_native_te_nvfsdp torchrun --nproc_per_node=1 train.py --config-name=L1_pilotConfiguration customization:
Type of changes
[x] New feature (non-breaking change which adds functionality)
[ ] Bug fix (non-breaking change which fixes an issue)
[ ] Refactor
[ ] Documentation update
[ ] Other (please describe):
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
[ ] ciflow:skip - Skip all CI tests for this PR
[ ] ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
[ ] ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
[ ] ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
[x] ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.
Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.
For more details, see CONTRIBUTING
Note
By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
/ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.
Pre-submit Checklist
[x] I have tested these changes locally
[ ] I have updated the documentation accordingly
[ ] I have added/updated tests as needed
[] All existing tests pass successfully
Next Steps: