@@ -23,7 +23,7 @@ use std::collections::{BTreeSet, HashMap, HashSet};
2323use std:: sync:: Arc ;
2424
2525#[ cfg( test) ]
26- use std:: sync :: Mutex ;
26+ use std:: cell :: Cell ;
2727
2828use crate :: analyzer:: type_coercion:: TypeCoercionRewriter ;
2929
@@ -36,17 +36,38 @@ pub(crate) enum NullRestrictionEvalMode {
3636}
3737
3838#[ cfg( test) ]
39- static NULL_RESTRICTION_EVAL_MODE : Mutex < NullRestrictionEvalMode > =
40- Mutex :: new ( NullRestrictionEvalMode :: Auto ) ;
39+ thread_local ! {
40+ static NULL_RESTRICTION_EVAL_MODE : Cell <NullRestrictionEvalMode > =
41+ const { Cell :: new( NullRestrictionEvalMode :: Auto ) } ;
42+ }
4143
4244#[ cfg( test) ]
4345pub ( crate ) fn set_null_restriction_eval_mode_for_test ( mode : NullRestrictionEvalMode ) {
44- * NULL_RESTRICTION_EVAL_MODE . lock ( ) . unwrap ( ) = mode;
46+ NULL_RESTRICTION_EVAL_MODE . with ( |eval_mode| eval_mode . set ( mode) ) ;
4547}
4648
4749#[ cfg( test) ]
4850fn null_restriction_eval_mode ( ) -> NullRestrictionEvalMode {
49- * NULL_RESTRICTION_EVAL_MODE . lock ( ) . unwrap ( )
51+ NULL_RESTRICTION_EVAL_MODE . with ( Cell :: get)
52+ }
53+
54+ #[ cfg( test) ]
55+ pub ( crate ) fn with_null_restriction_eval_mode_for_test < T > (
56+ mode : NullRestrictionEvalMode ,
57+ f : impl FnOnce ( ) -> T ,
58+ ) -> T {
59+ struct NullRestrictionEvalModeReset ( NullRestrictionEvalMode ) ;
60+
61+ impl Drop for NullRestrictionEvalModeReset {
62+ fn drop ( & mut self ) {
63+ set_null_restriction_eval_mode_for_test ( self . 0 ) ;
64+ }
65+ }
66+
67+ let previous_mode = null_restriction_eval_mode ( ) ;
68+ set_null_restriction_eval_mode_for_test ( mode) ;
69+ let _reset = NullRestrictionEvalModeReset ( previous_mode) ;
70+ f ( )
5071}
5172use arrow:: array:: { Array , RecordBatch , new_null_array} ;
5273use arrow:: datatypes:: { DataType , Field , Schema } ;
@@ -445,18 +466,24 @@ mod tests {
445466 let predicate = binary_expr ( col ( "a" ) , Operator :: Gt , lit ( 8i64 ) ) ;
446467 let join_cols_of_predicate = predicate. column_refs ( ) ;
447468
448- set_null_restriction_eval_mode_for_test ( NullRestrictionEvalMode :: Auto ) ;
449- let auto_result = is_restrict_null_predicate (
450- predicate. clone ( ) ,
451- join_cols_of_predicate. iter ( ) . copied ( ) ,
469+ let auto_result = with_null_restriction_eval_mode_for_test (
470+ NullRestrictionEvalMode :: Auto ,
471+ || {
472+ is_restrict_null_predicate (
473+ predicate. clone ( ) ,
474+ join_cols_of_predicate. iter ( ) . copied ( ) ,
475+ )
476+ } ,
452477 ) ?;
453478
454- set_null_restriction_eval_mode_for_test (
479+ let authoritative_result = with_null_restriction_eval_mode_for_test (
455480 NullRestrictionEvalMode :: AuthoritativeOnly ,
456- ) ;
457- let authoritative_result = is_restrict_null_predicate (
458- predicate. clone ( ) ,
459- join_cols_of_predicate. iter ( ) . copied ( ) ,
481+ || {
482+ is_restrict_null_predicate (
483+ predicate. clone ( ) ,
484+ join_cols_of_predicate. iter ( ) . copied ( ) ,
485+ )
486+ } ,
460487 ) ?;
461488
462489 assert_eq ! ( auto_result, authoritative_result) ;
@@ -470,17 +497,25 @@ mod tests {
470497 let predicate = binary_expr ( col ( "a" ) , Operator :: Gt , col ( "b" ) ) ;
471498 let column_a = Column :: from_name ( "a" ) ;
472499
473- set_null_restriction_eval_mode_for_test ( NullRestrictionEvalMode :: Auto ) ;
474- let auto_result =
475- is_restrict_null_predicate ( predicate. clone ( ) , std:: iter:: once ( & column_a) ) ?;
500+ let auto_result = with_null_restriction_eval_mode_for_test (
501+ NullRestrictionEvalMode :: Auto ,
502+ || {
503+ is_restrict_null_predicate (
504+ predicate. clone ( ) ,
505+ std:: iter:: once ( & column_a) ,
506+ )
507+ } ,
508+ ) ?;
476509
477- set_null_restriction_eval_mode_for_test (
510+ let authoritative_only_result = with_null_restriction_eval_mode_for_test (
478511 NullRestrictionEvalMode :: AuthoritativeOnly ,
479- ) ;
480- let authoritative_only_result =
481- is_restrict_null_predicate ( predicate. clone ( ) , std:: iter:: once ( & column_a) ) ?;
482-
483- set_null_restriction_eval_mode_for_test ( NullRestrictionEvalMode :: Auto ) ;
512+ || {
513+ is_restrict_null_predicate (
514+ predicate. clone ( ) ,
515+ std:: iter:: once ( & column_a) ,
516+ )
517+ } ,
518+ ) ?;
484519
485520 assert ! ( !auto_result, "{predicate}" ) ;
486521 assert ! ( !authoritative_only_result, "{predicate}" ) ;
0 commit comments