@@ -76,8 +76,8 @@ def sorting_analyzer_for_unitrefine_curation():
7676 """Makes an analyzer whose first 10 units are good normal units, and 10 which are noise. We make them
7777 noise by using a spike trains which are uncorrelated with the recording for `sorting2`."""
7878
79- recording , sorting_1 = generate_ground_truth_recording (num_channels = 4 , seed = 1 , num_units = 10 )
80- _ , sorting_2 = generate_ground_truth_recording (num_channels = 4 , seed = 2 , num_units = 10 )
79+ recording , sorting_1 = generate_ground_truth_recording (num_channels = 4 , seed = 1 , num_units = 6 )
80+ _ , sorting_2 = generate_ground_truth_recording (num_channels = 4 , seed = 2 , num_units = 6 )
8181 both_sortings = aggregate_units ([sorting_1 , sorting_2 ])
8282 analyzer = create_sorting_analyzer (sorting = both_sortings , recording = recording )
8383 analyzer .compute (["random_spikes" , "noise_levels" , "templates" ])
@@ -113,9 +113,9 @@ def trained_pipeline_path(sorting_analyzer_for_unitrefine_curation):
113113 }
114114 )
115115 train_model (
116- analyzers = [analyzer ],
116+ analyzers = [analyzer ] * 5 ,
117117 folder = trained_model_folder ,
118- labels = [[1 ] * 10 + [0 ] * 10 ] ,
118+ labels = [[1 ] * 6 + [0 ] * 6 ] * 5 ,
119119 imputation_strategies = ["median" ],
120120 scaling_techniques = ["standard_scaler" ],
121121 classifiers = ["RandomForestClassifier" ],
0 commit comments