|
35 | 35 | from slurmrestd_client.models.slurm_v0041_post_job_submit_request_jobs_inner_time_limit import ( |
36 | 36 | SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit, |
37 | 37 | ) |
38 | | -from fastapi import Response |
| 38 | +from fastapi import HTTPException, Response |
| 39 | +from pydantic import ConfigDict, ValidationError |
39 | 40 |
|
40 | 41 | from ..routers.compute import models as compute_models |
41 | 42 | from ..types.user import User |
42 | 43 | from ..routers.status import models as status_models |
43 | 44 |
|
44 | 45 | logger = logging.getLogger(__name__) |
45 | 46 |
|
| 47 | + |
| 48 | +class SlurmV0041PostJobSubmitRequestJobStrict(SlurmV0041PostJobSubmitRequestJob): |
| 49 | + # we reject unexpected fields to enable raising ValidationError |
| 50 | + # TODO: we could see if the autogeneration could be configured to make |
| 51 | + # this strict by default |
| 52 | + model_config = ConfigDict( |
| 53 | + populate_by_name=True, |
| 54 | + validate_assignment=True, |
| 55 | + protected_namespaces=(), |
| 56 | + extra="forbid", |
| 57 | + ) |
| 58 | + |
46 | 59 | # --------------------------------------------------------------------------- |
47 | 60 | # Slurm → IRI state mapping |
48 | 61 | # --------------------------------------------------------------------------- |
@@ -270,22 +283,27 @@ async def submit_job( |
270 | 283 | partition = partition or os.environ.get("SLURM_DEFAULT_PARTITION") |
271 | 284 | account = account or os.environ.get("SLURM_DEFAULT_ACCOUNT") |
272 | 285 |
|
273 | | - slurm_job = SlurmV0041PostJobSubmitRequestJob( |
274 | | - nodes=str(node_count), |
275 | | - time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins), |
276 | | - name=name, |
277 | | - script=executable, |
278 | | - partition=partition, |
279 | | - account=account, |
280 | | - environment=environment, |
281 | | - current_working_directory=cwd, |
282 | | - standard_output=stdout, |
283 | | - standard_error=stderr, |
284 | | - ) |
285 | | - |
286 | | - # Job array support: e.g. custom_attributes={"array": "0-19"} |
287 | | - if job_spec.attributes and "array" in job_spec.attributes.custom_attributes: |
288 | | - slurm_job.array = job_spec.attributes.custom_attributes["array"] |
| 286 | + custom_attributes = job_spec.attributes.custom_attributes if job_spec.attributes else {} |
| 287 | + |
| 288 | + try: |
| 289 | + slurm_job = SlurmV0041PostJobSubmitRequestJobStrict( |
| 290 | + nodes=str(node_count), |
| 291 | + time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins), |
| 292 | + name=name, |
| 293 | + script=executable, |
| 294 | + partition=partition, |
| 295 | + account=account, |
| 296 | + environment=environment, |
| 297 | + current_working_directory=cwd, |
| 298 | + standard_output=stdout, |
| 299 | + standard_error=stderr, |
| 300 | + **custom_attributes |
| 301 | + ) |
| 302 | + except (ValidationError, TypeError) as exc: |
| 303 | + raise HTTPException( |
| 304 | + status_code=422, |
| 305 | + detail=f"Invalid job submission parameters: {exc}", |
| 306 | + ) from exc |
289 | 307 |
|
290 | 308 | req = SlurmV0041PostJobSubmitRequest(job=slurm_job) |
291 | 309 |
|
|
0 commit comments