diff --git a/snakemake_executor_plugin_slurm/submit_string.py b/snakemake_executor_plugin_slurm/submit_string.py index 59c3787..25cbb20 100644 --- a/snakemake_executor_plugin_slurm/submit_string.py +++ b/snakemake_executor_plugin_slurm/submit_string.py @@ -1,5 +1,18 @@ from snakemake_executor_plugin_slurm_jobstep import get_cpu_setting from types import SimpleNamespace +import shlex + + +def safe_quote(value): + """ + Safely quote a parameter value using shlex.quote. + Handles None values and converts to string if needed. + Returns empty quotes for empty strings. + """ + str_value = str(value) + if str_value == "": + return "''" + return shlex.quote(str_value) def get_submit_command(job, params): @@ -10,37 +23,41 @@ def get_submit_command(job, params): params = SimpleNamespace(**params) call = ( - f"sbatch " - f"--parsable " - f"--job-name {params.run_uuid} " - f'--output "{params.slurm_logfile}" ' - f"--export=ALL " - f'--comment "{params.comment_str}"' + "sbatch " + "--parsable " + f"--job-name {safe_quote(params.run_uuid)} " + f"--output {safe_quote(params.slurm_logfile)} " + "--export=ALL " + f"--comment {safe_quote(params.comment_str)}" ) # No accout or partition checking is required, here. # Checking is done in the submit function. # here, only the string is used, as it already contains - # '-A {account_name}' + # "-A '{account_name}'" call += f" {params.account}" # here, only the string is used, as it already contains - # '- p {partition_name}' + # "- p '{partition_name}'" call += f" {params.partition}" if job.resources.get("clusters"): - call += f" --clusters {job.resources.clusters}" + call += f" --clusters {safe_quote(job.resources.clusters)}" if job.resources.get("runtime"): - call += f" -t {job.resources.runtime}" + call += f" -t {safe_quote(job.resources.runtime)}" - if job.resources.get("constraint") or isinstance( - job.resources.get("constraint"), str - ): - call += f" -C '{job.resources.get('constraint')}'" + # Both, constraint and qos are optional. + # If not set, they will not be added to the sbatch call. + # If explicitly set to an empty string, + # `--constraint ''` or `--qos ''` will be added. + constraint = job.resources.get("constraint") + if constraint is not None: + call += f" -C {safe_quote(constraint)}" - if job.resources.get("qos") or isinstance(job.resources.get("qos"), str): - call += f" --qos='{job.resources.qos}'" + qos = job.resources.get("qos") + if qos is not None: + call += f" --qos={safe_quote(qos)}" if job.resources.get("mem_mb_per_cpu"): call += f" --mem-per-cpu {job.resources.mem_mb_per_cpu}" @@ -77,6 +94,7 @@ def get_submit_command(job, params): # ensure that workdir is set correctly # use short argument as this is the same in all slurm versions # (see https://github.com/snakemake/snakemake/issues/2014) - call += f" -D '{params.workdir}'" + if params.workdir: + call += f" -D {safe_quote(params.workdir)}" return call diff --git a/tests/tests.py b/tests/tests.py index cf38b74..9a78689 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -388,7 +388,7 @@ def test_constraint_resource(self, mock_job): process_mock.returncode = 0 mock_popen.return_value = process_mock - assert " -C 'haswell'" in get_submit_command(job, params) + assert " -C haswell" in get_submit_command(job, params) def test_qos_resource(self, mock_job): """Test that the qos resource is correctly added to the sbatch command.""" @@ -412,7 +412,7 @@ def test_qos_resource(self, mock_job): process_mock.returncode = 0 mock_popen.return_value = process_mock - assert " --qos='normal'" in get_submit_command(job, params) + assert " --qos=normal" in get_submit_command(job, params) def test_both_constraint_and_qos(self, mock_job): """Test that both constraint and qos resources can be used together.""" @@ -439,8 +439,8 @@ def test_both_constraint_and_qos(self, mock_job): # Assert both resources are correctly included sbatch_command = get_submit_command(job, params) - assert " --qos='high'" in sbatch_command - assert " -C 'haswell'" in sbatch_command + assert " --qos=high" in sbatch_command + assert " -C haswell" in sbatch_command def test_no_resources(self, mock_job): """ @@ -517,7 +517,7 @@ def test_empty_qos(self, mock_job): process_mock.communicate.return_value = ("123", "") process_mock.returncode = 0 mock_popen.return_value = process_mock - # Assert the qoes is included (even if empty) + # Assert the qos is included (even if empty) assert "--qos=''" in get_submit_command(job, params) def test_taks(self, mock_job):