From fde7d007656800160a594afb198d398dffef4e71 Mon Sep 17 00:00:00 2001 From: Shane St Savage Date: Thu, 3 Jul 2025 21:32:06 -0700 Subject: [PATCH] Add --classify-batch-size arg --- src/python/ifcb_analysis/process.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/python/ifcb_analysis/process.py b/src/python/ifcb_analysis/process.py index 66659e9c..0175cc04 100644 --- a/src/python/ifcb_analysis/process.py +++ b/src/python/ifcb_analysis/process.py @@ -32,6 +32,7 @@ def process_bin( model_config: classify.KerasModelConfig, extract_images: bool = True, classify_images: bool = True, + classify_batch_size: int = None, force: bool = False, dask_log_path: str = None, dask_log_level: int = None, @@ -158,7 +159,7 @@ def process_bin( if classify_images: logging.info(f'Classifying images and saving to {class_fname}') - predictions_df = classify.predict(model_config, image_stack) + predictions_df = classify.predict(model_config, image_stack, classify_batch_size) # Since classify.predict (which calls Model.predict) is run in a for loop, memory consumption # will build up and result in an OOM error, so we excplicitly clear it out after each model run. @@ -278,6 +279,7 @@ def process( date_dirs: bool = True, extract_images: bool = True, classify_images: bool = True, + classify_batch_size: int = None, force: bool = False, use_dask: bool = True, dask_log_path: str = None, @@ -321,6 +323,7 @@ def process( [model_config] * len(bins), [extract_images] * len(bins), [classify_images] * len(bins), + [classify_batch_size] * len(bins), [force] * len(bins) ] @@ -348,7 +351,15 @@ def process( outdir = output_dir try: - process_bin(bin, outdir, model_config, extract_images, classify_images, force) + process_bin( + file=bin, + outdir=outdir, + model_config=model_config, + extract_images=extract_images, + classify_images=classify_images, + classify_batch_size=classify_batch_size, + force=force, + ) except Exception as e: logging.error(f'Error processing {bin}: {e}') @@ -363,6 +374,7 @@ def yaml_config_callback(ctx, param, value): @click.command() @click.option('--extract-images/--no-extract-images', default=True) @click.option('--classify-images/--no-classify-images', default=True) +@click.option('--classify-batch-size', type=click.INT, default=64) @click.option('--force/--no-force', default=False) @click.option('--log-level', default='INFO') @click.option('--log-file', type=click.Path(writable=True, dir_okay=False)) @@ -386,6 +398,7 @@ def cli( ctx, extract_images: bool, classify_images: bool, + classify_batch_size: int, force: bool, log_level: str, log_file: Path, @@ -439,6 +452,7 @@ def cli( date_dirs=date_dirs, extract_images=extract_images, classify_images=classify_images, + classify_batch_size=classify_batch_size, force=force, use_dask=use_dask, dask_log_path=dask_log_path,