Skip to content
Merged
1 change: 1 addition & 0 deletions python/sedonadb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ test = [
"pandas",
"polars",
"pytest",
"pyyaml",
]
geopandas = [
"adbc-driver-manager[dbapi]",
Expand Down
212 changes: 211 additions & 1 deletion python/sedonadb/python/sedonadb/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, List, Tuple, Any
from typing import TYPE_CHECKING, Any, List, Tuple

import geoarrow.pyarrow as ga
import pyarrow as pa
Expand Down Expand Up @@ -657,6 +657,216 @@ def __init__(self, uri=None):
cur.execute("SET max_parallel_workers_per_gather TO 0")


class BigQuery(DBEngine):
"""A BigQuery implementation of the DBEngine using ADBC

Uses the ADBC BigQuery driver. Authentication uses Application Default
Credentials (ADC) by default — run ``gcloud auth application-default login``
once to set that up. Set the following environment variables to configure
the connection:

- SEDONADB_BIGQUERY_TEST_PROJECT_ID: GCP project identifier. Defaults to
"sedonadb-testing".
- SEDONADB_BIGQUERY_TEST_DATASET_ID: Dataset identifier. In general data is
not scanned for these tests because doing so would incur cost.
- SEDONADB_BIGQUERY_TEST_CREDENTIALS_FILE: (optional) Path to a service
account JSON key file. When omitted, ADC is used instead.

Unless modifying these tests, the cached results should allow these tests
to run without an active connection (and should allow tests to run locally
much faster as opening a connection to BigQuery is slow).
"""

_CACHE_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "geography"
_shared_cache: "ArrowSQLCache | None" = None

def __init__(self, cache_path: "Path | None" = None):
self._cache_path = cache_path or self._CACHE_DIR / "bigquery_cache.yml"
if cache_path is not None or BigQuery._shared_cache is None:
BigQuery._shared_cache = ArrowSQLCache("bigquery", self._cache_path)
self._file_cache = BigQuery._shared_cache
self.con = None
self._con_attempted = False

def _ensure_con(self):
import adbc_driver_bigquery
import adbc_driver_bigquery.dbapi

Comment thread
paleolimbot marked this conversation as resolved.
if self.con is not None or self._con_attempted:
return
self._con_attempted = True

project_id = os.environ.get(
"SEDONADB_BIGQUERY_TEST_PROJECT_ID", "sedonadb-testing"
)
dataset_id = os.environ.get(
"SEDONADB_BIGQUERY_TEST_DATASET_ID", "sedonadb_test"
)
credentials_file = os.environ.get("SEDONADB_BIGQUERY_TEST_CREDENTIALS_FILE")

db_kwargs = {
adbc_driver_bigquery.DatabaseOptions.PROJECT_ID.value: project_id,
adbc_driver_bigquery.DatabaseOptions.DATASET_ID.value: dataset_id,
}

if credentials_file:
db_kwargs[adbc_driver_bigquery.DatabaseOptions.AUTH_TYPE.value] = (
adbc_driver_bigquery.DatabaseOptions.AUTH_VALUE_JSON_CREDENTIAL_FILE.value
)
db_kwargs[adbc_driver_bigquery.DatabaseOptions.AUTH_CREDENTIALS.value] = (
credentials_file
)

self.con = adbc_driver_bigquery.dbapi.connect(db_kwargs=db_kwargs)

def close(self):
"""Close the connection and flush any new cache entries to disk"""
self._file_cache.flush()
if self.con:
self.con.close()

def __del__(self):
try:
self.close()
except Exception:
pass

@classmethod
def name(cls):
return "bigquery"

@classmethod
def install_hint(cls):
return (
"- Run `pip install adbc-driver-bigquery` to install the required driver\n"
"- Run `gcloud auth application-default login` to authenticate\n"
"- Set SEDONADB_BIGQUERY_TEST_PROJECT_ID to a valid BigQuery project identifier"
)

def val_or_null(self, arg):
if isinstance(arg, bytes):
return f"FROM_HEX('{arg.hex()}')"
else:
return super().val_or_null(arg)

def create_table_parquet(self, name, paths) -> "BigQuery":
raise NotImplementedError("Create table from Parquet not implemented")

def create_table_arrow(self, name, obj, *, geometry_cols=None) -> "BigQuery":
raise NotImplementedError("Create table from Arrow not implemented")

def execute_and_collect(self, query) -> pa.Table:
cached = self._file_cache.get(query)
if cached is not None:
return cached

try:
self._ensure_con()
except Exception as e:
raise RuntimeError(
"Query not in cache and BigQuery connection unavailable:\n"
+ BigQuery.install_hint()
) from e

with self.con.cursor() as cur:
cur.execute(query)
result = cur.fetch_arrow_table()
self._file_cache.put(query, result)
return result

def result_to_table(self, result: pa.Table) -> pa.Table:
# BigQuery only has a GEOGRAPHY type (always WGS84 with spherical edges).
# The ADBC driver returns geography columns as WKT strings with
# Arrow extension metadata: ARROW:extension:name = 'google:sqlType:geography'.
cols = {}
for name, col in zip(result.schema.names, result.columns):
field = result.schema.field(name)
ext_name = (field.metadata or {}).get(b"ARROW:extension:name", b"")
if ext_name == b"google:sqlType:geography":
col_wkb = ga.as_wkb(col.cast(pa.string()))
cols[name] = ga.with_crs(
ga.wkb().with_edge_type(ga.EdgeType.SPHERICAL).wrap_array(col_wkb),
ga.OGC_CRS84,
)
else:
cols[name] = col

return pa.table(cols)


class ArrowSQLCache:
"""A YAML-file-backed cache for Arrow-based query results.

Each entry stores a pa.Table as base64-encoded Arrow IPC. Queries are
sorted alphabetically when written for stable git diffs. Results are
nested under ``results.<engine_name>`` in the YAML output.

Leading comment lines (e.g., a license header) are preserved across
rewrites.
"""

def __init__(self, engine_name: str, path: Path):
self._engine_name = engine_name
self._path = path
self._header_lines: list[str] = []
self._entries: dict = {}
self._dirty = False
if self._path.exists():
self._load()

def _load(self):
import yaml

with open(self._path) as f:
lines = f.readlines()

# Split leading comment lines from the YAML body
body_start = 0
for i, line in enumerate(lines):
if line.startswith("#") or line.strip() == "":
body_start = i + 1
else:
break
self._header_lines = lines[:body_start]

body = "".join(lines[body_start:])
doc = yaml.safe_load(body) if body.strip() else {}
if doc and "results" in doc:
self._entries = doc["results"].get(self._engine_name, {})
else:
self._entries = doc or {}

def get(self, query: str) -> "pa.Table | None":
entry = self._entries.get(query)
if entry is None:
return None
import base64

buf = base64.b64decode(entry)
return pa.ipc.open_stream(buf).read_all()

def put(self, query: str, table: pa.Table):
import base64

sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
self._entries[query] = base64.b64encode(sink.getvalue().to_pybytes()).decode()
self._dirty = True

def flush(self):
if not self._dirty:
return
self._path.parent.mkdir(parents=True, exist_ok=True)
doc = {"results": {self._engine_name: self._entries}}
with open(self._path, "w") as f:
import yaml

f.writelines(self._header_lines)
yaml.dump(doc, f, default_flow_style=False, sort_keys=True)
self._dirty = False


def geom_or_null(arg, srid=None):
"""Format SQL expression for a geometry object or NULL"""
if arg is None:
Expand Down
25 changes: 0 additions & 25 deletions python/sedonadb/tests/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,31 +2392,6 @@ def test_st_reverse(eng, geom, expected):
eng.assert_query_result(f"SELECT ST_Reverse({geom_or_null(geom)})", expected)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
@pytest.mark.parametrize(
("x", "y", "expected"),
[
(None, None, None),
(1, None, None),
(None, 1, None),
(1, 1, "POINT (1 1)"),
(1.0, 1.0, "POINT (1 1)"),
(10, -1.5, "POINT (10 -1.5)"),
],
)
def test_st_geogpoint(eng, x, y, expected):
eng = eng.create_or_skip()
if eng == SedonaDB:
eng.assert_query_result(
f"SELECT ST_GeogPoint({val_or_null(x)}, {val_or_null(y)})", expected
)
else:
eng.assert_query_result(
f"SELECT ST_Point({val_or_null(x)}, {val_or_null(y)}) as geography",
expected,
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
@pytest.mark.parametrize(
("x", "y", "expected"),
Expand Down
Loading
Loading