Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 124 additions & 21 deletions run_batch_of_slides.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
"""
import os
import argparse
import shutil
import torch
from typing import Any
from typing import Any, List, Sequence, Tuple, Optional

torch.multiprocessing.set_start_method('spawn', force=True)

from trident import Processor
from trident.patch_encoder_models import encoder_registry as patch_encoder_registry
Expand Down Expand Up @@ -140,6 +143,77 @@ def generate_help_text() -> str:
return parser.format_help()


def slide_outputs_complete(slide_path: str, args: argparse.Namespace, task_sequence: Sequence[str]) -> bool:
"""Return True if all required outputs exist for the slide for the requested tasks."""
slide_stem = os.path.splitext(os.path.basename(slide_path))[0]
coords_dir = args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap'

for task_name in task_sequence:
if task_name == 'seg':
if not os.path.exists(os.path.join(args.job_dir, 'contours', f'{slide_stem}.jpg')):
return False
elif task_name == 'coords':
if not os.path.exists(os.path.join(args.job_dir, coords_dir, 'patches', f'{slide_stem}_patches.h5')):
return False
elif task_name == 'feat':
# Check if feature file exists
if args.slide_encoder is None:
features_dir = os.path.join(args.job_dir, coords_dir, f'features_{args.patch_encoder}')
else:
features_dir = os.path.join(args.job_dir, coords_dir, f'slide_features_{args.slide_encoder}')

if not features_dir or not os.path.isdir(features_dir):
return False
if not any(os.path.exists(os.path.join(features_dir, f'{slide_stem}.{ext}')) for ext in ('h5', 'pt')):
return False
else:
return False
return True


def filter_completed_slides(slide_paths: List[str], args: argparse.Namespace, task_sequence: Sequence[str]) -> List[str]:
"""Filter out slides whose outputs already exist for all requested tasks."""
return [slide for slide in slide_paths if not slide_outputs_complete(slide, args, task_sequence)]


def cleanup_files(job_dir: str, cache_dir: Optional[str] = None) -> Tuple[int, int]:
"""
Remove stale lock files and optionally clean cache directory.

Returns
-------
Tuple[int, int]
Number of lock files removed and cache items removed.
"""
# Remove lock files
lock_count = 0
if os.path.isdir(job_dir):
for root, _, files in os.walk(job_dir):
for filename in files:
if filename.endswith('.lock'):
try:
os.remove(os.path.join(root, filename))
lock_count += 1
except OSError:
pass

# Clean cache directory
cache_count = 0
if cache_dir and os.path.isdir(cache_dir):
for item in os.listdir(cache_dir):
item_path = os.path.join(cache_dir, item)
try:
if os.path.isdir(item_path):
shutil.rmtree(item_path)
else:
os.remove(item_path)
cache_count += 1
except OSError:
pass

return lock_count, cache_count


def initialize_processor(args: argparse.Namespace) -> Processor:
"""
Initialize the Trident Processor with arguments set in `run_batch_of_slides`.
Expand Down Expand Up @@ -246,30 +320,46 @@ def main() -> None:
WSI caching is enabled. Supports segmentation, coordinate extraction,
and feature extraction tasks.
"""
from trident.IO import collect_valid_slides

args = parse_arguments()
args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'

# Cleanup stale lock files and cache directory
lock_count, cache_count = cleanup_files(args.job_dir, args.wsi_cache)
if lock_count:
print(f"[MAIN] Cleared {lock_count} stale lock file(s) under {args.job_dir}.")
if cache_count:
print(f"[MAIN] Cleared {cache_count} item(s) from cache directory {args.wsi_cache}.")

# Collect and filter slides once
all_slides = collect_valid_slides(
wsi_dir=args.wsi_dir,
custom_list_path=args.custom_list_of_wsis,
wsi_ext=args.wsi_ext,
search_nested=args.search_nested,
max_workers=args.max_workers
)
print(f"[MAIN] Found {len(all_slides)} valid slides in {args.wsi_dir}.")

task_sequence = ['seg', 'coords', 'feat'] if args.task == 'all' else [args.task]
pending_slides = filter_completed_slides(all_slides, args, task_sequence)

if (skipped := len(all_slides) - len(pending_slides)):
print(f"[MAIN] Skipping {skipped} slide(s) with completed outputs.")

if not pending_slides:
print('[MAIN] All requested work already complete. Nothing to process.')
return

if args.wsi_cache:
# === Parallel pipeline with caching ===

from queue import Queue
from threading import Thread

from trident.Concurrency import batch_producer, batch_consumer, cache_batch
from trident.IO import collect_valid_slides

queue = Queue(maxsize=1)
valid_slides = collect_valid_slides(
wsi_dir=args.wsi_dir,
custom_list_path=args.custom_list_of_wsis,
wsi_ext=args.wsi_ext,
search_nested=args.search_nested,
max_workers=args.max_workers
)
print(f"[MAIN] Found {len(valid_slides)} valid slides in {args.wsi_dir}.")

warm = valid_slides[:args.cache_batch_size]
warm = pending_slides[:args.cache_batch_size]
warmup_dir = os.path.join(args.wsi_cache, "batch_0")
print(f"[MAIN] Warmup caching batch: {warmup_dir}")
cache_batch(warm, warmup_dir)
Expand All @@ -288,9 +378,8 @@ def run_task_fn(processor: Processor, task_name: str) -> None:
run_task(processor, args)

producer = Thread(target=batch_producer, args=(
queue, valid_slides, args.cache_batch_size, args.cache_batch_size, args.wsi_cache
queue, pending_slides, args.cache_batch_size, args.cache_batch_size, args.wsi_cache
))

consumer = Thread(target=batch_consumer, args=(
queue, args.task, args.wsi_cache, processor_factory, run_task_fn
))
Expand All @@ -302,11 +391,25 @@ def run_task_fn(processor: Processor, task_name: str) -> None:
consumer.join()
else:
# === Sequential mode ===
processor = initialize_processor(args)
tasks = ['seg', 'coords', 'feat'] if args.task == 'all' else [args.task]
for task_name in tasks:
args.task = task_name
run_task(processor, args)
# Write pending slides to temporary CSV
import csv
import tempfile

with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, newline='') as f:
writer = csv.writer(f)
writer.writerow(['wsi'])
writer.writerows([[slide] for slide in pending_slides])
temp_csv_path = f.name

try:
args.custom_list_of_wsis = temp_csv_path
processor = initialize_processor(args)

for task_name in task_sequence:
args.task = task_name
run_task(processor, args)
finally:
os.unlink(temp_csv_path)


if __name__ == "__main__":
Expand Down
11 changes: 11 additions & 0 deletions trident/Concurrency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import gc
import time
import torch
import shutil
from typing import List, Callable, Any
Expand Down Expand Up @@ -37,6 +38,10 @@ def cache_batch(wsis: List[str], dest_dir: str) -> List[str]:
dest_mrxs_dir = os.path.join(dest_dir, os.path.basename(mrxs_dir))
shutil.copytree(mrxs_dir, dest_mrxs_dir)

# Create completion marker
with open(os.path.join(dest_dir, '.cache_complete'), 'w') as f:
f.write('done')

return copied


Expand Down Expand Up @@ -105,6 +110,12 @@ def batch_consumer(
break

ssd_batch_dir = os.path.join(cache_dir, f"batch_{batch_id}")

# Wait for cache completion marker
marker = os.path.join(ssd_batch_dir, '.cache_complete')
while not os.path.exists(marker):
time.sleep(0.5)

print(f"[CONSUMER] Processing batch {batch_id}: {ssd_batch_dir}")

processor = processor_factory(ssd_batch_dir)
Expand Down
20 changes: 14 additions & 6 deletions trident/slide_encoder_models/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,22 @@ def _build(self, pretrained=True):
raise Exception("Please install fairscale and gigapath using `pip install fairscale git+https://github.com/prov-gigapath/prov-gigapath.git`.")

# Make sure flash_attn is correct version
try:
import flash_attn; assert flash_attn.__version__ == '2.5.8'
except:
traceback.print_exc()
raise Exception("Please install flash_attn version 2.5.8 using `pip install flash_attn==2.5.8`.")
# try:
# import flash_attn; assert flash_attn.__version__ == '2.5.8'
# except:
# traceback.print_exc()
# raise Exception("Please install flash_attn version 2.5.8 using `pip install flash_attn==2.5.8`.")

if pretrained:
model = create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=True)
# Try to get local weights path first
weights_path = get_weights_path('slide', self.enc_name)
if weights_path:
print(f"Loading GigaPath slide encoder from local path: {weights_path}")
model = create_model(weights_path, "gigapath_slide_enc12l768d", 1536, global_pool=True)
else:
# Fallback to downloading from Hugging Face Hub
print("Local weights not found. Downloading from Hugging Face Hub...")
model = create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=True)
else:
model = create_model("", "gigapath_slide_enc12l768d", 1536, global_pool=True)

Expand Down