Skip to content

Commit db9774d

Browse files
committed
more reg implementation
1 parent d12db56 commit db9774d

File tree

1 file changed

+286
-23
lines changed

1 file changed

+286
-23
lines changed

vm/src/stdlib/winreg.rs

Lines changed: 286 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef<PyModule> {
1111
mod winreg {
1212
use std::ffi::OsStr;
1313
use std::os::windows::ffi::OsStrExt;
14+
use std::ptr;
1415
use std::sync::Arc;
1516

16-
use crate::builtins::PyInt;
17+
use crate::builtins::{PyInt, PyTuple};
1718
use crate::common::lock::PyRwLock;
1819
use crate::function::FuncArgs;
1920
use crate::protocol::PyNumberMethods;
2021
use crate::types::AsNumber;
2122
use crate::{PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine};
2223

23-
use windows_sys::Win32::Foundation;
24+
use windows_sys::Win32::Foundation::{self, ERROR_MORE_DATA};
2425
use windows_sys::Win32::System::Registry;
2526

2627
use num_traits::ToPrimitive;
@@ -133,6 +134,7 @@ mod winreg {
133134
#[pymethod]
134135
fn Close(&self, vm: &VirtualMachine) -> PyResult<()> {
135136
let res = unsafe { Registry::RegCloseKey(*self.hkey.write()) };
137+
*self.hkey.write() = std::ptr::null_mut();
136138
if res == 0 {
137139
Ok(())
138140
} else {
@@ -166,6 +168,7 @@ mod winreg {
166168
#[pymethod(magic)]
167169
fn exit(zelf: PyRef<Self>, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
168170
let res = unsafe { Registry::RegCloseKey(*zelf.hkey.write()) };
171+
*zelf.hkey.write() = std::ptr::null_mut();
169172
if res == 0 {
170173
Ok(())
171174
} else {
@@ -251,22 +254,39 @@ mod winreg {
251254
// TODO: Computer name can be `None`
252255
#[pyfunction]
253256
fn ConnectRegistry(
254-
computer_name: String,
257+
computer_name: Option<String>,
255258
key: PyRef<PyHKEYObject>,
256259
vm: &VirtualMachine,
257-
) -> PyResult<()> {
258-
let wide_computer_name = to_utf16(computer_name);
259-
let res = unsafe {
260-
Registry::RegConnectRegistryW(
261-
wide_computer_name.as_ptr(),
262-
*key.hkey.read(),
263-
std::ptr::null_mut(),
264-
)
265-
};
266-
if res == 0 {
267-
Ok(())
260+
) -> PyResult<PyHKEYObject> {
261+
if let Some(computer_name) = computer_name {
262+
let mut ret_key = std::ptr::null_mut();
263+
let wide_computer_name = to_utf16(computer_name);
264+
let res = unsafe {
265+
Registry::RegConnectRegistryW(
266+
wide_computer_name.as_ptr(),
267+
*key.hkey.read(),
268+
&mut ret_key
269+
)
270+
};
271+
if res == 0 {
272+
Ok(PyHKEYObject::new(ret_key))
273+
} else {
274+
Err(vm.new_os_error(format!("error code: {}", res)))
275+
}
268276
} else {
269-
Err(vm.new_os_error(format!("error code: {}", res)))
277+
let mut ret_key = std::ptr::null_mut();
278+
let res = unsafe {
279+
Registry::RegConnectRegistryW(
280+
std::ptr::null_mut(),
281+
*key.hkey.read(),
282+
&mut ret_key
283+
)
284+
};
285+
if res == 0 {
286+
Ok(PyHKEYObject::new(ret_key))
287+
} else {
288+
Err(vm.new_os_error(format!("error code: {}", res)))
289+
}
270290
}
271291
}
272292

@@ -368,6 +388,144 @@ mod winreg {
368388
}
369389
}
370390

391+
// #[pyfunction]
392+
// fn EnumKey(key: PyRef<PyHKEYObject>, index: i32, vm: &VirtualMachine) -> PyResult<String> {
393+
// let mut tmpbuf = [0u16; 257];
394+
// let mut len = std::mem::sizeof(tmpbuf.len())/std::mem::sizeof(tmpbuf[0]);
395+
// let res = unsafe {
396+
// Registry::RegEnumKeyExW(
397+
// *key.hkey.read(),
398+
// index as u32,
399+
// tmpbuf.as_mut_ptr(),
400+
// &mut len,
401+
// std::ptr::null_mut(),
402+
// std::ptr::null_mut(),
403+
// std::ptr::null_mut(),
404+
// std::ptr::null_mut(),
405+
// )
406+
// };
407+
// if res != 0 {
408+
// return Err(vm.new_os_error(format!("error code: {}", res)));
409+
// }
410+
// let s = String::from_utf16(&tmpbuf[..len as usize])
411+
// .map_err(|e| vm.new_value_error(format!("UTF16 error: {}", e)))?;
412+
// Ok(s)
413+
// }
414+
415+
#[pyfunction]
416+
fn EnumValue(hkey: PyRef<PyHKEYObject>, index: u32, vm: &VirtualMachine) -> PyResult {
417+
// Query registry for the required buffer sizes.
418+
let mut ret_value_size: u32 = 0;
419+
let mut ret_data_size: u32 = 0;
420+
let hkey: *mut std::ffi::c_void = *hkey.hkey.read();
421+
let rc = unsafe {
422+
Registry::RegQueryInfoKeyW(
423+
hkey,
424+
ptr::null_mut(),
425+
ptr::null_mut(),
426+
ptr::null_mut(),
427+
ptr::null_mut(),
428+
ptr::null_mut(),
429+
ptr::null_mut(),
430+
ptr::null_mut(),
431+
&mut ret_value_size as *mut u32,
432+
&mut ret_data_size as *mut u32,
433+
ptr::null_mut(),
434+
ptr::null_mut(),
435+
)
436+
};
437+
if rc != 0 {
438+
return Err(vm.new_os_error(format!(
439+
"RegQueryInfoKeyW failed with error code {}",
440+
rc
441+
)));
442+
}
443+
444+
// Include room for null terminators.
445+
ret_value_size += 1;
446+
ret_data_size += 1;
447+
let mut buf_value_size = ret_value_size;
448+
let mut buf_data_size = ret_data_size;
449+
450+
// Allocate buffers.
451+
let mut ret_value_buf: Vec<u16> = vec![0; ret_value_size as usize];
452+
let mut ret_data_buf: Vec<u8> = vec![0; ret_data_size as usize];
453+
454+
// Loop to enumerate the registry value.
455+
loop {
456+
let mut current_value_size = ret_value_size;
457+
let mut current_data_size = ret_data_size;
458+
let rc = unsafe {
459+
Registry::RegEnumValueW(
460+
hkey,
461+
index,
462+
ret_value_buf.as_mut_ptr(),
463+
&mut current_value_size as *mut u32,
464+
ptr::null_mut(),
465+
{
466+
// typ will hold the registry data type.
467+
let mut t = 0u32;
468+
&mut t
469+
},
470+
ret_data_buf.as_mut_ptr(),
471+
&mut current_data_size as *mut u32,
472+
)
473+
};
474+
if rc == ERROR_MORE_DATA {
475+
// Double the buffer sizes.
476+
buf_data_size *= 2;
477+
buf_value_size *= 2;
478+
ret_data_buf.resize(buf_data_size as usize, 0);
479+
ret_value_buf.resize(buf_value_size as usize, 0);
480+
// Reset sizes for next iteration.
481+
ret_value_size = buf_value_size;
482+
ret_data_size = buf_data_size;
483+
continue;
484+
}
485+
if rc != 0 {
486+
return Err(vm.new_os_error(format!(
487+
"RegEnumValueW failed with error code {}",
488+
rc
489+
)));
490+
}
491+
492+
// At this point, current_value_size and current_data_size have been updated.
493+
// Retrieve the registry type.
494+
let mut reg_type: u32 = 0;
495+
unsafe {
496+
Registry::RegEnumValueW(
497+
hkey,
498+
index,
499+
ret_value_buf.as_mut_ptr(),
500+
&mut current_value_size as *mut u32,
501+
ptr::null_mut(),
502+
&mut reg_type as *mut u32,
503+
ret_data_buf.as_mut_ptr(),
504+
&mut current_data_size as *mut u32,
505+
)
506+
};
507+
508+
// Convert the registry value name from UTF‑16.
509+
let name_len = ret_value_buf
510+
.iter()
511+
.position(|&c| c == 0)
512+
.unwrap_or(ret_value_buf.len());
513+
let name = String::from_utf16(&ret_value_buf[..name_len])
514+
.map_err(|e| vm.new_value_error(format!("UTF16 conversion error: {}", e)))?;
515+
516+
// Slice the data buffer to the actual size returned.
517+
let data_slice = &ret_data_buf[..current_data_size as usize];
518+
let py_data = reg_to_py(vm, data_slice, reg_type)?;
519+
520+
// Return tuple (value_name, data, type)
521+
return Ok(vm.ctx.new_tuple(vec![
522+
vm.ctx.new_str(name).into(),
523+
py_data,
524+
vm.ctx.new_int(reg_type).into(),
525+
]).into());
526+
}
527+
}
528+
371529
#[pyfunction]
372530
fn FlushKey(key: PyRef<PyHKEYObject>, vm: &VirtualMachine) -> PyResult<()> {
373531
let res = unsafe { Registry::RegFlushKey(*key.hkey.read()) };
@@ -430,11 +588,11 @@ mod winreg {
430588
}
431589

432590
#[pyfunction]
433-
fn QueryInfoKey(key: PyRef<PyHKEYObject>, vm: &VirtualMachine) -> PyResult<()> {
591+
fn QueryInfoKey(key: PyRef<PyHKEYObject>, vm: &VirtualMachine) -> PyResult<PyRef<PyTuple>> {
434592
let key = *key.hkey.read();
435593
let mut lpcsubkeys: u32 = 0;
436594
let mut lpcvalues: u32 = 0;
437-
let lpftlastwritetime: *mut Foundation::FILETIME = std::ptr::null_mut();
595+
let mut lpftlastwritetime: Foundation::FILETIME = unsafe { std::mem::zeroed() };
438596
let err = unsafe {
439597
Registry::RegQueryInfoKeyW(
440598
key,
@@ -448,15 +606,16 @@ mod winreg {
448606
std::ptr::null_mut(),
449607
std::ptr::null_mut(),
450608
std::ptr::null_mut(),
451-
lpftlastwritetime,
609+
&mut lpftlastwritetime,
452610
)
453611
};
454612

455613
if err != 0 {
456-
Err(vm.new_os_error(format!("error code: {}", err)))
457-
} else {
458-
Ok(())
614+
return Err(vm.new_os_error(format!("error code: {}", err)));
459615
}
616+
let l: u64 = (lpftlastwritetime.dwHighDateTime as u64) << 32 | lpftlastwritetime.dwLowDateTime as u64;
617+
let tup: Vec<PyObjectRef> = vec![vm.ctx.new_int(lpcsubkeys).into(), vm.ctx.new_int(lpcvalues).into(), vm.ctx.new_int(l).into()];
618+
Ok(vm.ctx.new_tuple(tup))
460619
}
461620

462621
#[pyfunction]
@@ -481,7 +640,44 @@ mod winreg {
481640
Ok(())
482641
}
483642

484-
// TODO: QueryValueEx
643+
#[pyfunction]
644+
fn QueryValueEx(key: PyRef<PyHKEYObject>, name: String, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
645+
let wide_name = to_utf16(name);
646+
let mut buf_size = 0;
647+
let res = unsafe {
648+
Registry::RegQueryValueExW(
649+
*key.hkey.read(),
650+
wide_name.as_ptr(),
651+
std::ptr::null_mut(),
652+
std::ptr::null_mut(),
653+
std::ptr::null_mut(),
654+
&mut buf_size,
655+
)
656+
};
657+
// TODO: res == ERROR_MORE_DATA
658+
if res != 0 {
659+
return Err(vm.new_os_error(format!("error code: {}", res)));
660+
}
661+
let mut retBuf = Vec::with_capacity(buf_size as usize);
662+
let mut typ = 0;
663+
let res = unsafe {
664+
Registry::RegQueryValueExW(
665+
*key.hkey.read(),
666+
wide_name.as_ptr(),
667+
std::ptr::null_mut(),
668+
&mut typ,
669+
retBuf.as_mut_ptr(),
670+
&mut buf_size,
671+
)
672+
};
673+
// TODO: res == ERROR_MORE_DATA
674+
if res != 0 {
675+
return Err(vm.new_os_error(format!("error code: {}", res)));
676+
}
677+
let obj = reg_to_py(vm, retBuf.as_slice(), typ)?;
678+
Ok(obj)
679+
}
680+
485681
#[pyfunction]
486682
fn SaveKey(key: PyRef<PyHKEYObject>, file_name: String, vm: &VirtualMachine) -> PyResult<()> {
487683
let file_name = to_utf16(file_name);
@@ -530,6 +726,73 @@ mod winreg {
530726
}
531727
}
532728

729+
fn reg_to_py(vm: &VirtualMachine, ret_data: &[u8], typ: u32) -> PyResult {
730+
match typ {
731+
REG_DWORD => {
732+
// If there isn’t enough data, return 0.
733+
if ret_data.len() < std::mem::size_of::<u32>() {
734+
Ok(vm.ctx.new_int(0).into())
735+
} else {
736+
let val = u32::from_ne_bytes(ret_data[..4].try_into().unwrap());
737+
Ok(vm.ctx.new_int(val).into())
738+
}
739+
}
740+
REG_QWORD => {
741+
if ret_data.len() < std::mem::size_of::<u64>() {
742+
Ok(vm.ctx.new_int(0).into())
743+
} else {
744+
let val = u64::from_ne_bytes(ret_data[..8].try_into().unwrap());
745+
Ok(vm.ctx.new_int(val).into())
746+
}
747+
}
748+
REG_SZ | REG_EXPAND_SZ => {
749+
// Treat the data as a UTF-16 string.
750+
let u16_count = ret_data.len() / 2;
751+
let u16_slice = unsafe {
752+
std::slice::from_raw_parts(ret_data.as_ptr() as *const u16, u16_count)
753+
};
754+
// Only use characters up to the first NUL.
755+
let len = u16_slice.iter().position(|&c| c == 0).unwrap_or(u16_slice.len());
756+
let s = String::from_utf16(&u16_slice[..len])
757+
.map_err(|e| vm.new_value_error(format!("UTF16 error: {}", e)))?;
758+
Ok(vm.ctx.new_str(s).into())
759+
}
760+
REG_MULTI_SZ => {
761+
if ret_data.is_empty() {
762+
Ok(vm.ctx.new_list(vec![]).into())
763+
} else {
764+
let u16_count = ret_data.len() / 2;
765+
let u16_slice = unsafe {
766+
std::slice::from_raw_parts(ret_data.as_ptr() as *const u16, u16_count)
767+
};
768+
let mut strings: Vec<PyObjectRef> = Vec::new();
769+
let mut start = 0;
770+
for (i, &c) in u16_slice.iter().enumerate() {
771+
if c == 0 {
772+
// An empty string signals the end.
773+
if start == i {
774+
break;
775+
}
776+
let s = String::from_utf16(&u16_slice[start..i])
777+
.map_err(|e| vm.new_value_error(format!("UTF16 error: {}", e)))?;
778+
strings.push(vm.ctx.new_str(s).into());
779+
start = i + 1;
780+
}
781+
}
782+
Ok(vm.ctx.new_list(strings).into())
783+
}
784+
}
785+
// For REG_BINARY and any other unknown types, return a bytes object if data exists.
786+
_ => {
787+
if ret_data.is_empty() {
788+
Ok(vm.ctx.none())
789+
} else {
790+
Ok(vm.ctx.new_bytes(ret_data.to_vec()).into())
791+
}
792+
}
793+
}
794+
}
795+
533796
fn py2reg(value: PyObjectRef, typ: u32, vm: &VirtualMachine) -> PyResult<Option<Vec<u8>>> {
534797
match typ {
535798
REG_DWORD => {
@@ -569,7 +832,7 @@ mod winreg {
569832
fn SetValueEx(
570833
key: PyRef<PyHKEYObject>,
571834
value_name: String,
572-
reserved: u32,
835+
_reserved: u32,
573836
typ: u32,
574837
value: PyObjectRef,
575838
vm: &VirtualMachine,

0 commit comments

Comments
 (0)