Skip to content

Commit 09f224a

Browse files
committed
Amend dtype data type for pd.read_csv
1 parent 82fa84d commit 09f224a

1 file changed

Lines changed: 38 additions & 10 deletions

File tree

malariagen_data/anoph/sample_metadata.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
import io
22
from itertools import cycle
3-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
3+
from typing import (
4+
Any,
5+
Callable,
6+
Dict,
7+
List,
8+
Mapping,
9+
Optional,
10+
Sequence,
11+
Tuple,
12+
Union,
13+
cast,
14+
)
415
import warnings
5-
from collections import defaultdict
616

717
import ipyleaflet # type: ignore
818
import numpy as np
@@ -41,11 +51,21 @@ def __init__(
4151
# data resources, and so column names and dtype need to be
4252
# passed in as parameters.
4353
self._aim_metadata_columns: Optional[List[str]] = None
44-
# `dtype` of `dict[str, Any]` is incompatible with `read_csv`
45-
self._aim_metadata_dtype: defaultdict[str, Any] = defaultdict()
54+
# `dtype` of `dict[str, Any]` is incompatible with `pd.read_csv`
55+
self._aim_metadata_dtype: Mapping[
56+
str, Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype]
57+
] = {}
4658
if isinstance(aim_metadata_dtype, Mapping):
4759
self._aim_metadata_columns = list(aim_metadata_dtype.keys())
48-
self._aim_metadata_dtype.update(aim_metadata_dtype)
60+
self._aim_metadata_dtype.update(
61+
cast(
62+
Mapping[
63+
str,
64+
Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype],
65+
],
66+
aim_metadata_dtype,
67+
)
68+
)
4969
self._aim_metadata_dtype["sample_id"] = "object"
5070

5171
# Set up taxon colors.
@@ -143,9 +163,14 @@ def _parse_general_metadata(
143163
"longitude": "float64",
144164
"sex_call": "object",
145165
}
146-
# `dtype` of `dict[str, str]` is incompatible with `read_csv`
147-
dtype = defaultdict(str, dtype)
148-
df = pd.read_csv(io.BytesIO(data), dtype=dtype, na_values="")
166+
# `dtype` of `dict[str, str]` is incompatible with `pd.read_csv`
167+
dtype_mapping = cast(
168+
Mapping[
169+
str, Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype]
170+
],
171+
dtype,
172+
)
173+
df = pd.read_csv(io.BytesIO(data), dtype=dtype_mapping, na_values="")
149174

150175
# Ensure all column names are lower case.
151176
df.columns = [c.lower() for c in df.columns] # type: ignore
@@ -349,12 +374,15 @@ def _parse_surveillance_flags(
349374
# Specify the expected data type for each column.
350375
# Note: "bool" is not nullable and does not support `NaN`, which is required when missing data.
351376
# Otherwise `NaN` will be mis-translated to `True` when the dtype is applied to the DataFrame.
352-
dtype = {
377+
dtype_dict = {
353378
"sample_id": "object",
354379
"is_surveillance": "boolean",
355380
}
356381
# `dtype` of `dict[str, str]` is incompatible with `read_csv`
357-
dtype = defaultdict(str, dtype)
382+
dtype = cast(
383+
Mapping[str, Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype]],
384+
dtype_dict,
385+
)
358386

359387
if isinstance(data, bytes):
360388
# Read the CSV data.

0 commit comments

Comments
 (0)