Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions .github/scripts/generate_zarr_v2_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""
Generate zarr v2 fixtures for backward compatibility tests.

Run this script with an old spikeinterface version and zarr<3, e.g.:
pip install "spikeinterface==0.104.0" "zarr<3"
python generate_zarr_v2_fixtures.py --output /tmp/zarr_v2_fixtures

The script saves:
- recording.zarr : a small ZarrRecordingExtractor
- sorting.zarr : a small ZarrSortingExtractor
- expected_values.json : key values used to verify correct loading
"""
import argparse
import shutil
import json
from pathlib import Path

import numpy as np
import zarr

import spikeinterface as si


def main(output_dir: Path) -> None:
print(f"spikeinterface version : {si.__version__}")
print(f"zarr version : {zarr.__version__}")


output_dir.mkdir(parents=True, exist_ok=True)

recording, sorting = si.generate_ground_truth_recording(durations=[10, 5],num_channels=32, num_units=10, seed=0)
# save to binary to make them JSON serializable for later expected values extraction
recording = recording.save(folder=output_dir / "recording_binary", overwrite=True)
sorting = sorting.save(folder=output_dir / "sorting_binary", overwrite=True)
# --- save recording ---
recording_path = output_dir / "recording.zarr"
recording_zarr = recording.save(format="zarr", folder=recording_path, overwrite=True)
print(f"Saved recording -> {recording_path}")

# --- save sorting ---
sorting_path = output_dir / "sorting.zarr"
sorting_zarr = sorting.save(format="zarr", folder=sorting_path, overwrite=True)
print(f"Saved sorting -> {sorting_path}")

# --- save SortingAnalyzer ---
# Reload the recording from zarr so it is a serializable ZarrRecordingExtractor,
# which the analyzer can store as provenance.
analyzer_path = output_dir / "analyzer.zarr"
if analyzer_path.is_dir():
shutil.rmtree(analyzer_path)
analyzer = si.create_sorting_analyzer(
sorting_zarr, recording_zarr, format="zarr", folder=analyzer_path, overwrite=True
)
analyzer.compute(["random_spikes", "templates"])
print(f"Saved analyzer -> {analyzer_path}")

# Reload to verify templates are accessible before writing expected values
templates_array = analyzer.get_extension("templates").get_data()

# --- capture expected values for later assertion ---
expected = {
"spikeinterface_version": si.__version__,
"zarr_version": zarr.__version__,
"recording": {
"num_channels": int(recording.get_num_channels()),
"num_segments": int(recording.get_num_segments()),
"sampling_frequency": float(recording.get_sampling_frequency()),
"num_samples_per_segment": [int(recording.get_num_samples(seg)) for seg in range(recording.get_num_segments())],
"channel_ids": recording.get_channel_ids().tolist(),
"dtype": str(recording.get_dtype()),
# first 10 frames of segment 0 for all channels
"traces_seg0_first10": recording.get_traces(start_frame=0, end_frame=10, segment_index=0).tolist(),
},
"sorting": {
"num_segments": int(sorting.get_num_segments()),
"sampling_frequency": float(sorting.get_sampling_frequency()),
"unit_ids": sorting.get_unit_ids().tolist(),
"spike_trains_seg0": {
str(uid): sorting.get_unit_spike_train(unit_id=uid, segment_index=0).tolist()
for uid in sorting.unit_ids
},
},
"analyzer": {
"num_units": int(analyzer.get_num_units()),
"num_channels": int(analyzer.get_num_channels()),
"templates_shape": list(templates_array.shape),
},
}

expected_path = output_dir / "expected_values.json"
with open(expected_path, "w") as f:
json.dump(expected, f, indent=2)
print(f"Saved expected -> {expected_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate zarr v2 fixtures for backward compatibility tests")
parser.add_argument("--output", type=Path, required=True, help="Directory to write fixtures into")
args = parser.parse_args()
main(args.output)
47 changes: 47 additions & 0 deletions .github/workflows/test_zarr_compat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Test zarr backwards compatibility

on:
workflow_dispatch:
pull_request:
types: [synchronize, opened, reopened]
branches:
- main
paths:
- "src/spikeinterface/core/zarrextractors.py"
- "src/spikeinterface/core/zarrrecordingextractor.py"
- "src/spikeinterface/core/tests/test_zarr_backwards_compat.py"
- ".github/workflows/test_zarr_compat.yml"
- ".github/scripts/generate_zarr_v2_fixtures.py"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test-zarr-compat:
name: zarr v2 -> v3 backwards compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install SI 0.104.0 with zarr v2
run: pip install "spikeinterface==0.104.0" "zarr<3"

- name: Generate zarr v2 fixtures
run: python .github/scripts/generate_zarr_v2_fixtures.py --output /tmp/zarr_v2_fixtures

- name: Install current SI with zarr v3
run: pip install -e ".[test_core]"

- name: Check zarr version is v3
run: python -c "import zarr; v = zarr.__version__; print(f'zarr {v}'); assert int(v.split('.')[0]) >= 3"

- name: Run backward compatibility tests
env:
ZARR_V2_FIXTURES_PATH: /tmp/zarr_v2_fixtures
run: pytest src/spikeinterface/core/tests/test_zarr_backwards_compat.py -v
19 changes: 13 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ dependencies = [
"numpy>=2.0.0;python_version>='3.13'",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=2.18,<3",
"zarr>=3,<4",
"neo>=0.14.3",
"probeinterface>=0.3.1",
"packaging",
"pydantic",
"numcodecs<0.16.0", # For supporting zarr < 3
]

[build-system]
Expand Down Expand Up @@ -127,7 +126,9 @@ test_core = [

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs,
Expand All @@ -139,7 +140,9 @@ test_extractors = [
"pooch>=1.8.2",
"datalad>=1.0.2",
# Commenting out for release
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",
]

Expand Down Expand Up @@ -190,7 +193,9 @@ test = [

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs
Expand Down Expand Up @@ -219,7 +224,9 @@ docs = [
"huggingface_hub", # For automated curation

# for release we need pypi, so this needs to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version
]

Expand Down
Loading