diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index ad1a60a..df43e8e 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -389,15 +389,6 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup json.dump({"ProbeId": probes_dict}, f, indent=4) # Step 3: GENERATION OF CONTACTS.TSV - # ensure required contact identifiers are present - for probe in probes: - if probe.contact_ids is None: - raise ValueError( - "Contacts must have unique contact ids " - "and not None for export to BIDS probe format." - "Use `probegroup.auto_generate_contact_ids`." - ) - df = probegroup.to_dataframe() index = range(sum([p.get_contact_count() for p in probes])) df.rename(columns=tsv_label_map_to_BIDS, inplace=True) diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index c19ddd6..376b9d8 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -336,7 +336,10 @@ def set_contacts( Defines the two axes of the contact plane for each electrode. The third dimension corresponds to the probe `ndim` (2d or 3d). contact_ids: array[str] | None, default: None - Defines the contact ids for the contacts. If None, contact ids are not assigned. + Defines the contact ids for the contacts. If None, contact ids are + auto-generated as the zero-indexed strings ``["0", "1", ..., str(n - 1)]`` + so a Probe always carries a stable, slice-invariant handle for each + contact. Pass an explicit array to override. shank_ids : array[str] | None, default: None Defines the shank ids for the contacts. If None, then these are assigned to a unique Shank. @@ -378,8 +381,9 @@ def set_contacts( plane_axes = np.array(plane_axes) self._contact_plane_axes = plane_axes - if contact_ids is not None: - self.set_contact_ids(contact_ids) + if contact_ids is None: + contact_ids = np.arange(n).astype(str) + self.set_contact_ids(contact_ids) if shank_ids is None: # self._shank_ids = np.zeros(n, dtype=str) @@ -566,8 +570,9 @@ def set_contact_ids(self, contact_ids: np.ndarray | list): """ contact_ids = np.asarray(contact_ids) if np.all([c == "" for c in contact_ids]): - self._contact_ids = None - return + # Backward compat: previous versions serialized "unset" as empty + # strings. A Probe now always carries contact_ids, so regenerate. + contact_ids = np.arange(self.get_contact_count()).astype(str) if contact_ids.size != self.get_contact_count(): raise ValueError( @@ -1085,10 +1090,7 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: if self._contact_sides is not None: arr["contact_sides"] = self.contact_sides - if self.contact_ids is None: - arr["contact_ids"] = [""] * self.get_contact_count() - else: - arr["contact_ids"] = self.contact_ids + arr["contact_ids"] = self.contact_ids if complete: arr["si_units"] = self.si_units diff --git a/tests/test_probe.py b/tests/test_probe.py index 631b6d3..53e6c06 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -141,6 +141,40 @@ def test_probe(): # ~ plt.show() +def test_set_contacts_auto_generates_contact_ids(): + """When contact_ids is not supplied, Probe auto-generates ['0', ..., str(n-1)].""" + probe = Probe(ndim=2, si_units="um") + positions = np.array([[0, 0], [10, 0], [20, 0], [30, 0]]) + probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) + + assert probe.contact_ids is not None + np.testing.assert_array_equal(probe.contact_ids, np.array(["0", "1", "2", "3"])) + + +def test_set_contacts_respects_explicit_contact_ids(): + """An explicit contact_ids argument is preserved verbatim.""" + probe = Probe(ndim=2, si_units="um") + positions = np.array([[0, 0], [10, 0], [20, 0]]) + probe.set_contacts( + positions=positions, + shapes="circle", + shape_params={"radius": 5}, + contact_ids=["a", "b", "c"], + ) + + np.testing.assert_array_equal(probe.contact_ids, np.array(["a", "b", "c"])) + + +def test_set_contact_ids_all_empty_strings_regenerates(): + """Backward compat: older serialized probes used empty strings for 'unset'.""" + probe = Probe(ndim=2, si_units="um") + positions = np.array([[0, 0], [10, 0], [20, 0]]) + probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) + probe.set_contact_ids(["", "", ""]) + + np.testing.assert_array_equal(probe.contact_ids, np.array(["0", "1", "2"])) + + def test_probe_equality_dunder(): probe1 = generate_dummy_probe() probe2 = generate_dummy_probe() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index c942190..0a3ba96 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -179,13 +179,6 @@ def test_copy_preserves_device_channel_indices(probegroup): ) -def test_copy_does_not_preserve_contact_ids(probegroup): - """Probe.copy() intentionally does not copy contact_ids.""" - pg_copy = probegroup.copy() - # All contact_ids should be empty strings after copy - assert all(cid == "" for cid in pg_copy.get_global_contact_ids()) - - def test_copy_is_independent(probegroup): """Mutating the copy must not affect the original.""" original_positions = probegroup.probes[0].contact_positions.copy()