@@ -19,15 +19,17 @@ use proc_macro2::{Span, TokenStream as TokenStream2};
19
19
use quote:: quote;
20
20
use rustpython_bytecode:: bytecode:: CodeObject ;
21
21
use rustpython_compiler:: compile;
22
+ use std:: collections:: HashMap ;
22
23
use std:: env;
23
24
use std:: fs;
24
- use std:: path:: PathBuf ;
25
+ use std:: path:: { Path , PathBuf } ;
25
26
use syn:: parse:: { Parse , ParseStream , Result as ParseResult } ;
26
- use syn:: { self , parse2, Lit , LitByteStr , Meta , Token } ;
27
+ use syn:: { self , parse2, Lit , LitByteStr , LitStr , Meta , Token } ;
27
28
28
29
enum CompilationSourceKind {
29
30
File ( PathBuf ) ,
30
31
SourceCode ( String ) ,
32
+ Dir ( PathBuf ) ,
31
33
}
32
34
33
35
struct CompilationSource {
@@ -36,14 +38,22 @@ struct CompilationSource {
36
38
}
37
39
38
40
impl CompilationSource {
39
- fn compile ( self , mode : & compile:: Mode , module_name : String ) -> Result < CodeObject , Diagnostic > {
40
- let compile = |source| {
41
- compile:: compile ( source, mode, module_name, 0 ) . map_err ( |err| {
42
- Diagnostic :: spans_error ( self . span , format ! ( "Compile error: {}" , err) )
43
- } )
44
- } ;
45
-
46
- match & self . kind {
41
+ fn compile_string (
42
+ & self ,
43
+ source : & str ,
44
+ mode : & compile:: Mode ,
45
+ module_name : String ,
46
+ ) -> Result < CodeObject , Diagnostic > {
47
+ compile:: compile ( source, mode, module_name, 0 )
48
+ . map_err ( |err| Diagnostic :: spans_error ( self . span , format ! ( "Compile error: {}" , err) ) )
49
+ }
50
+
51
+ fn compile (
52
+ & self ,
53
+ mode : & compile:: Mode ,
54
+ module_name : String ,
55
+ ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
56
+ Ok ( match & self . kind {
47
57
CompilationSourceKind :: File ( rel_path) => {
48
58
let mut path = PathBuf :: from (
49
59
env:: var_os ( "CARGO_MANIFEST_DIR" ) . expect ( "CARGO_MANIFEST_DIR is not present" ) ,
@@ -55,10 +65,59 @@ impl CompilationSource {
55
65
format ! ( "Error reading file {:?}: {}" , path, err) ,
56
66
)
57
67
} ) ?;
58
- compile ( & source)
68
+ hashmap ! { module_name. clone( ) => self . compile_string( & source, mode, module_name. clone( ) ) ?}
69
+ }
70
+ CompilationSourceKind :: SourceCode ( code) => {
71
+ hashmap ! { module_name. clone( ) => self . compile_string( code, mode, module_name. clone( ) ) ?}
72
+ }
73
+ CompilationSourceKind :: Dir ( rel_path) => {
74
+ let mut path = PathBuf :: from (
75
+ env:: var_os ( "CARGO_MANIFEST_DIR" ) . expect ( "CARGO_MANIFEST_DIR is not present" ) ,
76
+ ) ;
77
+ path. push ( rel_path) ;
78
+ self . compile_dir ( & path, String :: new ( ) , mode) ?
79
+ }
80
+ } )
81
+ }
82
+
83
+ fn compile_dir (
84
+ & self ,
85
+ path : & Path ,
86
+ parent : String ,
87
+ mode : & compile:: Mode ,
88
+ ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
89
+ let mut code_map = HashMap :: new ( ) ;
90
+ let paths = fs:: read_dir ( & path) . map_err ( |err| {
91
+ Diagnostic :: spans_error ( self . span , format ! ( "Error listing dir {:?}: {}" , path, err) )
92
+ } ) ?;
93
+ for path in paths {
94
+ let path = path. map_err ( |err| {
95
+ Diagnostic :: spans_error ( self . span , format ! ( "Failed to list file: {}" , err) )
96
+ } ) ?;
97
+ let path = path. path ( ) ;
98
+ let file_name = path. file_name ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
99
+ if path. is_dir ( ) {
100
+ code_map. extend ( self . compile_dir (
101
+ & path,
102
+ format ! ( "{}{}." , parent, file_name) ,
103
+ mode,
104
+ ) ?) ;
105
+ } else if file_name. ends_with ( ".py" ) {
106
+ let source = fs:: read_to_string ( & path) . map_err ( |err| {
107
+ Diagnostic :: spans_error (
108
+ self . span ,
109
+ format ! ( "Error reading file {:?}: {}" , path, err) ,
110
+ )
111
+ } ) ?;
112
+ let file_name_splitte: Vec < & str > = file_name. splitn ( 2 , '.' ) . collect ( ) ;
113
+ let module_name = format ! ( "{}{}" , parent, file_name_splitte[ 0 ] ) ;
114
+ code_map. insert (
115
+ module_name. clone ( ) ,
116
+ self . compile_string ( & source, mode, module_name) ?,
117
+ ) ;
59
118
}
60
- CompilationSourceKind :: SourceCode ( code) => compile ( code) ,
61
119
}
120
+ Ok ( code_map)
62
121
}
63
122
}
64
123
@@ -69,7 +128,7 @@ struct PyCompileInput {
69
128
}
70
129
71
130
impl PyCompileInput {
72
- fn compile ( & self ) -> Result < CodeObject , Diagnostic > {
131
+ fn compile ( & self ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
73
132
let mut module_name = None ;
74
133
let mut mode = None ;
75
134
let mut source: Option < CompilationSource > = None ;
@@ -122,6 +181,16 @@ impl PyCompileInput {
122
181
kind : CompilationSourceKind :: File ( path) ,
123
182
span : extract_spans ( & name_value) . unwrap ( ) ,
124
183
} ) ;
184
+ } else if name_value. ident == "dir" {
185
+ assert_source_empty ( & source) ?;
186
+ let path = match & name_value. lit {
187
+ Lit :: Str ( s) => PathBuf :: from ( s. value ( ) ) ,
188
+ _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
189
+ } ;
190
+ source = Some ( CompilationSource {
191
+ kind : CompilationSourceKind :: Dir ( path) ,
192
+ span : extract_spans ( & name_value) . unwrap ( ) ,
193
+ } ) ;
125
194
}
126
195
}
127
196
}
@@ -154,16 +223,23 @@ impl Parse for PyCompileInput {
154
223
pub fn impl_py_compile_bytecode ( input : TokenStream2 ) -> Result < TokenStream2 , Diagnostic > {
155
224
let input: PyCompileInput = parse2 ( input) ?;
156
225
157
- let code_obj = input. compile ( ) ?;
226
+ let code_map = input. compile ( ) ?;
158
227
159
- let bytes = bincode:: serialize ( & code_obj) . expect ( "Failed to serialize" ) ;
160
- let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
228
+ let modules = code_map. iter ( ) . map ( |( module_name, code_obj) | {
229
+ let module_name = LitStr :: new ( & module_name, Span :: call_site ( ) ) ;
230
+ let bytes = bincode:: serialize ( & code_obj) . expect ( "Failed to serialize" ) ;
231
+ let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
232
+ quote ! { #module_name. into( ) => bincode:: deserialize:: <:: rustpython_vm:: bytecode:: CodeObject >( #bytes)
233
+ . expect( "Deserializing CodeObject failed" ) }
234
+ } ) ;
161
235
162
236
let output = quote ! {
163
237
( {
164
238
use :: rustpython_vm:: __exports:: bincode;
165
- bincode:: deserialize:: <:: rustpython_vm:: bytecode:: CodeObject >( #bytes)
166
- . expect( "Deserializing CodeObject failed" )
239
+ use :: rustpython_vm:: __exports:: hashmap;
240
+ hashmap! {
241
+ #( #modules) , *
242
+ }
167
243
} )
168
244
} ;
169
245
0 commit comments