Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions Lib/test/test_binascii.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ def test_base64valid(self):
res += b
self.assertEqual(res, self.rawdata)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_base64invalid(self):
# Test base64 with random invalid characters sprinkled throughout
# (This requires a new version of binascii.)
Expand Down Expand Up @@ -114,8 +112,6 @@ def addnoise(line):
# empty strings. TBD: shouldn't it raise an exception instead ?
self.assertEqual(binascii.a2b_base64(self.type2test(fillers)), b'')

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_base64_strict_mode(self):
# Test base64 with strict mode on
def _assertRegexTemplate(assert_regex: str, data: bytes, non_strict_mode_expected_result: bytes):
Expand Down Expand Up @@ -159,8 +155,6 @@ def assertDiscontinuousPadding(data, non_strict_mode_expected_result: bytes):
assertDiscontinuousPadding(b'ab=c=', b'i\xb7')
assertDiscontinuousPadding(b'ab=ab==', b'i\xb6\x9b')

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_base64errors(self):
# Test base64 with invalid padding
def assertIncorrectPadding(data):
Expand Down
160 changes: 129 additions & 31 deletions stdlib/src/binascii.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
pub(super) use decl::crc32;
pub(crate) use decl::make_module;
use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine};

pub(super) use decl::crc32;

pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
base64::decode_config(input, base64::STANDARD.decode_allow_trailing_bits(true))
}
const PAD: u8 = 61u8;
const MAXLINESIZE: usize = 76; // Excluding the CRLF

#[pymodule(name = "binascii")]
mod decl {
use super::decode;
use super::{MAXLINESIZE, PAD};
use crate::vm::{
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
builtins::{PyIntRef, PyTypeRef},
convert::ToPyException,
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
PyResult, VirtualMachine,
};
use itertools::Itertools;

const MAXLINESIZE: usize = 76;

#[pyattr(name = "Error", once)]
fn error_type(vm: &VirtualMachine) -> PyTypeRef {
pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx.new_exception_type(
"binascii",
"Error",
Expand Down Expand Up @@ -67,15 +65,18 @@ mod decl {
fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
data.with_ref(|hex_bytes| {
if hex_bytes.len() % 2 != 0 {
return Err(new_binascii_error("Odd-length string".to_owned(), vm));
return Err(super::new_binascii_error(
"Odd-length string".to_owned(),
vm,
));
}

let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
for (n1, n2) in hex_bytes.iter().tuples() {
if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) {
unhex.push(n1 << 4 | n2);
} else {
return Err(new_binascii_error(
return Err(super::new_binascii_error(
"Non-hexadecimal digit found".to_owned(),
vm,
));
Expand Down Expand Up @@ -144,13 +145,20 @@ mod decl {
newline: bool,
}

fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(error_type(vm), msg)
#[derive(FromArgs)]
struct A2bBase64Args {
#[pyarg(any)]
s: ArgAsciiBuffer,
#[pyarg(named, default = "false")]
strict_mode: bool,
}

#[pyfunction]
fn a2b_base64(s: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
fn a2b_base64(args: A2bBase64Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
#[rustfmt::skip]
// Converts between ASCII and base-64 characters. The index of a given number yields the
// number in ASCII while the value of said index yields the number in base-64. For example
// "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
const BASE64_TABLE: [i8; 256] = [
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
Expand All @@ -171,25 +179,92 @@ mod decl {
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
];

let A2bBase64Args { s, strict_mode } = args;
s.with_ref(|b| {
let decoded = if b.len() % 4 == 0 {
decode(b)
} else {
Err(base64::DecodeError::InvalidLength)
};
decoded.or_else(|_| {
let buf: Vec<_> = b
.iter()
.copied()
.filter(|&c| BASE64_TABLE[c as usize] != -1)
.collect();
if buf.len() % 4 != 0 {
return Err(base64::DecodeError::InvalidLength);
if b.is_empty() {
return Ok(vec![]);
}

if strict_mode && b[0] == PAD {
return Err(base64::DecodeError::InvalidByte(0, 61));
}

let mut decoded: Vec<u8> = vec![];

let mut quad_pos = 0; // position in the nibble
let mut pads = 0;
let mut left_char: u8 = 0;
let mut padding_started = false;
for (i, &el) in b.iter().enumerate() {
if el == PAD {
padding_started = true;

pads += 1;
if quad_pos >= 2 && quad_pos + pads >= 4 {
if strict_mode && i + 1 < b.len() {
// Represents excess data after padding error
return Err(base64::DecodeError::InvalidLastSymbol(i, PAD));
}

return Ok(decoded);
}

continue;
}
decode(&buf)
})

let binary_char = BASE64_TABLE[el as usize];
if binary_char >= 64 || binary_char == -1 {
if strict_mode {
// Represents non-base64 data error
return Err(base64::DecodeError::InvalidByte(i, el));
}
continue;
}

if strict_mode && padding_started {
// Represents discontinuous padding error
return Err(base64::DecodeError::InvalidByte(i, PAD));
}
pads = 0;

// Decode individual ASCII character
match quad_pos {
0 => {
quad_pos = 1;
left_char = binary_char as u8;
}
1 => {
quad_pos = 2;
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
left_char = (binary_char & 0x0f) as u8;
}
2 => {
quad_pos = 3;
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
left_char = (binary_char & 0x03) as u8;
}
3 => {
quad_pos = 0;
decoded.push((left_char << 6) | binary_char as u8);
left_char = 0;
}
_ => unsafe {
// quad_pos is only assigned in this match statement to constants
std::hint::unreachable_unchecked()
},
}
}

match quad_pos {
0 => Ok(decoded),
1 => Err(base64::DecodeError::InvalidLastSymbol(
decoded.len() / 3 * 4 + 1,
0,
)),
_ => Err(base64::DecodeError::InvalidLength),
}
})
.map_err(|err| new_binascii_error(format!("error decoding base64: {err}"), vm))
.map_err(|err| super::Base64DecodeError(err).to_pyexception(vm))
}

#[pyfunction]
Expand Down Expand Up @@ -654,3 +729,26 @@ mod decl {
})
}
}

struct Base64DecodeError(base64::DecodeError);

fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(decl::error_type(vm), msg)
}

impl ToPyException for Base64DecodeError {
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
use base64::DecodeError::*;
let message = match self.0 {
InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(),
InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(),
InvalidByte(_, _) => "Only base64 data is allowed".to_owned(),
InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(),
InvalidLastSymbol(length, _) => {
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
}
InvalidLength => "Incorrect padding".to_owned(),
};
new_binascii_error(format!("error decoding base64: {message}"), vm)
}
}