@@ -38,6 +38,9 @@ use datafusion::physical_plan::joins::{
3838} ;
3939use datafusion:: prelude:: { SessionConfig , SessionContext } ;
4040use datafusion_common:: { NullEquality , ScalarValue } ;
41+ use datafusion_execution:: TaskContext ;
42+ use datafusion_execution:: disk_manager:: { DiskManagerBuilder , DiskManagerMode } ;
43+ use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
4144use datafusion_physical_expr:: PhysicalExprRef ;
4245use datafusion_physical_expr:: expressions:: Literal ;
4346
@@ -1125,6 +1128,138 @@ impl JoinFuzzTestCase {
11251128 }
11261129}
11271130
1131+ /// Fuzz test: compare SMJ (with spilling) against HJ (no spill) for filtered
1132+ /// outer joins under memory pressure. This exercises the deferred filtering +
1133+ /// spill read-back path that unit tests can't easily cover with random data.
1134+ #[ tokio:: test]
1135+ async fn test_filtered_join_spill_fuzz ( ) {
1136+ let join_types = [ JoinType :: Left , JoinType :: Right , JoinType :: Full ] ;
1137+
1138+ let runtime_spill = RuntimeEnvBuilder :: new ( )
1139+ . with_memory_limit ( 4096 , 1.0 )
1140+ . with_disk_manager_builder (
1141+ DiskManagerBuilder :: default ( ) . with_mode ( DiskManagerMode :: OsTmpDirectory ) ,
1142+ )
1143+ . build_arc ( )
1144+ . unwrap ( ) ;
1145+
1146+ for join_type in & join_types {
1147+ for ( left_extra, right_extra) in [ ( true , true ) , ( false , true ) , ( true , false ) ] {
1148+ let input1 = make_staggered_batches_i32 ( 1000 , left_extra) ;
1149+ let input2 = make_staggered_batches_i32 ( 1000 , right_extra) ;
1150+
1151+ let schema1 = input1[ 0 ] . schema ( ) ;
1152+ let schema2 = input2[ 0 ] . schema ( ) ;
1153+ let filter = col_lt_col_filter ( schema1. clone ( ) , schema2. clone ( ) ) ;
1154+
1155+ let on = vec ! [
1156+ (
1157+ Arc :: new( Column :: new_with_schema( "a" , & schema1) . unwrap( ) ) as _,
1158+ Arc :: new( Column :: new_with_schema( "a" , & schema2) . unwrap( ) ) as _,
1159+ ) ,
1160+ (
1161+ Arc :: new( Column :: new_with_schema( "b" , & schema1) . unwrap( ) ) as _,
1162+ Arc :: new( Column :: new_with_schema( "b" , & schema2) . unwrap( ) ) as _,
1163+ ) ,
1164+ ] ;
1165+
1166+ for batch_size in [ 2 , 49 , 100 ] {
1167+ let session_config = SessionConfig :: new ( ) . with_batch_size ( batch_size) ;
1168+
1169+ // HJ baseline (no memory limit)
1170+ let left_hj = MemorySourceConfig :: try_new_exec (
1171+ std:: slice:: from_ref ( & input1) ,
1172+ schema1. clone ( ) ,
1173+ None ,
1174+ )
1175+ . unwrap ( ) ;
1176+ let right_hj = MemorySourceConfig :: try_new_exec (
1177+ std:: slice:: from_ref ( & input2) ,
1178+ schema2. clone ( ) ,
1179+ None ,
1180+ )
1181+ . unwrap ( ) ;
1182+ let hj = Arc :: new (
1183+ HashJoinExec :: try_new (
1184+ left_hj,
1185+ right_hj,
1186+ on. clone ( ) ,
1187+ Some ( filter. clone ( ) ) ,
1188+ join_type,
1189+ None ,
1190+ PartitionMode :: Partitioned ,
1191+ NullEquality :: NullEqualsNothing ,
1192+ false ,
1193+ )
1194+ . unwrap ( ) ,
1195+ ) ;
1196+ let ctx_hj = SessionContext :: new_with_config ( session_config. clone ( ) ) ;
1197+ let hj_collected = collect ( hj, ctx_hj. task_ctx ( ) ) . await . unwrap ( ) ;
1198+
1199+ // SMJ with spilling
1200+ let left_smj = MemorySourceConfig :: try_new_exec (
1201+ std:: slice:: from_ref ( & input1) ,
1202+ schema1. clone ( ) ,
1203+ None ,
1204+ )
1205+ . unwrap ( ) ;
1206+ let right_smj = MemorySourceConfig :: try_new_exec (
1207+ std:: slice:: from_ref ( & input2) ,
1208+ schema2. clone ( ) ,
1209+ None ,
1210+ )
1211+ . unwrap ( ) ;
1212+ let smj = Arc :: new (
1213+ SortMergeJoinExec :: try_new (
1214+ left_smj,
1215+ right_smj,
1216+ on. clone ( ) ,
1217+ Some ( filter. clone ( ) ) ,
1218+ * join_type,
1219+ vec ! [ SortOptions :: default ( ) ; on. len( ) ] ,
1220+ NullEquality :: NullEqualsNothing ,
1221+ )
1222+ . unwrap ( ) ,
1223+ ) ;
1224+ let task_ctx_spill = Arc :: new (
1225+ TaskContext :: default ( )
1226+ . with_session_config ( session_config)
1227+ . with_runtime ( Arc :: clone ( & runtime_spill) ) ,
1228+ ) ;
1229+ let smj_collected = collect ( smj, task_ctx_spill) . await . unwrap ( ) ;
1230+
1231+ let hj_rows: usize = hj_collected. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
1232+ let smj_rows: usize = smj_collected. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
1233+
1234+ assert_eq ! (
1235+ hj_rows, smj_rows,
1236+ "Row count mismatch for {join_type:?} batch_size={batch_size} \
1237+ left_extra={left_extra} right_extra={right_extra}: \
1238+ HJ={hj_rows} SMJ={smj_rows}"
1239+ ) ;
1240+
1241+ if hj_rows > 0 {
1242+ let hj_fmt =
1243+ pretty_format_batches ( & hj_collected) . unwrap ( ) . to_string ( ) ;
1244+ let smj_fmt =
1245+ pretty_format_batches ( & smj_collected) . unwrap ( ) . to_string ( ) ;
1246+
1247+ let mut hj_sorted: Vec < & str > = hj_fmt. trim ( ) . lines ( ) . collect ( ) ;
1248+ hj_sorted. sort_unstable ( ) ;
1249+ let mut smj_sorted: Vec < & str > = smj_fmt. trim ( ) . lines ( ) . collect ( ) ;
1250+ smj_sorted. sort_unstable ( ) ;
1251+
1252+ assert_eq ! (
1253+ hj_sorted, smj_sorted,
1254+ "Content mismatch for {join_type:?} batch_size={batch_size} \
1255+ left_extra={left_extra} right_extra={right_extra}"
1256+ ) ;
1257+ }
1258+ }
1259+ }
1260+ }
1261+ }
1262+
11281263/// Return randomly sized record batches with:
11291264/// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns
11301265/// two random int32 columns 'x', 'y' as other columns
0 commit comments