11from __future__ import annotations
22
3+ import collections
34import json
4- from dataclasses import dataclass
5+ import matplotlib
56
67import numpy as np
78import pulser as pl
89
910
10- @dataclass
1111class ProcessedData :
1212 """
1313 Data on a single graph obtained from the Quantum Device.
@@ -37,10 +37,14 @@ class ProcessedData:
3737
3838 sequence : pl .Sequence
3939 state_dict : dict [str , int ]
40+ _dist_excitation : np .ndarray
4041 target : int
4142
42- def __post_init__ (self ) -> None :
43- self .state_dict = _convert_np_int64_to_int (data = self .state_dict )
43+ def __init__ (self , sequence : pl .Sequence , state_dict : dict [str , np .int64 ], target : int ):
44+ self .sequence = sequence
45+ self .state_dict = _convert_np_int64_to_int (data = state_dict )
46+ self ._dist_excitation = dist_excitation (self .state_dict )
47+ self .target = target
4448
4549 def save_to_file (self , file_path : str ) -> None :
4650 with open (file_path , "w" ) as file :
@@ -61,6 +65,20 @@ def load_from_file(cls, file_path: str) -> "ProcessedData":
6165 target = tmp_data ["target" ],
6266 )
6367
68+ def dist_excitation (self , size : int | None = None ) -> np .ndarray :
69+ """
70+ Return the distribution of excitations for this graph.
71+
72+ Arguments:
73+ size: If specified, truncate or pad the array to this
74+ size.
75+ """
76+ if size is None or size == len (self ._dist_excitation ):
77+ return self ._dist_excitation .copy ()
78+ if size < len (self ._dist_excitation ):
79+ return np .resize (self ._dist_excitation , size )
80+ return np .pad (self ._dist_excitation , (0 , size - len (self ._dist_excitation )))
81+
6482 def draw_sequence (self ) -> None :
6583 """
6684 Draw the sequence on screen
@@ -73,8 +91,119 @@ def draw_register(self) -> None:
7391 """
7492 self .sequence .register .draw (blockade_radius = self .sequence .device .min_atom_distance + 0.01 )
7593
94+ def draw_excitation (self ) -> None :
95+ """
96+ Draw an histogram for the excitation level on screen
97+ """
98+ x = [str (i ) for i in range (len (self ._dist_excitation ))]
99+ matplotlib .pyplot .bar (x , self ._dist_excitation )
100+
101+
102+ def dist_excitation (state_dict : dict [str , int ], size : int | None = None ) -> np .ndarray :
103+ """
104+ Calculates the distribution of excitation energies from a dictionary of
105+ bitstrings to their respective counts.
106+
107+ Args:
108+ size (int | None): If specified, only keep `size` energy
109+ distributions in the output. Otherwise, keep all values.
110+
111+ Returns:
112+ A histogram of excitation energies.
113+ - index: an excitation level (i.e. a number of `1` bits in a
114+ bitstring)
115+ - value: normalized count of samples with this excitation level.
116+ """
117+
118+ if len (state_dict ) == 0 :
119+ return np .ndarray (0 )
120+
121+ if size is None :
122+ # If size is not specified, it's the length of bitstrings.
123+ # We assume that all bitstrings in `count_bitstring` have the
124+ # same length and we have just checked that it's not empty.
125+
126+ # Pick the length of the first bitstring.
127+ # We have already checked that `count_bitstring` is not empty.
128+ bitstring = next (iter (state_dict .keys ()))
129+ size = len (bitstring )
130+
131+ # Make mypy realize that `size` is now always an `int`.
132+ assert type (size ) is int
133+
134+ count_occupation : dict [int , int ] = collections .defaultdict (int )
135+ total = 0.0
136+ for bitstring , number in state_dict .items ():
137+ occupation = sum (1 for bit in bitstring if bit == "1" )
138+ count_occupation [occupation ] += number
139+ total += number
140+
141+ result = np .zeros (size + 1 , dtype = float )
142+ for occupation , count in count_occupation .items ():
143+ if occupation < size :
144+ result [occupation ] = count / total
145+
146+ return result
147+
76148
77149def _convert_np_int64_to_int (data : dict [str , np .int64 ]) -> dict [str , int ]:
150+ """
151+ Utility function: convert the values of a dict from `np.int64` to `int`,
152+ for serialization purposes.
153+ """
78154 return {
79155 key : (int (value ) if isinstance (value , np .integer ) else value ) for key , value in data .items ()
80156 }
157+
158+
159+ def save_dataset (dataset : list [ProcessedData ], file_path : str ) -> None :
160+ """Saves a dataset to a JSON file.
161+
162+ Args:
163+ dataset (list[ProcessedData]): The dataset to be saved, containing
164+ RegisterData instances.
165+ file_path (str): The path where the dataset will be saved as a JSON
166+ file.
167+
168+ Note:
169+ The data is stored in a format suitable for loading with load_dataset.
170+
171+ Returns:
172+ None
173+ """
174+ with open (file_path , "w" ) as file :
175+ data = [
176+ {
177+ "sequence" : instance .sequence .to_abstract_repr (),
178+ "state_dict" : instance .state_dict ,
179+ "target" : instance .target ,
180+ }
181+ for instance in dataset
182+ ]
183+ json .dump (data , file )
184+
185+
186+ def load_dataset (file_path : str ) -> list [ProcessedData ]:
187+ """Loads a dataset from a JSON file.
188+
189+ Args:
190+ file_path (str): The path to the JSON file containing the dataset.
191+
192+ Note:
193+ The data is loaded in the format that was used when saving with
194+ save_dataset.
195+
196+ Returns:
197+ A list of ProcessedData instances, corresponding to the data stored in
198+ the JSON file.
199+ """
200+ with open (file_path ) as file :
201+ data = json .load (file )
202+ return [
203+ ProcessedData (
204+ sequence = pl .Sequence .from_abstract_repr (item ["sequence" ]),
205+ state_dict = item ["state_dict" ],
206+ target = item ["target" ],
207+ )
208+ for item in data
209+ ]
0 commit comments