Skip to content

Commit 7085976

Browse files
authored
Merge pull request #5 from slaclab/passthru-custom-attr
passthrough all custom attributes to slurm job
2 parents cfc0c87 + e67226c commit 7085976

1 file changed

Lines changed: 35 additions & 17 deletions

File tree

app/s3df/compute_adapter.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,27 @@
3535
from slurmrestd_client.models.slurm_v0041_post_job_submit_request_jobs_inner_time_limit import (
3636
SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit,
3737
)
38-
from fastapi import Response
38+
from fastapi import HTTPException, Response
39+
from pydantic import ConfigDict, ValidationError
3940

4041
from ..routers.compute import models as compute_models
4142
from ..types.user import User
4243
from ..routers.status import models as status_models
4344

4445
logger = logging.getLogger(__name__)
4546

47+
48+
class SlurmV0041PostJobSubmitRequestJobStrict(SlurmV0041PostJobSubmitRequestJob):
49+
# we reject unexpected fields to enable raising ValidationError
50+
# TODO: we could see if the autogeneration could be configured to make
51+
# this strict by default
52+
model_config = ConfigDict(
53+
populate_by_name=True,
54+
validate_assignment=True,
55+
protected_namespaces=(),
56+
extra="forbid",
57+
)
58+
4659
# ---------------------------------------------------------------------------
4760
# Slurm → IRI state mapping
4861
# ---------------------------------------------------------------------------
@@ -270,22 +283,27 @@ async def submit_job(
270283
partition = partition or os.environ.get("SLURM_DEFAULT_PARTITION")
271284
account = account or os.environ.get("SLURM_DEFAULT_ACCOUNT")
272285

273-
slurm_job = SlurmV0041PostJobSubmitRequestJob(
274-
nodes=str(node_count),
275-
time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins),
276-
name=name,
277-
script=executable,
278-
partition=partition,
279-
account=account,
280-
environment=environment,
281-
current_working_directory=cwd,
282-
standard_output=stdout,
283-
standard_error=stderr,
284-
)
285-
286-
# Job array support: e.g. custom_attributes={"array": "0-19"}
287-
if job_spec.attributes and "array" in job_spec.attributes.custom_attributes:
288-
slurm_job.array = job_spec.attributes.custom_attributes["array"]
286+
custom_attributes = job_spec.attributes.custom_attributes if job_spec.attributes else {}
287+
288+
try:
289+
slurm_job = SlurmV0041PostJobSubmitRequestJobStrict(
290+
nodes=str(node_count),
291+
time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins),
292+
name=name,
293+
script=executable,
294+
partition=partition,
295+
account=account,
296+
environment=environment,
297+
current_working_directory=cwd,
298+
standard_output=stdout,
299+
standard_error=stderr,
300+
**custom_attributes
301+
)
302+
except (ValidationError, TypeError) as exc:
303+
raise HTTPException(
304+
status_code=422,
305+
detail=f"Invalid job submission parameters: {exc}",
306+
) from exc
289307

290308
req = SlurmV0041PostJobSubmitRequest(job=slurm_job)
291309

0 commit comments

Comments
 (0)