55
66class Database (NamedTuple ): # noqa: D101
77 name : str
8+ catalog : str
89 description : str | None
910 locationUri : str
1011
1112
1213class Table (NamedTuple ): # noqa: D101
1314 name : str
1415 database : str | None
16+ catalog : str
1517 description : str | None
1618 tableType : str
1719 isTemporary : bool
@@ -24,56 +26,168 @@ class Column(NamedTuple): # noqa: D101
2426 nullable : bool
2527 isPartition : bool
2628 isBucket : bool
29+ isCluster : bool
2730
2831
2932class Function (NamedTuple ): # noqa: D101
3033 name : str
34+ catalog : str | None
35+ namespace : list [str ] | None
3136 description : str | None
3237 className : str
3338 isTemporary : bool
3439
3540
36- class Catalog : # noqa: D101
41+ class Catalog :
42+ """Implements the spark catalog API.
43+
44+ Implementation notes:
45+ Spark has the concept of a catalog and inside each catalog there are schemas
46+ which contain tables. But spark calls the schemas as databases through
47+ the catalog API.
48+ For Duckdb, there are databases, which in turn contain schemas. DuckDBs
49+ databases therefore overlap with the concept of the spark catalog.
50+ So to summarize
51+ ------------------------------
52+ | Spark | DuckDB |
53+ ------------------------------
54+ ! Catalog | Database |
55+ | Database/Schema | Schema |
56+ ------------------------------
57+ The consequence is that this catalog API refers in several locations to a
58+ database name, which is the DuckDB schema.
59+ """
60+
3761 def __init__ (self , session : SparkSession ) -> None : # noqa: D107
3862 self ._session = session
3963
40- def listDatabases (self ) -> list [Database ]: # noqa: D102
41- res = self ._session .conn .sql ("select database_name from duckdb_databases()" ).fetchall ()
64+ def listDatabases (self , pattern : str | None = None ) -> list [Database ]:
65+ """Returns a list of database object for all available databases."""
66+ if pattern :
67+ pattern = pattern .replace ("*" , "%" )
68+ where_sql = " WHERE schema_name LIKE ?"
69+ params = (pattern ,)
70+ else :
71+ where_sql = ""
72+ params = ()
73+
74+ sql_text = "select schema_name, database_name from duckdb_schemas()" + where_sql
75+ res = self ._session .conn .sql (sql_text , params = params ).fetchall ()
4276
43- def transform_to_database (x : list [str ]) -> Database :
44- return Database (name = x [0 ], description = None , locationUri = "" )
77+ def transform_to_database (x : tuple [str , ... ]) -> Database :
78+ return Database (name = x [0 ], catalog = x [ 1 ], description = None , locationUri = "" )
4579
4680 databases = [transform_to_database (x ) for x in res ]
4781 return databases
4882
49- def listTables (self ) -> list [Table ]: # noqa: D102
50- res = self ._session .conn .sql ("select table_name, database_name, sql, temporary from duckdb_tables()" ).fetchall ()
83+ def listTables (self , dbName : str | None = None , pattern : str | None = None ) -> list [Table ]:
84+ """Returns a list of tables/views in the specified database.
85+
86+ If dbName nor pattern are provided, the current active database is used.
87+ """
88+ dbName = dbName or self .currentDatabase ()
89+ where_sql = ""
90+ params = (dbName ,)
91+
92+ if pattern :
93+ pattern = pattern .replace ("*" , "%" )
94+ where_sql = " and table_name LIKE ?"
95+ params += (pattern ,)
96+
97+ sql_text = (
98+ "select table_name, schema_name, sql, temporary from duckdb_tables() where schema_name = ?" + where_sql
99+ )
100+
101+ res = self ._session .conn .sql (sql_text , params = params ).fetchall ()
102+ current_catalog = self ._currentCatalog ()
51103
52104 def transform_to_table (x : list [str ]) -> Table :
53- return Table (name = x [0 ], database = x [1 ], description = x [2 ], tableType = "" , isTemporary = x [3 ])
105+ return Table (
106+ name = x [0 ], database = x [1 ], catalog = current_catalog , description = x [2 ], tableType = "" , isTemporary = x [3 ]
107+ )
54108
55109 tables = [transform_to_table (x ) for x in res ]
56110 return tables
57111
58- def listColumns (self , tableName : str , dbName : str | None = None ) -> list [Column ]: # noqa: D102
59- query = f"""
60- select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{ tableName } '
61- """
62- if dbName :
63- query += f" and database_name = '{ dbName } '"
64- res = self ._session .conn .sql (query ).fetchall ()
112+ def listColumns (self , tableName : str , dbName : str | None = None ) -> list [Column ]:
113+ """Returns a list of columns for the given table/view in the specified database."""
114+ query = (
115+ "select column_name, data_type, is_nullable"
116+ " from duckdb_columns()"
117+ " where table_name = ? and schema_name = ? and database_name = ?"
118+ )
119+ dbName = dbName or self .currentDatabase ()
120+ params = (tableName , dbName , self ._currentCatalog ())
121+ res = self ._session .conn .sql (query , params = params ).fetchall ()
122+
123+ if len (res ) == 0 :
124+ from duckdb .experimental .spark .errors import AnalysisException
125+
126+ msg = f"[TABLE_OR_VIEW_NOT_FOUND] The table or view `{ tableName } ` cannot be found"
127+ raise AnalysisException (msg )
65128
66129 def transform_to_column (x : list [str | bool ]) -> Column :
67- return Column (name = x [0 ], description = None , dataType = x [1 ], nullable = x [2 ], isPartition = False , isBucket = False )
130+ return Column (
131+ name = x [0 ],
132+ description = None ,
133+ dataType = x [1 ],
134+ nullable = x [2 ],
135+ isPartition = False ,
136+ isBucket = False ,
137+ isCluster = False ,
138+ )
68139
69140 columns = [transform_to_column (x ) for x in res ]
70141 return columns
71142
72- def listFunctions (self , dbName : str | None = None ) -> list [Function ]: # noqa: D102
73- raise NotImplementedError
143+ def listFunctions (self , dbName : str | None = None , pattern : str | None = None ) -> list [Function ]:
144+ """Returns a list of functions registered in the specified database."""
145+ dbName = dbName or self .currentDatabase ()
146+ where_sql = ""
147+ params = (dbName ,)
148+
149+ if pattern :
150+ pattern = pattern .replace ("*" , "%" )
151+ where_sql = " AND function_name LIKE ?"
152+ params = (pattern ,)
153+
154+ sql_text = (
155+ "SELECT DISTINCT database_name, schema_name, function_name, description, function_type"
156+ " FROM duckdb_functions()"
157+ " WHERE schema_name = ? " + where_sql
158+ )
159+
160+ res = self ._session .conn .sql (sql_text , params = params ).fetchall ()
161+
162+ columns = [
163+ Function (
164+ name = x [2 ],
165+ catalog = x [0 ],
166+ namespace = [x [1 ]],
167+ description = x [3 ],
168+ className = x [4 ],
169+ isTemporary = x [0 ] == "temp" ,
170+ )
171+ for x in res
172+ ]
173+ return columns
174+
175+ def currentDatabase (self ) -> str :
176+ """Retrieves the name of the active database/schema."""
177+ res = self ._session .conn .sql ("SELECT current_schema()" ).fetchone ()
178+ return res [0 ]
179+
180+ def setCurrentDatabase (self , dbName : str ) -> None :
181+ """Sets the active database/schema. Equivalent to executing 'USE dbName'."""
182+ self ._session .conn .sql (f"USE { _sql_quote (dbName )} " )
183+
184+ def _currentCatalog (self ) -> str :
185+ res = self ._session .conn .sql ("SELECT current_database()" ).fetchone ()
186+ return res [0 ]
187+
74188
75- def setCurrentDatabase ( self , dbName : str ) -> None : # noqa: D102
76- raise NotImplementedError
189+ def _sql_quote ( value : str ) -> str :
190+ return f'" { value . replace ( '"' , '""' ) } "'
77191
78192
79193__all__ = ["Catalog" , "Column" , "Database" , "Function" , "Table" ]
0 commit comments