55import logging
66from math import ceil
77from typing import Any , Callable , Generic , Sequence , TypeVar , cast
8+ from uuid import UUID
89import emu_mps
9- from pasqal_cloud import SDK
10- from pasqal_cloud .batch import Batch
10+ from pasqal_cloud import SDK , Batch , BatchFilters
1111import pulser as pl
1212from pulser .devices import Device
1313from pulser .json .abstract_repr .deserializer import deserialize_device
@@ -341,7 +341,7 @@ def __init__(
341341 username : str ,
342342 password : str | None = None ,
343343 device_name : str = "FRESNEL" ,
344- batch_id : list [str ] | None = None ,
344+ batch_ids : list [str ] | None = None ,
345345 ):
346346 sdk = SDK (username = username , project_id = project_id , password = password )
347347
@@ -351,65 +351,79 @@ def __init__(
351351
352352 super ().__init__ (path = path , device = device , compiler = compiler )
353353 self ._sdk = sdk
354- self ._batch_id = batch_id
354+ self ._batch_ids : list [ str ] | None = batch_ids
355355
356356 @property
357357 def batch_ids (self ) -> list [str ] | None :
358- return self ._batch_id
358+ return self ._batch_ids
359359
360360 async def run (self ) -> list [ProcessedData ]:
361361 if len (self .sequences ) == 0 :
362362 logger .warning ("No sequences to run, did you forget to call compile()?" )
363363 return []
364364
365- batches : list [Batch ] = []
366- if self ._batch_id is None :
365+ device : pl .devices .Device = self .sequences [0 ].sequence .device
366+ # The API doesn't support run longer than 500 jobs.
367+ # If we want to add more runs, we'll need to split them across several jobs.
368+ max_runs = device .max_runs if isinstance (device .max_runs , int ) else 500
369+
370+ if self ._batch_ids is None :
367371 # Enqueue jobs.
368372 self ._batch_ids = []
369373 for compiled in self .sequences :
370- logger .debug ("Executing compiled graph #%s" , id )
374+ logger .debug ("Enqueuing execution of compiled graph #%s" , compiled . graph . id )
371375 batch = self ._sdk .create_batch (
372376 compiled .sequence .to_abstract_repr (),
373- jobs = [{"runs" : 1000 }],
377+ jobs = [{"runs" : max_runs }],
374378 wait = False ,
375379 )
376380 logger .info (
377381 "Remote execution of compiled graph #%s starting, batched with id %s" ,
378- id ,
382+ compiled . graph . id ,
379383 batch .id ,
380384 )
381- batches .append (batch )
382385 self ._batch_ids .append (batch .id )
383386 logger .info (
384387 "All %s jobs enqueued for remote execution, with ids %s" ,
385- len (batches ),
388+ len (self . _batch_ids ),
386389 self ._batch_ids ,
387390 )
388- else :
389- # Get jobs back from the cloud API.
390- for batch_id in self ._batch_id :
391- batches .append (self ._sdk .get_batch (batch_id ))
391+ assert len (self ._batch_ids ) == len (self .sequences )
392392
393393 # Now wait until all batches are complete.
394- waiting = True
395- while waiting :
396- waiting = False
397- for batch in batches :
398- if batch .status in {"PENDING" , "RUNNING" }:
399- # At least one job is pending, let's wait.
400- await sleep (2 )
401- logger .debug ("Job %s is still incomplete" )
402- waiting = True
403-
404- logger .info ("All jobs complete, %s sequences executed" , len (batches ))
394+ pending_batch_ids : set [str ] = set (self ._batch_ids )
395+ completed_batches : dict [str , Batch ] = {}
396+
397+ while len (pending_batch_ids ) > 0 :
398+ await sleep (delay = 2 )
399+ # We can check up to 100 batches in a single query with the SDK, so let's do that.
400+ MAX_BATCH_LEN = 100
401+ check_ids : list [str | UUID ] = [cast (str | UUID , id ) for id in pending_batch_ids ][
402+ :MAX_BATCH_LEN
403+ ]
404+ # Update their status.
405+ check_batches = self ._sdk .get_batches (
406+ filters = BatchFilters (id = check_ids )
407+ ) # Ideally, this should be async, see https://github.com/pasqal-io/pasqal-cloud/issues/162.
408+ for batch in check_batches .results :
409+ assert isinstance (batch , Batch )
410+ if batch .status not in {"PENDING" , "RUNNING" }:
411+ logger .debug ("Job %s is now complete" , batch .id )
412+ pending_batch_ids .discard (batch .id )
413+ completed_batches [batch .id ] = batch
414+
415+ logger .info ("All jobs complete, %s sequences executed" , len (completed_batches ))
405416
406417 # At this point, all batches are complete.
407- # Now collect data. We rely upon the fact
408- # that we enqueued exactly one batch per sequence, in the same order.
418+ # Now, collect data.
419+ #
420+ # We rely upon the fact that for any `i`,
421+ # `self._batch_id[i]` is the batch for `self.sequences[i]`.
409422 processed_data : list [ProcessedData ] = []
410- for i , batch in enumerate (batches ):
411- # Note: There's only one job per batch.
423+ for i , id in enumerate (self . _batch_ids ):
424+ batch = completed_batches [ id ]
412425 assert len (batch .jobs ) == 1
426+
413427 for _ , job in batch .jobs .items ():
414428 if job .status == "DONE" :
415429 state_dict = job .result
0 commit comments