16
16
// under the License.
17
17
18
18
use std:: collections:: HashMap ;
19
+ use std:: cmp:: { max, min} ;
19
20
use std:: ffi:: CString ;
21
+ use std:: ops:: IndexMut ;
20
22
use std:: sync:: Arc ;
21
23
22
24
use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
@@ -28,26 +30,35 @@ use arrow::pyarrow::FromPyArrow;
28
30
use datafusion:: arrow:: datatypes:: Schema ;
29
31
use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
30
32
use datafusion:: arrow:: util:: pretty;
31
- use datafusion:: common:: UnnestOptions ;
33
+ use datafusion:: common:: { DFSchema , Statistics , UnnestOptions } ;
34
+ use datafusion:: common:: stats:: Precision ;
35
+ use datafusion:: common:: tree_node:: { Transformed , TreeNode } ;
32
36
use datafusion:: config:: { CsvOptions , ParquetColumnOptions , ParquetOptions , TableParquetOptions } ;
33
37
use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
38
+ use datafusion:: datasource:: physical_plan:: FileScanConfig ;
39
+ use datafusion:: datasource:: source:: { DataSource , DataSourceExec } ;
34
40
use datafusion:: datasource:: TableProvider ;
35
41
use datafusion:: error:: DataFusionError ;
36
42
use datafusion:: execution:: SendableRecordBatchStream ;
37
43
use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
44
+ use datafusion:: physical_plan:: ExecutionPlan ;
38
45
use datafusion:: prelude:: * ;
39
46
use datafusion_ffi:: table_provider:: FFI_TableProvider ;
47
+ use datafusion:: sql:: unparser:: plan_to_sql;
48
+ use datafusion_proto:: physical_plan:: AsExecutionPlan ;
49
+ use datafusion_proto:: protobuf:: PhysicalPlanNode ;
40
50
use futures:: { StreamExt , TryStreamExt } ;
51
+ use prost:: Message ;
41
52
use pyo3:: exceptions:: PyValueError ;
42
53
use pyo3:: prelude:: * ;
43
54
use pyo3:: pybacked:: PyBackedStr ;
44
- use pyo3:: types:: { PyCapsule , PyList , PyTuple , PyTupleMethods } ;
55
+ use pyo3:: types:: { PyBytes , PyCapsule , PyDict , PyList , PyString , PyTuple , PyTupleMethods } ;
45
56
use tokio:: task:: JoinHandle ;
46
57
47
58
use crate :: catalog:: PyTable ;
48
59
use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError } ;
49
60
use crate :: expr:: sort_expr:: to_sort_expressions;
50
- use crate :: physical_plan:: PyExecutionPlan ;
61
+ use crate :: physical_plan:: { codec , PyExecutionPlan } ;
51
62
use crate :: record_batch:: PyRecordBatchStream ;
52
63
use crate :: sql:: logical:: PyLogicalPlan ;
53
64
use crate :: utils:: {
@@ -57,6 +68,7 @@ use crate::{
57
68
errors:: PyDataFusionResult ,
58
69
expr:: { sort_expr:: PySortExpr , PyExpr } ,
59
70
} ;
71
+ use crate :: common:: df_schema:: PyDFSchema ;
60
72
61
73
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
62
74
// - we have not decided on the table_provider approach yet
@@ -992,6 +1004,168 @@ impl PyDataFrame {
992
1004
let df = self . df . as_ref ( ) . clone ( ) . fill_null ( scalar_value, cols) ?;
993
1005
Ok ( Self :: new ( df) )
994
1006
}
1007
+
1008
+ fn distributed_plan ( & self , py : Python < ' _ > ) -> PyResult < DistributedPlan > {
1009
+ let future_plan = DistributedPlan :: try_new ( self . df . as_ref ( ) ) ;
1010
+ wait_for_future ( py, future_plan) ?. map_err ( py_datafusion_err)
1011
+ }
1012
+
1013
+ fn plan_sql ( & self , py : Python < ' _ > ) -> PyResult < PyObject > {
1014
+ let logical_plan = self . df . logical_plan ( ) ;
1015
+
1016
+ let sql = plan_to_sql ( logical_plan) . map_err ( py_datafusion_err) ?;
1017
+ Ok ( PyString :: new ( py, sql. to_string ( ) . as_ref ( ) ) . into ( ) )
1018
+ }
1019
+ }
1020
+
1021
+ #[ pyclass( get_all) ]
1022
+ #[ derive( Debug , Clone ) ]
1023
+ pub struct DistributedPlan {
1024
+ min_size : usize ,
1025
+ physical_plan : PyExecutionPlan ,
1026
+ }
1027
+
1028
+ #[ pymethods]
1029
+ impl DistributedPlan {
1030
+ #[ new]
1031
+ fn unmarshal ( state : Bound < PyDict > ) -> PyResult < Self > {
1032
+ let ctx = SessionContext :: new ( ) ;
1033
+ let serialized_plan = state
1034
+ . get_item ( "plan" ) ?
1035
+ . expect ( "missing key `plan` from state" ) ;
1036
+ let serialized_plan = serialized_plan. downcast :: < PyBytes > ( ) ?. as_bytes ( ) ;
1037
+ let min_size = state
1038
+ . get_item ( "min_size" ) ?
1039
+ . expect ( "missing key `min_size` from state" )
1040
+ . extract :: < usize > ( ) ?;
1041
+ let plan = deserialize_plan ( serialized_plan, & ctx) ?;
1042
+ Ok ( Self {
1043
+ min_size,
1044
+ physical_plan : PyExecutionPlan :: new ( plan) ,
1045
+ } )
1046
+ }
1047
+
1048
+ fn partition_count ( & self ) -> usize {
1049
+ self . physical_plan . partition_count ( )
1050
+ }
1051
+
1052
+ fn num_bytes ( & self ) -> Option < usize > {
1053
+ self . stats_field ( |stats| stats. total_byte_size )
1054
+ }
1055
+
1056
+ fn num_rows ( & self ) -> Option < usize > {
1057
+ self . stats_field ( |stats| stats. num_rows )
1058
+ }
1059
+
1060
+ fn schema ( & self ) -> PyResult < PyDFSchema > {
1061
+ DFSchema :: try_from ( self . plan ( ) . schema ( ) )
1062
+ . map ( PyDFSchema :: from)
1063
+ . map_err ( py_datafusion_err)
1064
+ }
1065
+
1066
+ fn set_desired_parallelism ( & mut self , desired_parallelism : usize ) -> PyResult < ( ) > {
1067
+ let updated_plan = self
1068
+ . plan ( )
1069
+ . clone ( )
1070
+ . transform_up ( |node| {
1071
+ if let Some ( exec) = node. as_any ( ) . downcast_ref :: < DataSourceExec > ( ) {
1072
+ // Remove redundant ranges from partition files because FileScanConfig refuses to repartition
1073
+ // if any file has a range defined (even when the range actually covers the entire file).
1074
+ // The EnforceDistribution optimizer rule adds ranges for both full and partial files,
1075
+ // so this tries to revert that in order to trigger a repartition when no files are actually split.
1076
+ // TODO: check whether EnforceDistribution is still adding redundant ranges and remove this
1077
+ // workaround if no longer needed.
1078
+ if let Some ( file_scan) =
1079
+ exec. data_source ( ) . as_any ( ) . downcast_ref :: < FileScanConfig > ( )
1080
+ {
1081
+ let mut range_free_file_scan = file_scan. clone ( ) ;
1082
+ let mut total_size: usize = 0 ;
1083
+ for group in range_free_file_scan. file_groups . iter_mut ( ) {
1084
+ for group_idx in 0 ..group. len ( ) {
1085
+ let file = group. index_mut ( group_idx) ;
1086
+ if let Some ( range) = & file. range {
1087
+ total_size += ( range. end - range. start ) as usize ;
1088
+ if range. start == 0 && range. end == file. object_meta . size as i64
1089
+ {
1090
+ file. range = None ; // remove redundant range
1091
+ }
1092
+ } else {
1093
+ total_size += file. object_meta . size as usize ;
1094
+ }
1095
+
1096
+ }
1097
+ }
1098
+ let min_size_buckets = max ( 1 , total_size. div_ceil ( self . min_size ) ) ;
1099
+ let partitions = min ( min_size_buckets, desired_parallelism) ;
1100
+ let ordering = range_free_file_scan. eq_properties ( ) . output_ordering ( ) ;
1101
+ if let Some ( repartitioned) =
1102
+ range_free_file_scan. repartitioned ( partitions, 1 , ordering) ?
1103
+ {
1104
+ return Ok ( Transformed :: yes ( Arc :: new ( DataSourceExec :: new (
1105
+ repartitioned,
1106
+ ) ) ) ) ;
1107
+ }
1108
+ }
1109
+ }
1110
+ Ok ( Transformed :: no ( node) )
1111
+ } )
1112
+ . map_err ( py_datafusion_err) ?
1113
+ . data ;
1114
+ self . physical_plan = PyExecutionPlan :: new ( updated_plan) ;
1115
+ Ok ( ( ) )
1116
+ }
1117
+ }
1118
+
1119
+ impl DistributedPlan {
1120
+ async fn try_new ( df : & DataFrame ) -> Result < Self , DataFusionError > {
1121
+ let ( mut session_state, logical_plan) = df. clone ( ) . into_parts ( ) ;
1122
+ let min_size = session_state
1123
+ . config_options ( )
1124
+ . optimizer
1125
+ . repartition_file_min_size ;
1126
+ // Create the physical plan with a single partition, to ensure that no files are split into ranges.
1127
+ // Otherwise, any subsequent repartition attempt would fail (see the comment in `set_desired_parallelism`)
1128
+ session_state
1129
+ . config_mut ( )
1130
+ . options_mut ( )
1131
+ . execution
1132
+ . target_partitions = 1 ;
1133
+ let physical_plan = session_state. create_physical_plan ( & logical_plan) . await ?;
1134
+ let physical_plan = PyExecutionPlan :: new ( physical_plan) ;
1135
+ Ok ( Self {
1136
+ min_size,
1137
+ physical_plan,
1138
+ } )
1139
+ }
1140
+
1141
+ fn plan ( & self ) -> & Arc < dyn ExecutionPlan > {
1142
+ & self . physical_plan . plan
1143
+ }
1144
+
1145
+ fn stats_field ( & self , field : fn ( Statistics ) -> Precision < usize > ) -> Option < usize > {
1146
+ if let Ok ( stats) = self . plan ( ) . partition_statistics ( None ) {
1147
+ match field ( stats) {
1148
+ Precision :: Exact ( n) => Some ( n) ,
1149
+ _ => None ,
1150
+ }
1151
+ } else {
1152
+ None
1153
+ }
1154
+ }
1155
+ }
1156
+
1157
+ fn deserialize_plan (
1158
+ serialized_plan : & [ u8 ] ,
1159
+ ctx : & SessionContext ,
1160
+ ) -> PyResult < Arc < dyn ExecutionPlan > > {
1161
+ deltalake:: ensure_initialized ( ) ;
1162
+ let node = PhysicalPlanNode :: decode ( serialized_plan)
1163
+ . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) )
1164
+ . map_err ( py_datafusion_err) ?;
1165
+ let plan = node
1166
+ . try_into_physical_plan ( ctx, ctx. runtime_env ( ) . as_ref ( ) , codec ( ) )
1167
+ . map_err ( py_datafusion_err) ?;
1168
+ Ok ( plan)
995
1169
}
996
1170
997
1171
/// Print DataFrame
0 commit comments