Skip to content

Commit bb7c7ad

Browse files
authored
Further Improve robustness of model classifier tests (#4450)
1 parent 1a8e958 commit bb7c7ad

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

src/spikeinterface/curation/tests/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def sorting_analyzer_for_unitrefine_curation():
7676
"""Makes an analyzer whose first 10 units are good normal units, and 10 which are noise. We make them
7777
noise by using a spike trains which are uncorrelated with the recording for `sorting2`."""
7878

79-
recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=10)
80-
_, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=10)
79+
recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=6)
80+
_, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=6)
8181
both_sortings = aggregate_units([sorting_1, sorting_2])
8282
analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording)
8383
analyzer.compute(["random_spikes", "noise_levels", "templates"])
@@ -113,9 +113,9 @@ def trained_pipeline_path(sorting_analyzer_for_unitrefine_curation):
113113
}
114114
)
115115
train_model(
116-
analyzers=[analyzer],
116+
analyzers=[analyzer] * 5,
117117
folder=trained_model_folder,
118-
labels=[[1] * 10 + [0] * 10],
118+
labels=[[1] * 6 + [0] * 6] * 5,
119119
imputation_strategies=["median"],
120120
scaling_techniques=["standard_scaler"],
121121
classifiers=["RandomForestClassifier"],

src/spikeinterface/curation/tests/test_model_based_curation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_unitrefi
132132
classified_units = model_based_classification.predict_labels()
133133
predictions = classified_units["prediction"].values
134134

135-
expected_result = np.array([1] * 10 + [0] * 10)
135+
expected_result = np.array([1] * 6 + [0] * 6)
136136
assert np.all(predictions == expected_result)
137137

138138
conversion = {0: "noise", 1: "good"}
139-
expected_result_converted = np.array(["good"] * 10 + ["noise"] * 10)
139+
expected_result_converted = np.array(["good"] * 6 + ["noise"] * 6)
140140
classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion)
141141
predictions_labelled = classified_units_labelled["prediction"]
142142
assert np.all(predictions_labelled == expected_result_converted)

0 commit comments

Comments
 (0)