Skip to content

Commit 21af2d2

Browse files
committed
Fix bug in applying aim_metadata_dtype. Amend data types.
1 parent 09f224a commit 21af2d2

4 files changed

Lines changed: 60 additions & 48 deletions

File tree

malariagen_data/ag3.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ def _setup_aim_palettes():
7777
"unassigned": "black",
7878
}
7979

80+
# Note: These column names will be treated as case-insensitive,
81+
# because these column names and the column names from the CSV
82+
# will be converted to lowercase before applying these dtypes.
83+
AIM_METADATA_DTYPE = {
84+
"aim_species_fraction_arab": "float64",
85+
"aim_species_fraction_colu": "float64",
86+
"aim_species_fraction_colu_no2l": "float64",
87+
"aim_species_gambcolu_arabiensis": "object",
88+
"aim_species_gambiae_coluzzii": "object",
89+
"aim_species": "object",
90+
}
91+
8092

8193
class Ag3(AnophelesDataResource):
8294
"""Provides access to data from Ag3.x releases.
@@ -162,14 +174,7 @@ def __init__(
162174
config_path=CONFIG_PATH,
163175
cohorts_analysis=cohorts_analysis,
164176
aim_analysis=aim_analysis,
165-
aim_metadata_dtype={
166-
"aim_species_fraction_arab": "float64",
167-
"aim_species_fraction_colu": "float64",
168-
"aim_species_fraction_colu_no2l": "float64",
169-
"aim_species_gambcolu_arabiensis": "object",
170-
"aim_species_gambiae_coluzzii": "object",
171-
"aim_species": "object",
172-
},
177+
aim_metadata_dtype=AIM_METADATA_DTYPE,
173178
aim_ids=("gambcolu_vs_arab", "gamb_vs_colu"),
174179
aim_palettes=AIM_PALETTES,
175180
site_filters_analysis=site_filters_analysis,

malariagen_data/anoph/frq_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def plot_frequencies_heatmap(
210210

211211
# Indexing.
212212
if index is None:
213-
# `list[Hashable]` is incompatible with `list`
213+
# `list[Hashable]` is incompatible with the param for `list`
214+
# Convert `df.index.names` to a `list[str]` instead.
214215
index_names_as_list = [str(name) for name in df.index.names]
215216
index = list(index_names_as_list)
216217
df = df.reset_index().copy()

malariagen_data/anoph/sample_metadata.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
from typing import (
44
Any,
55
Callable,
6+
DefaultDict,
67
Dict,
78
List,
89
Mapping,
910
Optional,
1011
Sequence,
1112
Tuple,
1213
Union,
13-
cast,
1414
)
15+
from collections import defaultdict
1516
import warnings
1617

1718
import ipyleaflet # type: ignore
@@ -51,21 +52,22 @@ def __init__(
5152
# data resources, and so column names and dtype need to be
5253
# passed in as parameters.
5354
self._aim_metadata_columns: Optional[List[str]] = None
54-
# `dtype` of `dict[str, Any]` is incompatible with `pd.read_csv`
55-
self._aim_metadata_dtype: Mapping[
56-
str, Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype]
57-
] = {}
55+
self._aim_metadata_dtype: Optional[Mapping[str, Any]] = {}
56+
57+
# Only apply the `aim_metadata_dtype` if it is a type of `Mapping`.
5858
if isinstance(aim_metadata_dtype, Mapping):
59-
self._aim_metadata_columns = list(aim_metadata_dtype.keys())
60-
self._aim_metadata_dtype.update(
61-
cast(
62-
Mapping[
63-
str,
64-
Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype],
65-
],
66-
aim_metadata_dtype,
67-
)
68-
)
59+
# Convert all of the column names to lowercase.
60+
prepared_aim_metadata_dtype_dict = {
61+
k.lower(): v for k, v in aim_metadata_dtype.items()
62+
}
63+
64+
# Get all the column names from the prepared dict.
65+
self._aim_metadata_columns = list(prepared_aim_metadata_dtype_dict.keys())
66+
67+
# Update the _aim_metadata_dtype with the prepared dict.
68+
self._aim_metadata_dtype.update(prepared_aim_metadata_dtype_dict)
69+
70+
# Add the sample_id to the _aim_metadata_dtype.
6971
self._aim_metadata_dtype["sample_id"] = "object"
7072

7173
# Set up taxon colors.
@@ -151,7 +153,7 @@ def _parse_general_metadata(
151153
self, sample_set: str, data: Union[bytes, Exception]
152154
) -> pd.DataFrame:
153155
if isinstance(data, bytes):
154-
dtype = {
156+
dtype_dict = {
155157
"sample_id": "object",
156158
"partner_sample_id": "object",
157159
"contributor": "object",
@@ -163,14 +165,9 @@ def _parse_general_metadata(
163165
"longitude": "float64",
164166
"sex_call": "object",
165167
}
166-
# `dtype` of `dict[str, str]` is incompatible with `pd.read_csv`
167-
dtype_mapping = cast(
168-
Mapping[
169-
str, Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype]
170-
],
171-
dtype,
172-
)
173-
df = pd.read_csv(io.BytesIO(data), dtype=dtype_mapping, na_values="")
168+
# `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
169+
dtype: DefaultDict[str, str] = defaultdict(lambda: "object", dtype_dict)
170+
df = pd.read_csv(io.BytesIO(data), dtype=dtype, na_values="")
174171

175172
# Ensure all column names are lower case.
176173
df.columns = [c.lower() for c in df.columns] # type: ignore
@@ -255,7 +252,10 @@ def _parse_sequence_qc_metadata(
255252
) -> pd.DataFrame:
256253
if isinstance(data, bytes):
257254
# Get the dtype of the constant columns.
258-
dtype = self._sequence_qc_metadata_dtype
255+
dtype_dict = self._sequence_qc_metadata_dtype
256+
257+
# `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
258+
dtype: DefaultDict[str, str] = defaultdict(lambda: "object", dtype_dict)
259259

260260
# Read the CSV using the dtype dict.
261261
df = pd.read_csv(io.BytesIO(data), dtype=dtype, na_values="")
@@ -272,8 +272,8 @@ def _parse_sequence_qc_metadata(
272272

273273
# Add the sequence QC columns with appropriate missing values.
274274
# For each column, set the value to either NA or NaN.
275-
for c, dtype in self._sequence_qc_metadata_dtype.items():
276-
if pd.api.types.is_integer_dtype(dtype):
275+
for c, datum_dtype in self._sequence_qc_metadata_dtype.items():
276+
if pd.api.types.is_integer_dtype(datum_dtype):
277277
# Note: this creates a column with dtype int64.
278278
df[c] = -1
279279
else:
@@ -378,11 +378,8 @@ def _parse_surveillance_flags(
378378
"sample_id": "object",
379379
"is_surveillance": "boolean",
380380
}
381-
# `dtype` of `dict[str, str]` is incompatible with `read_csv`
382-
dtype = cast(
383-
Mapping[str, Union[str, type, np.dtype, pd.api.extensions.ExtensionDtype]],
384-
dtype_dict,
385-
)
381+
# `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
382+
dtype: DefaultDict[str, str] = defaultdict(lambda: "object", dtype_dict)
386383

387384
if isinstance(data, bytes):
388385
# Read the CSV data.
@@ -516,7 +513,11 @@ def _parse_cohorts_metadata(
516513
) -> pd.DataFrame:
517514
if isinstance(data, bytes):
518515
# Parse CSV data.
519-
dtype = self._cohorts_metadata_dtype
516+
dtype_dict = self._cohorts_metadata_dtype
517+
518+
# `dict[str, str]` is incompatible with the `dtype` of `pd.read_csv`
519+
dtype: DefaultDict[str, str] = defaultdict(lambda: "object", dtype_dict)
520+
520521
df = pd.read_csv(io.BytesIO(data), dtype=dtype, na_values="")
521522

522523
# Ensure all column names are lower case.
@@ -590,14 +591,19 @@ def _parse_aim_metadata(
590591
assert self._aim_metadata_columns is not None
591592
assert self._aim_metadata_dtype is not None
592593
if isinstance(data, bytes):
593-
# Parse CSV data.
594-
df = pd.read_csv(
595-
io.BytesIO(data), dtype=self._aim_metadata_dtype, na_values=""
596-
)
594+
# Parse CSV data but don't apply the dtype yet.
595+
df = pd.read_csv(io.BytesIO(data), na_values="")
597596

598-
# Ensure all column names are lower case.
597+
# Convert all column names to lowercase.
599598
df.columns = [c.lower() for c in df.columns] # type: ignore
600599

600+
# For each column in the DataFrame...
601+
for c in df.columns:
602+
# Apply the corresponding dtype from `_aim_metadata_dtype`.
603+
# Convert the type to a NumPy dtype.
604+
col_dtype_as_np = np.dtype(self._aim_metadata_dtype[c])
605+
df[c] = df[c].astype(col_dtype_as_np)
606+
601607
return df
602608

603609
elif isinstance(data, FileNotFoundError):

tests/anoph/test_sample_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def validate_metadata(df, expected_columns):
268268

269269
# Check column types.
270270
for c in df.columns:
271-
assert df[c].dtype.kind == expected_columns[c]
271+
assert df[c].dtype.kind == expected_columns[c], c
272272

273273

274274
@parametrize_with_cases("fixture,api", cases=".")

0 commit comments

Comments
 (0)