Skip to content

Commit 1da32f2

Browse files
ahernankSharon-codes
authored andcommitted
add tests for cohort group metadata
1 parent daa6a95 commit 1da32f2

1 file changed

Lines changed: 97 additions & 0 deletions

File tree

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
from pytest_cases import parametrize_with_cases
3+
4+
from malariagen_data import af1 as _af1
5+
from malariagen_data import ag3 as _ag3
6+
from malariagen_data.anoph.cohort_group_metadata import AnophelesCohortGroupMetadata
7+
8+
9+
@pytest.fixture
10+
def ag3_sim_api(ag3_sim_fixture):
11+
return AnophelesCohortGroupMetadata(
12+
url=ag3_sim_fixture.url,
13+
config_path=_ag3.CONFIG_PATH,
14+
gcs_url=_ag3.GCS_URL,
15+
major_version_number=_ag3.MAJOR_VERSION_NUMBER,
16+
major_version_path=_ag3.MAJOR_VERSION_PATH,
17+
pre=True,
18+
)
19+
20+
21+
@pytest.fixture
22+
def af1_sim_api(af1_sim_fixture):
23+
return AnophelesCohortGroupMetadata(
24+
url=af1_sim_fixture.url,
25+
config_path=_af1.CONFIG_PATH,
26+
gcs_url=_af1.GCS_URL,
27+
major_version_number=_af1.MAJOR_VERSION_NUMBER,
28+
major_version_path=_af1.MAJOR_VERSION_PATH,
29+
pre=False,
30+
)
31+
32+
33+
def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
34+
return ag3_sim_fixture, ag3_sim_api
35+
36+
37+
def case_af1_sim(af1_sim_fixture, af1_sim_api):
38+
return af1_sim_fixture, af1_sim_api
39+
40+
41+
def cohort_group_metadata_expected_columns():
42+
return {
43+
"cohort_id": "O",
44+
"cohort_size": "i",
45+
"country": "O",
46+
"country_alpha2": "O",
47+
"country_alpha3": "O",
48+
"taxon": "O",
49+
"year": "i",
50+
"quarter": "i",
51+
"month": "i",
52+
"admin1_name": "O",
53+
"admin1_iso": "O",
54+
"admin1_geoboundaries_shape_id": "O",
55+
"admin1_representative_longitude": "f",
56+
"admin1_representative_latitude": "f",
57+
}
58+
59+
60+
def cohort_group_geo_metadata_expected_columns():
61+
return {
62+
"cohort_id": "O",
63+
"geometry": "i",
64+
}
65+
66+
67+
def validate_cohort_group_metadata(df, expected_columns):
68+
# Check column names.
69+
expected_column_names = list(expected_columns.keys())
70+
assert df.columns.to_list() == expected_column_names
71+
72+
# Check column types.
73+
for c in df.columns:
74+
assert df[c].dtype.kind == expected_columns[c]
75+
76+
77+
@parametrize_with_cases("fixture,api", cases=".")
78+
def test_cohort_group_metadata(fixture, api: AnophelesCohortGroupMetadata):
79+
# Set up the test.
80+
cohort_name = "admin1_month"
81+
# Call function to be tested.
82+
df_cohorts = api.cohort_group_metadata(cohort_name)
83+
# Check output.
84+
validate_cohort_group_metadata(df_cohorts, cohort_group_metadata_expected_columns())
85+
# # Check values against cohort metadata
86+
# df_default = api.sample_metadata()
87+
# assert df_cohorts['cohort_id'] in df_default['cohort_admin1_month']
88+
89+
90+
@parametrize_with_cases("fixture,api", cases=".")
91+
def test_cohort_group_metadata_with_query(fixture, api: AnophelesCohortGroupMetadata):
92+
cohort_name = "admin1_month"
93+
df_cohorts = api.cohort_group_metadata(
94+
cohort_name, cohort_group_query="country == 'Burkina Faso'"
95+
)
96+
validate_cohort_group_metadata(df_cohorts, cohort_group_metadata_expected_columns())
97+
assert (df_cohorts["country"] == "Burkina Faso").all()

0 commit comments

Comments
 (0)