33from typing import (
44 Any ,
55 Callable ,
6+ DefaultDict ,
67 Dict ,
78 List ,
89 Mapping ,
910 Optional ,
1011 Sequence ,
1112 Tuple ,
1213 Union ,
13- cast ,
1414)
15+ from collections import defaultdict
1516import warnings
1617
1718import ipyleaflet # type: ignore
@@ -51,21 +52,22 @@ def __init__(
5152 # data resources, and so column names and dtype need to be
5253 # passed in as parameters.
5354 self ._aim_metadata_columns : Optional [List [str ]] = None
54- # `dtype` of `dict[str, Any]` is incompatible with `pd.read_csv`
55- self ._aim_metadata_dtype : Mapping [
56- str , Union [str , type , np .dtype , pd .api .extensions .ExtensionDtype ]
57- ] = {}
55+ self ._aim_metadata_dtype : Optional [Mapping [str , Any ]] = {}
56+
57+ # Only apply the `aim_metadata_dtype` if it is a type of `Mapping`.
5858 if isinstance (aim_metadata_dtype , Mapping ):
59- self ._aim_metadata_columns = list (aim_metadata_dtype .keys ())
60- self ._aim_metadata_dtype .update (
61- cast (
62- Mapping [
63- str ,
64- Union [str , type , np .dtype , pd .api .extensions .ExtensionDtype ],
65- ],
66- aim_metadata_dtype ,
67- )
68- )
59+ # Convert all of the column names to lowercase.
60+ prepared_aim_metadata_dtype_dict = {
61+ k .lower (): v for k , v in aim_metadata_dtype .items ()
62+ }
63+
64+ # Get all the column names from the prepared dict.
65+ self ._aim_metadata_columns = list (prepared_aim_metadata_dtype_dict .keys ())
66+
67+ # Update the _aim_metadata_dtype with the prepared dict.
68+ self ._aim_metadata_dtype .update (prepared_aim_metadata_dtype_dict )
69+
70+ # Add the sample_id to the _aim_metadata_dtype.
6971 self ._aim_metadata_dtype ["sample_id" ] = "object"
7072
7173 # Set up taxon colors.
@@ -151,7 +153,7 @@ def _parse_general_metadata(
151153 self , sample_set : str , data : Union [bytes , Exception ]
152154 ) -> pd .DataFrame :
153155 if isinstance (data , bytes ):
154- dtype = {
156+ dtype_dict = {
155157 "sample_id" : "object" ,
156158 "partner_sample_id" : "object" ,
157159 "contributor" : "object" ,
@@ -163,14 +165,9 @@ def _parse_general_metadata(
163165 "longitude" : "float64" ,
164166 "sex_call" : "object" ,
165167 }
166- # `dtype` of `dict[str, str]` is incompatible with `pd.read_csv`
167- dtype_mapping = cast (
168- Mapping [
169- str , Union [str , type , np .dtype , pd .api .extensions .ExtensionDtype ]
170- ],
171- dtype ,
172- )
173- df = pd .read_csv (io .BytesIO (data ), dtype = dtype_mapping , na_values = "" )
168+ # `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
169+ dtype : DefaultDict [str , str ] = defaultdict (lambda : "object" , dtype_dict )
170+ df = pd .read_csv (io .BytesIO (data ), dtype = dtype , na_values = "" )
174171
175172 # Ensure all column names are lower case.
176173 df .columns = [c .lower () for c in df .columns ] # type: ignore
@@ -255,7 +252,10 @@ def _parse_sequence_qc_metadata(
255252 ) -> pd .DataFrame :
256253 if isinstance (data , bytes ):
257254 # Get the dtype of the constant columns.
258- dtype = self ._sequence_qc_metadata_dtype
255+ dtype_dict = self ._sequence_qc_metadata_dtype
256+
257+ # `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
258+ dtype : DefaultDict [str , str ] = defaultdict (lambda : "object" , dtype_dict )
259259
260260 # Read the CSV using the dtype dict.
261261 df = pd .read_csv (io .BytesIO (data ), dtype = dtype , na_values = "" )
@@ -272,8 +272,8 @@ def _parse_sequence_qc_metadata(
272272
273273 # Add the sequence QC columns with appropriate missing values.
274274 # For each column, set the value to either NA or NaN.
275- for c , dtype in self ._sequence_qc_metadata_dtype .items ():
276- if pd .api .types .is_integer_dtype (dtype ):
275+ for c , datum_dtype in self ._sequence_qc_metadata_dtype .items ():
276+ if pd .api .types .is_integer_dtype (datum_dtype ):
277277 # Note: this creates a column with dtype int64.
278278 df [c ] = - 1
279279 else :
@@ -378,11 +378,8 @@ def _parse_surveillance_flags(
378378 "sample_id" : "object" ,
379379 "is_surveillance" : "boolean" ,
380380 }
381- # `dtype` of `dict[str, str]` is incompatible with `read_csv`
382- dtype = cast (
383- Mapping [str , Union [str , type , np .dtype , pd .api .extensions .ExtensionDtype ]],
384- dtype_dict ,
385- )
381+ # `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
382+ dtype : DefaultDict [str , str ] = defaultdict (lambda : "object" , dtype_dict )
386383
387384 if isinstance (data , bytes ):
388385 # Read the CSV data.
@@ -516,7 +513,11 @@ def _parse_cohorts_metadata(
516513 ) -> pd .DataFrame :
517514 if isinstance (data , bytes ):
518515 # Parse CSV data.
519- dtype = self ._cohorts_metadata_dtype
516+ dtype_dict = self ._cohorts_metadata_dtype
517+
518+ # `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
519+ dtype : DefaultDict [str , str ] = defaultdict (lambda : "object" , dtype_dict )
520+
520521 df = pd .read_csv (io .BytesIO (data ), dtype = dtype , na_values = "" )
521522
522523 # Ensure all column names are lower case.
@@ -590,14 +591,19 @@ def _parse_aim_metadata(
590591 assert self ._aim_metadata_columns is not None
591592 assert self ._aim_metadata_dtype is not None
592593 if isinstance (data , bytes ):
593- # Parse CSV data.
594- df = pd .read_csv (
595- io .BytesIO (data ), dtype = self ._aim_metadata_dtype , na_values = ""
596- )
594+ # Parse CSV data but don't apply the dtype yet.
595+ df = pd .read_csv (io .BytesIO (data ), na_values = "" )
597596
598- # Ensure all column names are lower case .
597+ # Convert all column names to lowercase .
599598 df .columns = [c .lower () for c in df .columns ] # type: ignore
600599
600+ # For each column in the DataFrame...
601+ for c in df .columns :
602+ # Apply the corresponding dtype from `_aim_metadata_dtype`.
603+ # Convert the type to a NumPy dtype.
604+ col_dtype_as_np = np .dtype (self ._aim_metadata_dtype [c ])
605+ df [c ] = df [c ].astype (col_dtype_as_np )
606+
601607 return df
602608
603609 elif isinstance (data , FileNotFoundError ):
0 commit comments