Skip to content

Commit 1c0e189

Browse files
committed
Further develop the Catalog API to match Spark better
Implemented missing features. Added missing arguments. Corrected database vs schema handling.
1 parent 5c2a7f7 commit 1c0e189

2 files changed

Lines changed: 237 additions & 48 deletions

File tree

duckdb/experimental/spark/sql/catalog.py

Lines changed: 134 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
class Database(NamedTuple): # noqa: D101
77
name: str
8+
catalog: str
89
description: str | None
910
locationUri: str
1011

1112

1213
class Table(NamedTuple): # noqa: D101
1314
name: str
1415
database: str | None
16+
catalog: str
1517
description: str | None
1618
tableType: str
1719
isTemporary: bool
@@ -24,56 +26,168 @@ class Column(NamedTuple): # noqa: D101
2426
nullable: bool
2527
isPartition: bool
2628
isBucket: bool
29+
isCluster: bool
2730

2831

2932
class Function(NamedTuple): # noqa: D101
3033
name: str
34+
catalog: str | None
35+
namespace: list[str] | None
3136
description: str | None
3237
className: str
3338
isTemporary: bool
3439

3540

36-
class Catalog: # noqa: D101
41+
class Catalog:
42+
"""Implements the spark catalog API.
43+
44+
Implementation notes:
45+
Spark has the concept of a catalog and inside each catalog there are schemas
46+
which contain tables. But spark calls the schemas as databases through
47+
the catalog API.
48+
For Duckdb, there are databases, which in turn contain schemas. DuckDBs
49+
databases therefore overlap with the concept of the spark catalog.
50+
So to summarize
51+
------------------------------
52+
| Spark | DuckDB |
53+
------------------------------
54+
! Catalog | Database |
55+
| Database/Schema | Schema |
56+
------------------------------
57+
The consequence is that this catalog API refers in several locations to a
58+
database name, which is the DuckDB schema.
59+
"""
60+
3761
def __init__(self, session: SparkSession) -> None: # noqa: D107
3862
self._session = session
3963

40-
def listDatabases(self) -> list[Database]: # noqa: D102
41-
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
64+
def listDatabases(self, pattern: str | None = None) -> list[Database]:
65+
"""Returns a list of database object for all available databases."""
66+
if pattern:
67+
pattern = pattern.replace("*", "%")
68+
where_sql = " WHERE schema_name LIKE ?"
69+
params = (pattern,)
70+
else:
71+
where_sql = ""
72+
params = ()
73+
74+
sql_text = "select schema_name, database_name from duckdb_schemas()" + where_sql
75+
res = self._session.conn.sql(sql_text, params=params).fetchall()
4276

43-
def transform_to_database(x: list[str]) -> Database:
44-
return Database(name=x[0], description=None, locationUri="")
77+
def transform_to_database(x: tuple[str, ...]) -> Database:
78+
return Database(name=x[0], catalog=x[1], description=None, locationUri="")
4579

4680
databases = [transform_to_database(x) for x in res]
4781
return databases
4882

49-
def listTables(self) -> list[Table]: # noqa: D102
50-
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
83+
def listTables(self, dbName: str | None = None, pattern: str | None = None) -> list[Table]:
84+
"""Returns a list of tables/views in the specified database.
85+
86+
If dbName nor pattern are provided, the current active database is used.
87+
"""
88+
dbName = dbName or self.currentDatabase()
89+
where_sql = ""
90+
params = (dbName,)
91+
92+
if pattern:
93+
pattern = pattern.replace("*", "%")
94+
where_sql = " and table_name LIKE ?"
95+
params += (pattern,)
96+
97+
sql_text = (
98+
"select table_name, schema_name, sql, temporary from duckdb_tables() where schema_name = ?" + where_sql
99+
)
100+
101+
res = self._session.conn.sql(sql_text, params=params).fetchall()
102+
current_catalog = self._currentCatalog()
51103

52104
def transform_to_table(x: list[str]) -> Table:
53-
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
105+
return Table(
106+
name=x[0], database=x[1], catalog=current_catalog, description=x[2], tableType="", isTemporary=x[3]
107+
)
54108

55109
tables = [transform_to_table(x) for x in res]
56110
return tables
57111

58-
def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]: # noqa: D102
59-
query = f"""
60-
select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}'
61-
"""
62-
if dbName:
63-
query += f" and database_name = '{dbName}'"
64-
res = self._session.conn.sql(query).fetchall()
112+
def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]:
113+
"""Returns a list of columns for the given table/view in the specified database."""
114+
query = (
115+
"select column_name, data_type, is_nullable"
116+
" from duckdb_columns()"
117+
" where table_name = ? and schema_name = ? and database_name = ?"
118+
)
119+
dbName = dbName or self.currentDatabase()
120+
params = (tableName, dbName, self._currentCatalog())
121+
res = self._session.conn.sql(query, params=params).fetchall()
122+
123+
if len(res) == 0:
124+
from duckdb.experimental.spark.errors import AnalysisException
125+
126+
msg = f"[TABLE_OR_VIEW_NOT_FOUND] The table or view `{tableName}` cannot be found"
127+
raise AnalysisException(msg)
65128

66129
def transform_to_column(x: list[str | bool]) -> Column:
67-
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
130+
return Column(
131+
name=x[0],
132+
description=None,
133+
dataType=x[1],
134+
nullable=x[2],
135+
isPartition=False,
136+
isBucket=False,
137+
isCluster=False,
138+
)
68139

69140
columns = [transform_to_column(x) for x in res]
70141
return columns
71142

72-
def listFunctions(self, dbName: str | None = None) -> list[Function]: # noqa: D102
73-
raise NotImplementedError
143+
def listFunctions(self, dbName: str | None = None, pattern: str | None = None) -> list[Function]:
144+
"""Returns a list of functions registered in the specified database."""
145+
dbName = dbName or self.currentDatabase()
146+
where_sql = ""
147+
params = (dbName,)
148+
149+
if pattern:
150+
pattern = pattern.replace("*", "%")
151+
where_sql = " AND function_name LIKE ?"
152+
params = (pattern,)
153+
154+
sql_text = (
155+
"SELECT DISTINCT database_name, schema_name, function_name, description, function_type"
156+
" FROM duckdb_functions()"
157+
" WHERE schema_name = ? " + where_sql
158+
)
159+
160+
res = self._session.conn.sql(sql_text, params=params).fetchall()
161+
162+
columns = [
163+
Function(
164+
name=x[2],
165+
catalog=x[0],
166+
namespace=[x[1]],
167+
description=x[3],
168+
className=x[4],
169+
isTemporary=x[0] == "temp",
170+
)
171+
for x in res
172+
]
173+
return columns
174+
175+
def currentDatabase(self) -> str:
176+
"""Retrieves the name of the active database/schema."""
177+
res = self._session.conn.sql("SELECT current_schema()").fetchone()
178+
return res[0]
179+
180+
def setCurrentDatabase(self, dbName: str) -> None:
181+
"""Sets the active database/schema. Equivalent to executing 'USE dbName'."""
182+
self._session.conn.sql(f"USE {_sql_quote(dbName)}")
183+
184+
def _currentCatalog(self) -> str:
185+
res = self._session.conn.sql("SELECT current_database()").fetchone()
186+
return res[0]
187+
74188

75-
def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102
76-
raise NotImplementedError
189+
def _sql_quote(value: str) -> str:
190+
return f'"{value.replace('"', '""')}"'
77191

78192

79193
__all__ = ["Catalog", "Column", "Database", "Function", "Table"]

tests/fast/spark/test_spark_catalog.py

Lines changed: 103 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,125 @@
77

88

99
class TestSparkCatalog:
10-
def test_list_databases(self, spark):
10+
def test_list_databases_all(self, spark):
1111
dbs = spark.catalog.listDatabases()
1212
if USE_ACTUAL_SPARK:
1313
assert all(isinstance(db, Database) for db in dbs)
1414
else:
1515
assert dbs == [
16-
Database(name="memory", description=None, locationUri=""),
17-
Database(name="system", description=None, locationUri=""),
18-
Database(name="temp", description=None, locationUri=""),
16+
Database(name="main", catalog="memory", description=None, locationUri=""),
17+
Database(name="information_schema", catalog="system", description=None, locationUri=""),
18+
Database(name="main", catalog="system", description=None, locationUri=""),
19+
Database(name="pg_catalog", catalog="system", description=None, locationUri=""),
20+
Database(name="main", catalog="temp", description=None, locationUri=""),
1921
]
2022

21-
def test_list_tables(self, spark):
22-
# empty
23+
def test_create_use_schema(self, spark):
24+
assert spark.catalog.currentDatabase() == "main"
25+
26+
spark.sql("CREATE SCHEMA my_schema1")
27+
spark.catalog.setCurrentDatabase("my_schema1")
28+
assert spark.catalog.currentDatabase() == "my_schema1"
29+
30+
dbs = spark.catalog.listDatabases("*schema1")
31+
assert len(dbs) == 1
32+
assert spark.catalog.currentDatabase() == "my_schema1"
33+
34+
# VErofying the table goes to the right schema
35+
spark.sql("create table tbl1(a varchar)")
36+
spark.sql("create table main.tbl2(a varchar)")
37+
expected = [
38+
Table(
39+
name="tbl1",
40+
catalog="memory",
41+
database="my_schema1",
42+
description="CREATE TABLE my_schema1.tbl1(a VARCHAR);",
43+
tableType="",
44+
isTemporary=False,
45+
)
46+
]
47+
tbls = spark.catalog.listTables()
48+
assert tbls == expected
49+
50+
spark.sql("DROP TABLE my_schema1.tbl1")
51+
spark.sql("DROP SCHEMA my_schema1")
52+
assert len(spark.catalog.listDatabases("my_schema1")) == 0
53+
assert spark.catalog.currentDatabase() == "main"
54+
55+
@pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Checking duckdb specific databases")
56+
def test_list_databases_pattern(self, spark):
57+
expected = [
58+
Database(name="pg_catalog", catalog="system", description=None, locationUri=""),
59+
]
60+
dbs = spark.catalog.listDatabases("pg*")
61+
assert dbs == expected
62+
dbs = spark.catalog.listDatabases("pg_catalog")
63+
assert dbs == expected
64+
dbs = spark.catalog.listDatabases("notfound")
65+
assert dbs == []
66+
67+
def test_list_tables_empty(self, spark):
2368
tbls = spark.catalog.listTables()
2469
assert tbls == []
2570

26-
if not USE_ACTUAL_SPARK:
27-
# Skip this if we're using actual Spark because we can't create tables
28-
# with our setup.
29-
spark.sql("create table tbl(a varchar)")
30-
tbls = spark.catalog.listTables()
31-
assert tbls == [
32-
Table(
33-
name="tbl",
34-
database="memory",
35-
description="CREATE TABLE tbl(a VARCHAR);",
36-
tableType="",
37-
isTemporary=False,
38-
)
39-
]
71+
@pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Checking duckdb specific tables")
72+
def test_list_tables_create(self, spark):
73+
spark.sql("create table tbl1(a varchar)")
74+
spark.sql("create table tbl2(b varchar)")
75+
expected = [
76+
Table(
77+
name="tbl1",
78+
catalog="memory",
79+
database="main",
80+
description="CREATE TABLE tbl1(a VARCHAR);",
81+
tableType="",
82+
isTemporary=False,
83+
),
84+
Table(
85+
name="tbl2",
86+
catalog="memory",
87+
database="main",
88+
description="CREATE TABLE tbl2(b VARCHAR);",
89+
tableType="",
90+
isTemporary=False,
91+
),
92+
]
93+
tbls = spark.catalog.listTables()
94+
assert tbls == expected
95+
96+
tbls = spark.catalog.listTables(pattern="*l2")
97+
assert tbls == expected[1:]
98+
99+
tbls = spark.catalog.listTables(pattern="tbl2")
100+
assert tbls == expected[1:]
101+
102+
tbls = spark.catalog.listTables(dbName="notfound")
103+
assert tbls == []
40104

41105
@pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup")
42106
def test_list_columns(self, spark):
43107
spark.sql("create table tbl(a varchar, b bool)")
44-
columns = spark.catalog.listColumns("tbl")
45-
assert columns == [
46-
Column(name="a", description=None, dataType="VARCHAR", nullable=True, isPartition=False, isBucket=False),
47-
Column(name="b", description=None, dataType="BOOLEAN", nullable=True, isPartition=False, isBucket=False),
48-
]
49108

50-
# TODO: should this error instead? # noqa: TD002, TD003
51-
non_existant_columns = spark.catalog.listColumns("none_existant")
52-
assert non_existant_columns == []
109+
columns = spark.catalog.listColumns("tbl")
110+
kwds = dict(description=None, nullable=True, isPartition=False, isBucket=False, isCluster=False) # noqa: C408
111+
assert columns == [Column(name="a", dataType="VARCHAR", **kwds), Column(name="b", dataType="BOOLEAN", **kwds)]
53112

54113
spark.sql("create view vw as select * from tbl")
55114
view_columns = spark.catalog.listColumns("vw")
56115
assert view_columns == columns
116+
117+
from spark_namespace.errors import AnalysisException
118+
119+
with pytest.raises(AnalysisException):
120+
assert spark.catalog.listColumns("tbl", "notfound")
121+
122+
def test_list_columns_not_found(self, spark):
123+
from spark_namespace.errors import AnalysisException
124+
125+
with pytest.raises(AnalysisException):
126+
spark.catalog.listColumns("none_existant")
127+
128+
def test_list_functions(self, spark):
129+
fns = spark.catalog.listFunctions()
130+
assert len(fns)
131+
assert any(f.name == "current_database" for f in fns)

0 commit comments

Comments
 (0)