1616// under the License.
1717
1818use std:: fmt:: { self , Display , Formatter } ;
19- use std:: sync:: Arc ;
19+ use std:: sync:: { Arc , RwLock } ;
2020use std:: { any:: Any , borrow:: Cow } ;
2121
2222use arrow:: datatypes:: Schema ;
@@ -25,6 +25,7 @@ use datafusion::arrow::datatypes::SchemaRef;
2525use datafusion:: common:: Constraints ;
2626use datafusion:: datasource:: TableType ;
2727use datafusion:: logical_expr:: { Expr , TableProviderFilterPushDown , TableSource } ;
28+ use pyo3:: exceptions:: PyRuntimeError ;
2829use pyo3:: prelude:: * ;
2930
3031use datafusion:: logical_expr:: utils:: split_conjunction;
@@ -33,17 +34,13 @@ use crate::sql::logical::PyLogicalPlan;
3334
3435use super :: { data_type:: DataTypeMap , function:: SqlFunction } ;
3536
36- #[ pyclass( name = "SqlSchema" , module = "datafusion.common" , subclass) ]
37+ #[ pyclass( name = "SqlSchema" , module = "datafusion.common" , subclass, frozen ) ]
3738#[ derive( Debug , Clone ) ]
3839pub struct SqlSchema {
39- #[ pyo3( get, set) ]
40- pub name : String ,
41- #[ pyo3( get, set) ]
42- pub tables : Vec < SqlTable > ,
43- #[ pyo3( get, set) ]
44- pub views : Vec < SqlView > ,
45- #[ pyo3( get, set) ]
46- pub functions : Vec < SqlFunction > ,
40+ name : Arc < RwLock < String > > ,
41+ tables : Arc < RwLock < Vec < SqlTable > > > ,
42+ views : Arc < RwLock < Vec < SqlView > > > ,
43+ functions : Arc < RwLock < Vec < SqlFunction > > > ,
4744}
4845
4946#[ pyclass( name = "SqlTable" , module = "datafusion.common" , subclass) ]
@@ -104,28 +101,98 @@ impl SqlSchema {
104101 #[ new]
105102 pub fn new ( schema_name : & str ) -> Self {
106103 Self {
107- name : schema_name. to_owned ( ) ,
108- tables : Vec :: new ( ) ,
109- views : Vec :: new ( ) ,
110- functions : Vec :: new ( ) ,
104+ name : Arc :: new ( RwLock :: new ( schema_name. to_owned ( ) ) ) ,
105+ tables : Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ,
106+ views : Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ,
107+ functions : Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ,
111108 }
112109 }
113110
111+ #[ getter]
112+ fn name ( & self ) -> PyResult < String > {
113+ Ok ( self
114+ . name
115+ . read ( )
116+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to read schema name" ) ) ?
117+ . clone ( ) )
118+ }
119+
120+ #[ setter]
121+ fn set_name ( & self , value : String ) -> PyResult < ( ) > {
122+ * self
123+ . name
124+ . write ( )
125+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to write schema name" ) ) ? = value;
126+ Ok ( ( ) )
127+ }
128+
129+ #[ getter]
130+ fn tables ( & self ) -> PyResult < Vec < SqlTable > > {
131+ Ok ( self
132+ . tables
133+ . read ( )
134+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to read schema tables" ) ) ?
135+ . clone ( ) )
136+ }
137+
138+ #[ setter]
139+ fn set_tables ( & self , tables : Vec < SqlTable > ) -> PyResult < ( ) > {
140+ * self
141+ . tables
142+ . write ( )
143+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to write schema tables" ) ) ? = tables;
144+ Ok ( ( ) )
145+ }
146+
147+ #[ getter]
148+ fn views ( & self ) -> PyResult < Vec < SqlView > > {
149+ Ok ( self
150+ . views
151+ . read ( )
152+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to read schema views" ) ) ?
153+ . clone ( ) )
154+ }
155+
156+ #[ setter]
157+ fn set_views ( & self , views : Vec < SqlView > ) -> PyResult < ( ) > {
158+ * self
159+ . views
160+ . write ( )
161+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to write schema views" ) ) ? = views;
162+ Ok ( ( ) )
163+ }
164+
165+ #[ getter]
166+ fn functions ( & self ) -> PyResult < Vec < SqlFunction > > {
167+ Ok ( self
168+ . functions
169+ . read ( )
170+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to read schema functions" ) ) ?
171+ . clone ( ) )
172+ }
173+
174+ #[ setter]
175+ fn set_functions ( & self , functions : Vec < SqlFunction > ) -> PyResult < ( ) > {
176+ * self
177+ . functions
178+ . write ( )
179+ . map_err ( |_| PyRuntimeError :: new_err ( "failed to write schema functions" ) ) ? = functions;
180+ Ok ( ( ) )
181+ }
182+
114183 pub fn table_by_name ( & self , table_name : & str ) -> Option < SqlTable > {
115- for tbl in & self . tables {
116- if tbl. name . eq ( table_name) {
117- return Some ( tbl. clone ( ) ) ;
118- }
119- }
120- None
184+ let tables = self . tables . read ( ) . expect ( "failed to read schema tables" ) ;
185+ tables. iter ( ) . find ( |tbl| tbl. name . eq ( table_name) ) . cloned ( )
121186 }
122187
123- pub fn add_table ( & mut self , table : SqlTable ) {
124- self . tables . push ( table) ;
188+ pub fn add_table ( & self , table : SqlTable ) {
189+ let mut tables = self . tables . write ( ) . expect ( "failed to write schema tables" ) ;
190+ tables. push ( table) ;
125191 }
126192
127- pub fn drop_table ( & mut self , table_name : String ) {
128- self . tables . retain ( |x| !x. name . eq ( & table_name) ) ;
193+ pub fn drop_table ( & self , table_name : String ) {
194+ let mut tables = self . tables . write ( ) . expect ( "failed to write schema tables" ) ;
195+ tables. retain ( |x| !x. name . eq ( & table_name) ) ;
129196 }
130197}
131198
0 commit comments