Skip to content
52 changes: 35 additions & 17 deletions snakemake_executor_plugin_slurm/submit_string.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from snakemake_executor_plugin_slurm_jobstep import get_cpu_setting
from types import SimpleNamespace
import shlex


def safe_quote(value):
"""
Safely quote a parameter value using shlex.quote.
Handles None values and converts to string if needed.
Returns empty quotes for empty strings.
"""
str_value = str(value)
if str_value == "":
return "''"
return shlex.quote(str_value)


def get_submit_command(job, params):
Expand All @@ -10,37 +23,41 @@ def get_submit_command(job, params):
params = SimpleNamespace(**params)

call = (
f"sbatch "
f"--parsable "
f"--job-name {params.run_uuid} "
f'--output "{params.slurm_logfile}" '
f"--export=ALL "
f'--comment "{params.comment_str}"'
"sbatch "
"--parsable "
f"--job-name {safe_quote(params.run_uuid)} "
f"--output {safe_quote(params.slurm_logfile)} "
"--export=ALL "
f"--comment {safe_quote(params.comment_str)}"
)

# No accout or partition checking is required, here.
# Checking is done in the submit function.

# here, only the string is used, as it already contains
# '-A {account_name}'
# "-A '{account_name}'"
call += f" {params.account}"
# here, only the string is used, as it already contains
# '- p {partition_name}'
# "- p '{partition_name}'"
call += f" {params.partition}"

if job.resources.get("clusters"):
call += f" --clusters {job.resources.clusters}"
call += f" --clusters {safe_quote(job.resources.clusters)}"

if job.resources.get("runtime"):
call += f" -t {job.resources.runtime}"
call += f" -t {safe_quote(job.resources.runtime)}"

if job.resources.get("constraint") or isinstance(
job.resources.get("constraint"), str
):
call += f" -C '{job.resources.get('constraint')}'"
# Both, constraint and qos are optional.
# If not set, they will not be added to the sbatch call.
# If explicitly set to an empty string,
# `--constraint ''` or `--qos ''` will be added.
constraint = job.resources.get("constraint")
if constraint is not None:
call += f" -C {safe_quote(constraint)}"

if job.resources.get("qos") or isinstance(job.resources.get("qos"), str):
call += f" --qos='{job.resources.qos}'"
qos = job.resources.get("qos")
if qos is not None:
call += f" --qos={safe_quote(qos)}"

if job.resources.get("mem_mb_per_cpu"):
call += f" --mem-per-cpu {job.resources.mem_mb_per_cpu}"
Expand Down Expand Up @@ -77,6 +94,7 @@ def get_submit_command(job, params):
# ensure that workdir is set correctly
# use short argument as this is the same in all slurm versions
# (see https://github.com/snakemake/snakemake/issues/2014)
call += f" -D '{params.workdir}'"
if params.workdir:
call += f" -D {safe_quote(params.workdir)}"

return call
10 changes: 5 additions & 5 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_constraint_resource(self, mock_job):
process_mock.returncode = 0
mock_popen.return_value = process_mock

assert " -C 'haswell'" in get_submit_command(job, params)
assert " -C haswell" in get_submit_command(job, params)

def test_qos_resource(self, mock_job):
"""Test that the qos resource is correctly added to the sbatch command."""
Expand All @@ -412,7 +412,7 @@ def test_qos_resource(self, mock_job):
process_mock.returncode = 0
mock_popen.return_value = process_mock

assert " --qos='normal'" in get_submit_command(job, params)
assert " --qos=normal" in get_submit_command(job, params)

def test_both_constraint_and_qos(self, mock_job):
"""Test that both constraint and qos resources can be used together."""
Expand All @@ -439,8 +439,8 @@ def test_both_constraint_and_qos(self, mock_job):

# Assert both resources are correctly included
sbatch_command = get_submit_command(job, params)
assert " --qos='high'" in sbatch_command
assert " -C 'haswell'" in sbatch_command
assert " --qos=high" in sbatch_command
assert " -C haswell" in sbatch_command

def test_no_resources(self, mock_job):
"""
Expand Down Expand Up @@ -517,7 +517,7 @@ def test_empty_qos(self, mock_job):
process_mock.communicate.return_value = ("123", "")
process_mock.returncode = 0
mock_popen.return_value = process_mock
# Assert the qoes is included (even if empty)
# Assert the qos is included (even if empty)
assert "--qos=''" in get_submit_command(job, params)

def test_taks(self, mock_job):
Expand Down