From d243b73b6757439a9465b2a8fad6137102b2e1b6 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Tue, 17 Mar 2020 20:00:46 -0500 Subject: [PATCH 01/17] Fix int.from_bytes() with an arbitrary iterator --- vm/src/obj/objbyteinner.rs | 17 +++++++++++------ vm/src/pyobject.rs | 7 +++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index a6be440139..58030e219e 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -40,10 +40,15 @@ impl TryFromObject for PyByteInner { elements: k.try_value().unwrap() }), l @ PyList => l.get_byte_inner(vm), - obj => Err(vm.new_type_error(format!( - "a bytes-like object is required, not {}", - obj.class() - ))), + obj => { + let iter = vm.get_method_or_type_error(obj.clone(), "__iter__", || { + format!("a bytes-like object is required, not {}", obj.class()) + })?; + let iter = PyIterable::from_method(iter); + Ok(PyByteInner { + elements: iter.iter(vm)?.collect::>()?, + }) + } }) } } @@ -330,8 +335,8 @@ impl PyByteInner { where F: Fn(&[u8], &[u8]) -> bool, { - let r = PyByteInner::try_from_object(vm, other) - .map(|other| op(&self.elements, &other.elements)); + let r = PyBytesLike::try_from_object(vm, other) + .map(|other| other.with_ref(|other| op(&self.elements, other))); PyComparisonValue::from_option(r.ok()) } diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 786e1a6bf6..f0265d1c73 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -919,6 +919,13 @@ pub struct PyIterable { } impl PyIterable { + pub fn from_method(method: PyObjectRef) -> Self { + PyIterable { + method, + _item: std::marker::PhantomData, + } + } + /// Returns an iterator over this sequence of objects. /// /// This operation may fail if an exception is raised while invoking the From 0f4b3581018564eea02420a70e8e2fb4ec6b11a7 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Tue, 17 Mar 2020 20:01:07 -0500 Subject: [PATCH 02/17] Add ipaddress.py from CPython 3.6 --- Lib/ipaddress.py | 2266 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 2266 insertions(+) create mode 100644 Lib/ipaddress.py diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py new file mode 100644 index 0000000000..583f02ad54 --- /dev/null +++ b/Lib/ipaddress.py @@ -0,0 +1,2266 @@ +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +__version__ = '1.0' + + +import functools + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % + address) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % + address) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % + address) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return address.to_bytes(4, 'big') + except OverflowError: + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return address.to_bytes(16, 'big') + except OverflowError: + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = str(address).split('/') + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, (~number & (number-1)).bit_length()) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if (not (isinstance(first, _BaseAddress) and + isinstance(last, _BaseAddress))): + raise TypeError('first and last must be IP addresses, not networks') + if first.version != last.version: + raise TypeError("%s and %s are not of the same version" % ( + first, last)) + if first > last: + raise ValueError('last IP address must be greater than first') + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError('unknown IP version') + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min(_count_righthand_zero_bits(first_int, ip_bits), + (last_int - first_int + 1).bit_length() - 1) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, last.network_address <= net.network_address + # is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, nets[-1])) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase: + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = '%200s has no version specified' % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._max_prefixlen, + self._version)) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = "%r (len %d != %d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, address_len, + expected_len, self._version)) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits(ip_int, + cls._max_prefixlen) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = ip_int.to_bytes(byteslen, 'big') + msg = 'Netmask pattern %r mixes zeroes & ones' + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = '%r is not a valid netmask' % netmask_str + raise NetmaskValueError(msg) from None + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (str(self),) + + +@functools.total_ordering +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return (self._ip == other._ip + and self._version == other._version) + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _BaseAddress): + return NotImplemented + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, int): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, int): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, str(self)) + + def __str__(self): + return str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +@functools.total_ordering +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, str(self)) + + def __str__(self): + return '%s/%d' % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError('address out of range') + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError('address out of range') + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _BaseNetwork): + return NotImplemented + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return (self._version == other._version and + self.network_address == other.network_address and + int(self.netmask) == int(other.netmask)) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return (int(self.network_address) <= int(other._ip) <= + int(self.broadcast_address)) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other or ( + other.network_address in self or ( + other.broadcast_address in self))) + + @property + def broadcast_address(self): + x = self._cache.get('broadcast_address') + if x is None: + x = self._address_class(int(self.network_address) | + int(self.hostmask)) + self._cache['broadcast_address'] = x + return x + + @property + def hostmask(self): + x = self._cache.get('hostmask') + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache['hostmask'] = x + return x + + @property + def with_prefixlen(self): + return '%s/%d' % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = '%200s has no associated address class' % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError("%s and %s are not of the same version" % ( + self, other)) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not (other.network_address >= self.network_address and + other.broadcast_address <= self.broadcast_address): + raise ValueError('%s not contained in %s' % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__('%s/%s' % (other.network_address, + other.prefixlen)) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if (other.network_address >= s1.network_address and + other.broadcast_address <= s1.broadcast_address): + yield s2 + s1, s2 = s1.subnets() + elif (other.network_address >= s2.network_address and + other.broadcast_address <= s2.broadcast_address): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError('new prefix must be longer') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError('prefix length diff must be > 0') + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + 'prefix length diff %d is invalid for netblock %s' % ( + new_prefixlen, self)) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError('new prefix must be shorter') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % + (self.prefixlen, prefixlen_diff)) + return self.__class__(( + int(self.network_address) & (int(self.netmask) << prefixlen_diff), + new_prefixlen + )) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return (self.network_address.is_multicast and + self.broadcast_address.is_multicast) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return (self.network_address.is_reserved and + self.broadcast_address.is_reserved) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return (self.network_address.is_link_local and + self.broadcast_address.is_link_local) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return (self.network_address.is_private and + self.broadcast_address.is_private) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return (self.network_address.is_unspecified and + self.broadcast_address.is_unspecified) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return (self.network_address.is_loopback and + self.broadcast_address.is_loopback) + + +class _BaseV4: + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2**IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset('0123456789') + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset({255, 254, 252, 248, 240, 224, 192, 128, 0}) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, int): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + octets = ip_str.split('.') + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return int.from_bytes(map(cls._parse_octet, octets), 'big') + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == '0': + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return '.'.join(map(str, ip_int.to_bytes(4, 'big'))) + + def _is_valid_netmask(self, netmask): + """Verify that the netmask is valid. + + Args: + netmask: A string, either a prefix or dotted decimal + netmask. + + Returns: + A boolean, True if the prefix represents a valid IPv4 + netmask. + + """ + mask = netmask.split('.') + if len(mask) == 4: + try: + for x in mask: + if int(x) not in self._valid_mask_octets: + return False + except ValueError: + # Found something that isn't an integer or isn't valid + return False + for idx, y in enumerate(mask): + if idx > 0 and y > mask[idx - 1]: + return False + return True + try: + netmask = int(netmask) + except ValueError: + return False + return 0 <= netmask <= self._max_prefixlen + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split('.') + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = str(self).split('.')[::-1] + return '.'.join(reverse_octets) + '.in-addr.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, int): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + self._ip = int.from_bytes(address, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + @functools.lru_cache() + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + @functools.lru_cache() + def is_global(self): + return self not in self._constants._public_network and not self.is_private + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + + def __init__(self, address): + if isinstance(address, (bytes, int)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (int, bytes)): + addr = address + mask = self._max_prefixlen + # Constructing from a tuple (addr, [mask]) + elif isinstance(address, tuple): + addr = address[0] + mask = address[1] if len(address) > 1 else self._max_prefixlen + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + else: + args = _split_optional_netmask(address) + addr = self._ip_int_from_string(args[0]) + mask = args[1] if len(args) == 2 else self._max_prefixlen + + self.network_address = IPv4Address(addr) + self.netmask, self._prefixlen = self._make_netmask(mask) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv4Address(packed & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + @functools.lru_cache() + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return (not (self.network_address in IPv4Network('100.64.0.0/10') and + self.broadcast_address in IPv4Network('100.64.0.0/10')) and + not self.is_private) + + +class _IPv4Constants: + _linklocal_network = IPv4Network('169.254.0.0/16') + + _loopback_network = IPv4Network('127.0.0.0/8') + + _multicast_network = IPv4Network('224.0.0.0/4') + + _public_network = IPv4Network('100.64.0.0/10') + + _private_networks = [ + IPv4Network('0.0.0.0/8'), + IPv4Network('10.0.0.0/8'), + IPv4Network('127.0.0.0/8'), + IPv4Network('169.254.0.0/16'), + IPv4Network('172.16.0.0/12'), + IPv4Network('192.0.0.0/29'), + IPv4Network('192.0.0.170/31'), + IPv4Network('192.0.2.0/24'), + IPv4Network('192.168.0.0/16'), + IPv4Network('198.18.0.0/15'), + IPv4Network('198.51.100.0/24'), + IPv4Network('203.0.113.0/24'), + IPv4Network('240.0.0.0/4'), + IPv4Network('255.255.255.255/32'), + ] + + _reserved_network = IPv4Network('240.0.0.0/4') + + _unspecified_address = IPv4Address('0.0.0.0') + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6: + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2**IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, int): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + parts = ip_str.split(':') + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if '.' in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None + parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) + parts.append('%x' % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % (_max_parts-1, ip_str) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT-1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == '0': + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = (best_doublecolon_start + + best_doublecolon_len) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [''] + hextets[best_doublecolon_start:best_doublecolon_end] = [''] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [''] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError('IPv6 address is too large') + + hex_str = '%032x' % ip_int + hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ':'.join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = str(self.ip) + else: + ip_str = str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = '%032x' % ip_int + parts = [hex_str[x:x+4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return '%s/%d' % (':'.join(parts), self._prefixlen) + return ':'.join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(':', '') + return '.'.join(reverse_chars) + '.ip6.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, int): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + self._ip = int.from_bytes(address, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + @functools.lru_cache() + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF)) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + + def __init__(self, address): + if isinstance(address, (bytes, int)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (int, bytes)): + addr = address + mask = self._max_prefixlen + # Constructing from a tuple (addr, [mask]) + elif isinstance(address, tuple): + addr = address[0] + mask = address[1] if len(address) > 1 else self._max_prefixlen + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + else: + args = _split_optional_netmask(address) + addr = self._ip_int_from_string(args[0]) + mask = args[1] if len(args) == 2 else self._max_prefixlen + + self.network_address = IPv6Address(addr) + self.netmask, self._prefixlen = self._make_netmask(mask) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv6Address(packed & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return (self.network_address.is_site_local and + self.broadcast_address.is_site_local) + + +class _IPv6Constants: + + _linklocal_network = IPv6Network('fe80::/10') + + _multicast_network = IPv6Network('ff00::/8') + + _private_networks = [ + IPv6Network('::1/128'), + IPv6Network('::/128'), + IPv6Network('::ffff:0:0/96'), + IPv6Network('100::/64'), + IPv6Network('2001::/23'), + IPv6Network('2001:2::/48'), + IPv6Network('2001:db8::/32'), + IPv6Network('2001:10::/28'), + IPv6Network('fc00::/7'), + IPv6Network('fe80::/10'), + ] + + _reserved_networks = [ + IPv6Network('::/8'), IPv6Network('100::/8'), + IPv6Network('200::/7'), IPv6Network('400::/6'), + IPv6Network('800::/5'), IPv6Network('1000::/4'), + IPv6Network('4000::/3'), IPv6Network('6000::/3'), + IPv6Network('8000::/3'), IPv6Network('A000::/3'), + IPv6Network('C000::/3'), IPv6Network('E000::/4'), + IPv6Network('F000::/5'), IPv6Network('F800::/6'), + IPv6Network('FE00::/9'), + ] + + _sitelocal_network = IPv6Network('fec0::/10') + + +IPv6Address._constants = _IPv6Constants From d3cc3601d65d9bd1352c9cba6142cd55b394dfe4 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 00:05:55 -0500 Subject: [PATCH 03/17] Add ssl.py from CPython 3.6 --- Lib/ssl.py | 1237 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1237 insertions(+) create mode 100644 Lib/ssl.py diff --git a/Lib/ssl.py b/Lib/ssl.py new file mode 100644 index 0000000000..58d3e93922 --- /dev/null +++ b/Lib/ssl.py @@ -0,0 +1,1237 @@ +# Wrapper module for _ssl, providing some additional facilities +# implemented in Python. Written by Bill Janssen. + +"""This module provides some more Pythonic support for SSL. + +Object types: + + SSLSocket -- subtype of socket.socket which does SSL over the socket + +Exceptions: + + SSLError -- exception raised for I/O errors + +Functions: + + cert_time_to_seconds -- convert time string used for certificate + notBefore and notAfter functions to integer + seconds past the Epoch (the time values + returned from time.time()) + + fetch_server_certificate (HOST, PORT) -- fetch the certificate provided + by the server running on HOST at port PORT. No + validation of the certificate is performed. + +Integer constants: + +SSL_ERROR_ZERO_RETURN +SSL_ERROR_WANT_READ +SSL_ERROR_WANT_WRITE +SSL_ERROR_WANT_X509_LOOKUP +SSL_ERROR_SYSCALL +SSL_ERROR_SSL +SSL_ERROR_WANT_CONNECT + +SSL_ERROR_EOF +SSL_ERROR_INVALID_ERROR_CODE + +The following group define certificate requirements that one side is +allowing/requiring from the other side: + +CERT_NONE - no certificates from the other side are required (or will + be looked at if provided) +CERT_OPTIONAL - certificates are not required, but if provided will be + validated, and if validation fails, the connection will + also fail +CERT_REQUIRED - certificates are required, and will be validated, and + if validation fails, the connection will also fail + +The following constants identify various SSL protocol variants: + +PROTOCOL_SSLv2 +PROTOCOL_SSLv3 +PROTOCOL_SSLv23 +PROTOCOL_TLS +PROTOCOL_TLS_CLIENT +PROTOCOL_TLS_SERVER +PROTOCOL_TLSv1 +PROTOCOL_TLSv1_1 +PROTOCOL_TLSv1_2 + +The following constants identify various SSL alert message descriptions as per +http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 + +ALERT_DESCRIPTION_CLOSE_NOTIFY +ALERT_DESCRIPTION_UNEXPECTED_MESSAGE +ALERT_DESCRIPTION_BAD_RECORD_MAC +ALERT_DESCRIPTION_RECORD_OVERFLOW +ALERT_DESCRIPTION_DECOMPRESSION_FAILURE +ALERT_DESCRIPTION_HANDSHAKE_FAILURE +ALERT_DESCRIPTION_BAD_CERTIFICATE +ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE +ALERT_DESCRIPTION_CERTIFICATE_REVOKED +ALERT_DESCRIPTION_CERTIFICATE_EXPIRED +ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN +ALERT_DESCRIPTION_ILLEGAL_PARAMETER +ALERT_DESCRIPTION_UNKNOWN_CA +ALERT_DESCRIPTION_ACCESS_DENIED +ALERT_DESCRIPTION_DECODE_ERROR +ALERT_DESCRIPTION_DECRYPT_ERROR +ALERT_DESCRIPTION_PROTOCOL_VERSION +ALERT_DESCRIPTION_INSUFFICIENT_SECURITY +ALERT_DESCRIPTION_INTERNAL_ERROR +ALERT_DESCRIPTION_USER_CANCELLED +ALERT_DESCRIPTION_NO_RENEGOTIATION +ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION +ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE +ALERT_DESCRIPTION_UNRECOGNIZED_NAME +ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE +ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE +ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY +""" + +import ipaddress +import textwrap +import re +import sys +import os +from collections import namedtuple +from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag + +import _ssl # if we can't import it, let the error propagate + +from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION +from _ssl import _SSLContext, MemoryBIO, SSLSession +from _ssl import ( + SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, + SSLSyscallError, SSLEOFError, + ) +from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj +from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes +try: + from _ssl import RAND_egd +except ImportError: + # LibreSSL does not provide RAND_egd + pass + + +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3 +from _ssl import _OPENSSL_API_VERSION + + +_IntEnum._convert( + '_SSLMethod', __name__, + lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23', + source=_ssl) + +_IntFlag._convert( + 'Options', __name__, + lambda name: name.startswith('OP_'), + source=_ssl) + +_IntEnum._convert( + 'AlertDescription', __name__, + lambda name: name.startswith('ALERT_DESCRIPTION_'), + source=_ssl) + +_IntEnum._convert( + 'SSLErrorNumber', __name__, + lambda name: name.startswith('SSL_ERROR_'), + source=_ssl) + +_IntFlag._convert( + 'VerifyFlags', __name__, + lambda name: name.startswith('VERIFY_'), + source=_ssl) + +_IntEnum._convert( + 'VerifyMode', __name__, + lambda name: name.startswith('CERT_'), + source=_ssl) + + +PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS +_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} + +_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None) + + +if sys.platform == "win32": + from _ssl import enum_certificates, enum_crls + +from socket import socket, AF_INET, SOCK_STREAM, create_connection +from socket import SOL_SOCKET, SO_TYPE +import base64 # for DER-to-PEM translation +import errno +import warnings + + +socket_error = OSError # keep that public name in module namespace + +if _ssl.HAS_TLS_UNIQUE: + CHANNEL_BINDING_TYPES = ['tls-unique'] +else: + CHANNEL_BINDING_TYPES = [] + + +# Disable weak or insecure ciphers by default +# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL') +# Enable a better set of ciphers by default +# This list has been explicitly chosen to: +# * TLS 1.3 ChaCha20 and AES-GCM cipher suites +# * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE) +# * Prefer ECDHE over DHE for better performance +# * Prefer AEAD over CBC for better performance and security +# * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI +# (ChaCha20 needs OpenSSL 1.1.0 or patched 1.0.2) +# * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better +# performance and security +# * Then Use HIGH cipher suites as a fallback +# * Disable NULL authentication, NULL encryption, 3DES and MD5 MACs +# for security reasons +_DEFAULT_CIPHERS = ( + 'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:' + 'TLS13-AES-128-GCM-SHA256:' + 'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:' + 'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:' + '!aNULL:!eNULL:!MD5:!3DES' + ) + +# Restricted and more secure ciphers for the server side +# This list has been explicitly chosen to: +# * TLS 1.3 ChaCha20 and AES-GCM cipher suites +# * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE) +# * Prefer ECDHE over DHE for better performance +# * Prefer AEAD over CBC for better performance and security +# * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI +# * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better +# performance and security +# * Then Use HIGH cipher suites as a fallback +# * Disable NULL authentication, NULL encryption, MD5 MACs, DSS, RC4, and +# 3DES for security reasons +_RESTRICTED_SERVER_CIPHERS = ( + 'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:' + 'TLS13-AES-128-GCM-SHA256:' + 'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:' + 'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:' + '!aNULL:!eNULL:!MD5:!DSS:!RC4:!3DES' +) + + +class CertificateError(ValueError): + pass + + +def _dnsname_match(dn, hostname, max_wildcards=1): + """Matching according to RFC 6125, section 6.4.3 + + http://tools.ietf.org/html/rfc6125#section-6.4.3 + """ + pats = [] + if not dn: + return False + + leftmost, *remainder = dn.split(r'.') + + wildcards = leftmost.count('*') + if wildcards > max_wildcards: + # Issue #17980: avoid denials of service by refusing more + # than one wildcard per fragment. A survey of established + # policy among SSL implementations showed it to be a + # reasonable choice. + raise CertificateError( + "too many wildcards in certificate DNS name: " + repr(dn)) + + # speed up common case w/o wildcards + if not wildcards: + return dn.lower() == hostname.lower() + + # RFC 6125, section 6.4.3, subitem 1. + # The client SHOULD NOT attempt to match a presented identifier in which + # the wildcard character comprises a label other than the left-most label. + if leftmost == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + elif leftmost.startswith('xn--') or hostname.startswith('xn--'): + # RFC 6125, section 6.4.3, subitem 3. + # The client SHOULD NOT attempt to match a presented identifier + # where the wildcard character is embedded within an A-label or + # U-label of an internationalized domain name. + pats.append(re.escape(leftmost)) + else: + # Otherwise, '*' matches any dotless string, e.g. www* + pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) + + # add the remaining fragments, ignore any wildcards + for frag in remainder: + pats.append(re.escape(frag)) + + pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + return pat.match(hostname) + + +def _ipaddress_match(ipname, host_ip): + """Exact matching of IP addresses. + + RFC 6125 explicitly doesn't define an algorithm for this + (section 1.7.2 - "Out of Scope"). + """ + # OpenSSL may add a trailing newline to a subjectAltName's IP address + ip = ipaddress.ip_address(ipname.rstrip()) + return ip == host_ip + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 + rules are followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED") + try: + host_ip = ipaddress.ip_address(hostname) + except ValueError: + # Not an IP address (common case) + host_ip = None + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if host_ip is None and _dnsname_match(value, hostname): + return + dnsnames.append(value) + elif key == 'IP Address': + if host_ip is not None and _ipaddress_match(value, host_ip): + return + dnsnames.append(value) + if not dnsnames: + # The subject is only checked when there is no dNSName entry + # in subjectAltName + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_match(value, hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") + + +DefaultVerifyPaths = namedtuple("DefaultVerifyPaths", + "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env " + "openssl_capath") + +def get_default_verify_paths(): + """Return paths to default cafile and capath. + """ + parts = _ssl.get_default_verify_paths() + + # environment vars shadow paths + cafile = os.environ.get(parts[0], parts[1]) + capath = os.environ.get(parts[2], parts[3]) + + return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None, + capath if os.path.isdir(capath) else None, + *parts) + + +class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")): + """ASN.1 object identifier lookup + """ + __slots__ = () + + def __new__(cls, oid): + return super().__new__(cls, *_txt2obj(oid, name=False)) + + @classmethod + def fromnid(cls, nid): + """Create _ASN1Object from OpenSSL numeric ID + """ + return super().__new__(cls, *_nid2obj(nid)) + + @classmethod + def fromname(cls, name): + """Create _ASN1Object from short name, long name or OID + """ + return super().__new__(cls, *_txt2obj(name, name=True)) + + +class Purpose(_ASN1Object, _Enum): + """SSLContext purpose flags with X509v3 Extended Key Usage objects + """ + SERVER_AUTH = '1.3.6.1.5.5.7.3.1' + CLIENT_AUTH = '1.3.6.1.5.5.7.3.2' + + +class SSLContext(_SSLContext): + """An SSLContext holds various SSL-related configuration options and + data, such as certificates and possibly a private key.""" + + __slots__ = ('protocol', '__weakref__') + _windows_cert_stores = ("CA", "ROOT") + + def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs): + self = _SSLContext.__new__(cls, protocol) + if protocol != _SSLv2_IF_EXISTS: + self.set_ciphers(_DEFAULT_CIPHERS) + return self + + def __init__(self, protocol=PROTOCOL_TLS): + self.protocol = protocol + + def wrap_socket(self, sock, server_side=False, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=None, session=None): + return SSLSocket(sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + _context=self, _session=session) + + def wrap_bio(self, incoming, outgoing, server_side=False, + server_hostname=None, session=None): + sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side, + server_hostname=server_hostname) + return SSLObject(sslobj, session=session) + + def set_npn_protocols(self, npn_protocols): + protos = bytearray() + for protocol in npn_protocols: + b = bytes(protocol, 'ascii') + if len(b) == 0 or len(b) > 255: + raise SSLError('NPN protocols must be 1 to 255 in length') + protos.append(len(b)) + protos.extend(b) + + self._set_npn_protocols(protos) + + def set_alpn_protocols(self, alpn_protocols): + protos = bytearray() + for protocol in alpn_protocols: + b = bytes(protocol, 'ascii') + if len(b) == 0 or len(b) > 255: + raise SSLError('ALPN protocols must be 1 to 255 in length') + protos.append(len(b)) + protos.extend(b) + + self._set_alpn_protocols(protos) + + def _load_windows_store_certs(self, storename, purpose): + certs = bytearray() + try: + for cert, encoding, trust in enum_certificates(storename): + # CA certs are never PKCS#7 encoded + if encoding == "x509_asn": + if trust is True or purpose.oid in trust: + certs.extend(cert) + except PermissionError: + warnings.warn("unable to enumerate Windows certificate store") + if certs: + self.load_verify_locations(cadata=certs) + return certs + + def load_default_certs(self, purpose=Purpose.SERVER_AUTH): + if not isinstance(purpose, _ASN1Object): + raise TypeError(purpose) + if sys.platform == "win32": + for storename in self._windows_cert_stores: + self._load_windows_store_certs(storename, purpose) + self.set_default_verify_paths() + + @property + def options(self): + return Options(super().options) + + @options.setter + def options(self, value): + super(SSLContext, SSLContext).options.__set__(self, value) + + @property + def verify_flags(self): + return VerifyFlags(super().verify_flags) + + @verify_flags.setter + def verify_flags(self, value): + super(SSLContext, SSLContext).verify_flags.__set__(self, value) + + @property + def verify_mode(self): + value = super().verify_mode + try: + return VerifyMode(value) + except ValueError: + return value + + @verify_mode.setter + def verify_mode(self, value): + super(SSLContext, SSLContext).verify_mode.__set__(self, value) + + +def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, + capath=None, cadata=None): + """Create a SSLContext object with default settings. + + NOTE: The protocol and settings may change anytime without prior + deprecation. The values represent a fair balance between maximum + compatibility and security. + """ + if not isinstance(purpose, _ASN1Object): + raise TypeError(purpose) + + # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, + # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE + # by default. + context = SSLContext(PROTOCOL_TLS) + + if purpose == Purpose.SERVER_AUTH: + # verify certs and host name in client mode + context.verify_mode = CERT_REQUIRED + context.check_hostname = True + elif purpose == Purpose.CLIENT_AUTH: + context.set_ciphers(_RESTRICTED_SERVER_CIPHERS) + + if cafile or capath or cadata: + context.load_verify_locations(cafile, capath, cadata) + elif context.verify_mode != CERT_NONE: + # no explicit cafile, capath or cadata but the verify mode is + # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system + # root CA certificates for the given purpose. This may fail silently. + context.load_default_certs(purpose) + return context + +def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, + check_hostname=False, purpose=Purpose.SERVER_AUTH, + certfile=None, keyfile=None, + cafile=None, capath=None, cadata=None): + """Create a SSLContext object for Python stdlib modules + + All Python stdlib modules shall use this function to create SSLContext + objects in order to keep common settings in one place. The configuration + is less restrict than create_default_context()'s to increase backward + compatibility. + """ + if not isinstance(purpose, _ASN1Object): + raise TypeError(purpose) + + # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, + # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE + # by default. + context = SSLContext(protocol) + + if cert_reqs is not None: + context.verify_mode = cert_reqs + context.check_hostname = check_hostname + + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile or keyfile: + context.load_cert_chain(certfile, keyfile) + + # load CA root certs + if cafile or capath or cadata: + context.load_verify_locations(cafile, capath, cadata) + elif context.verify_mode != CERT_NONE: + # no explicit cafile, capath or cadata but the verify mode is + # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system + # root CA certificates for the given purpose. This may fail silently. + context.load_default_certs(purpose) + + return context + +# Used by http.client if no context is explicitly passed. +_create_default_https_context = create_default_context + + +# Backwards compatibility alias, even though it's not a public name. +_create_stdlib_context = _create_unverified_context + + +class SSLObject: + """This class implements an interface on top of a low-level SSL object as + implemented by OpenSSL. This object captures the state of an SSL connection + but does not provide any network IO itself. IO needs to be performed + through separate "BIO" objects which are OpenSSL's IO abstraction layer. + + This class does not have a public constructor. Instances are returned by + ``SSLContext.wrap_bio``. This class is typically used by framework authors + that want to implement asynchronous IO for SSL through memory buffers. + + When compared to ``SSLSocket``, this object lacks the following features: + + * Any form of network IO, including methods such as ``recv`` and ``send``. + * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. + """ + + def __init__(self, sslobj, owner=None, session=None): + self._sslobj = sslobj + # Note: _sslobj takes a weak reference to owner + self._sslobj.owner = owner or self + if session is not None: + self._sslobj.session = session + + @property + def context(self): + """The SSLContext that is currently in use.""" + return self._sslobj.context + + @context.setter + def context(self, ctx): + self._sslobj.context = ctx + + @property + def session(self): + """The SSLSession for client socket.""" + return self._sslobj.session + + @session.setter + def session(self, session): + self._sslobj.session = session + + @property + def session_reused(self): + """Was the client session reused during handshake""" + return self._sslobj.session_reused + + @property + def server_side(self): + """Whether this is a server-side socket.""" + return self._sslobj.server_side + + @property + def server_hostname(self): + """The currently set server hostname (for SNI), or ``None`` if no + server hostame is set.""" + return self._sslobj.server_hostname + + def read(self, len=1024, buffer=None): + """Read up to 'len' bytes from the SSL object and return them. + + If 'buffer' is provided, read into this buffer and return the number of + bytes read. + """ + if buffer is not None: + v = self._sslobj.read(len, buffer) + else: + v = self._sslobj.read(len) + return v + + def write(self, data): + """Write 'data' to the SSL object and return the number of bytes + written. + + The 'data' argument must support the buffer interface. + """ + return self._sslobj.write(data) + + def getpeercert(self, binary_form=False): + """Returns a formatted version of the data in the certificate provided + by the other end of the SSL channel. + + Return None if no certificate was provided, {} if a certificate was + provided, but not validated. + """ + return self._sslobj.peer_certificate(binary_form) + + def selected_npn_protocol(self): + """Return the currently selected NPN protocol as a string, or ``None`` + if a next protocol was not negotiated or if NPN is not supported by one + of the peers.""" + if _ssl.HAS_NPN: + return self._sslobj.selected_npn_protocol() + + def selected_alpn_protocol(self): + """Return the currently selected ALPN protocol as a string, or ``None`` + if a next protocol was not negotiated or if ALPN is not supported by one + of the peers.""" + if _ssl.HAS_ALPN: + return self._sslobj.selected_alpn_protocol() + + def cipher(self): + """Return the currently selected cipher as a 3-tuple ``(name, + ssl_version, secret_bits)``.""" + return self._sslobj.cipher() + + def shared_ciphers(self): + """Return a list of ciphers shared by the client during the handshake or + None if this is not a valid server connection. + """ + return self._sslobj.shared_ciphers() + + def compression(self): + """Return the current compression algorithm in use, or ``None`` if + compression was not negotiated or not supported by one of the peers.""" + return self._sslobj.compression() + + def pending(self): + """Return the number of bytes that can be read immediately.""" + return self._sslobj.pending() + + def do_handshake(self): + """Start the SSL/TLS handshake.""" + self._sslobj.do_handshake() + if self.context.check_hostname: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) + + def unwrap(self): + """Start the SSL shutdown handshake.""" + return self._sslobj.shutdown() + + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake).""" + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError("Unsupported channel binding type") + if cb_type != "tls-unique": + raise NotImplementedError( + "{0} channel binding type not implemented" + .format(cb_type)) + return self._sslobj.tls_unique_cb() + + def version(self): + """Return a string identifying the protocol version used by the + current SSL channel. """ + return self._sslobj.version() + + def verify_client_post_handshake(self): + return self._sslobj.verify_client_post_handshake() + + +class SSLSocket(socket): + """This class implements a subtype of socket.socket that wraps + the underlying OS socket in an SSL context when necessary, and + provides read and write methods over that channel.""" + + def __init__(self, sock=None, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_TLS, ca_certs=None, + do_handshake_on_connect=True, + family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, + suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, + server_hostname=None, + _context=None, _session=None): + + if _context: + self._context = _context + else: + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile and not keyfile: + keyfile = certfile + self._context = SSLContext(ssl_version) + self._context.verify_mode = cert_reqs + if ca_certs: + self._context.load_verify_locations(ca_certs) + if certfile: + self._context.load_cert_chain(certfile, keyfile) + if npn_protocols: + self._context.set_npn_protocols(npn_protocols) + if ciphers: + self._context.set_ciphers(ciphers) + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get + # mixed in. + if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: + raise NotImplementedError("only stream sockets are supported") + if server_side: + if server_hostname: + raise ValueError("server_hostname can only be specified " + "in client mode") + if _session is not None: + raise ValueError("session can only be specified in " + "client mode") + if self._context.check_hostname and not server_hostname: + raise ValueError("check_hostname requires server_hostname") + self._session = _session + self.server_side = server_side + self.server_hostname = server_hostname + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + if sock is not None: + socket.__init__(self, + family=sock.family, + type=sock.type, + proto=sock.proto, + fileno=sock.fileno()) + self.settimeout(sock.gettimeout()) + sock.detach() + elif fileno is not None: + socket.__init__(self, fileno=fileno) + else: + socket.__init__(self, family=family, type=type, proto=proto) + + # See if we are connected + try: + self.getpeername() + except OSError as e: + if e.errno != errno.ENOTCONN: + raise + connected = False + else: + connected = True + + self._closed = False + self._sslobj = None + self._connected = connected + if connected: + # create the SSL object + try: + sslobj = self._context._wrap_socket(self, server_side, + server_hostname) + self._sslobj = SSLObject(sslobj, owner=self, + session=self._session) + if do_handshake_on_connect: + timeout = self.gettimeout() + if timeout == 0.0: + # non-blocking + raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") + self.do_handshake() + + except (OSError, ValueError): + self.close() + raise + + @property + def context(self): + return self._context + + @context.setter + def context(self, ctx): + self._context = ctx + self._sslobj.context = ctx + + @property + def session(self): + """The SSLSession for client socket.""" + if self._sslobj is not None: + return self._sslobj.session + + @session.setter + def session(self, session): + self._session = session + if self._sslobj is not None: + self._sslobj.session = session + + @property + def session_reused(self): + """Was the client session reused during handshake""" + if self._sslobj is not None: + return self._sslobj.session_reused + + def dup(self): + raise NotImplementedError("Can't dup() %s instances" % + self.__class__.__name__) + + def _checkClosed(self, msg=None): + # raise an exception here if you wish to check for spurious closes + pass + + def _check_connected(self): + if not self._connected: + # getpeername() will raise ENOTCONN if the socket is really + # not connected; note that we can be connected even without + # _connected being set, e.g. if connect() first returned + # EAGAIN. + self.getpeername() + + def read(self, len=1024, buffer=None): + """Read up to LEN bytes and return them. + Return zero-length string on EOF.""" + + self._checkClosed() + if not self._sslobj: + raise ValueError("Read on closed or unwrapped SSL socket.") + try: + return self._sslobj.read(len, buffer) + except SSLError as x: + if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + if buffer is not None: + return 0 + else: + return b'' + else: + raise + + def write(self, data): + """Write DATA to the underlying SSL channel. Returns + number of bytes of DATA actually transmitted.""" + + self._checkClosed() + if not self._sslobj: + raise ValueError("Write on closed or unwrapped SSL socket.") + return self._sslobj.write(data) + + def getpeercert(self, binary_form=False): + """Returns a formatted version of the data in the + certificate provided by the other end of the SSL channel. + Return None if no certificate was provided, {} if a + certificate was provided, but not validated.""" + + self._checkClosed() + self._check_connected() + return self._sslobj.getpeercert(binary_form) + + def selected_npn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_NPN: + return None + else: + return self._sslobj.selected_npn_protocol() + + def selected_alpn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_ALPN: + return None + else: + return self._sslobj.selected_alpn_protocol() + + def cipher(self): + self._checkClosed() + if not self._sslobj: + return None + else: + return self._sslobj.cipher() + + def shared_ciphers(self): + self._checkClosed() + if not self._sslobj: + return None + return self._sslobj.shared_ciphers() + + def compression(self): + self._checkClosed() + if not self._sslobj: + return None + else: + return self._sslobj.compression() + + def send(self, data, flags=0): + self._checkClosed() + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to send() on %s" % + self.__class__) + return self._sslobj.write(data) + else: + return socket.send(self, data, flags) + + def sendto(self, data, flags_or_addr, addr=None): + self._checkClosed() + if self._sslobj: + raise ValueError("sendto not allowed on instances of %s" % + self.__class__) + elif addr is None: + return socket.sendto(self, data, flags_or_addr) + else: + return socket.sendto(self, data, flags_or_addr, addr) + + def sendmsg(self, *args, **kwargs): + # Ensure programs don't send data unencrypted if they try to + # use this method. + raise NotImplementedError("sendmsg not allowed on instances of %s" % + self.__class__) + + def sendall(self, data, flags=0): + self._checkClosed() + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to sendall() on %s" % + self.__class__) + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + v = self.send(byte_view[count:]) + count += v + else: + return socket.sendall(self, data, flags) + + def sendfile(self, file, offset=0, count=None): + """Send a file, possibly by using os.sendfile() if this is a + clear-text socket. Return the total number of bytes sent. + """ + if self._sslobj is None: + # os.sendfile() works with plain sockets only + return super().sendfile(file, offset, count) + else: + return self._sendfile_use_send(file, offset, count) + + def recv(self, buflen=1024, flags=0): + self._checkClosed() + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv() on %s" % + self.__class__) + return self.read(buflen) + else: + return socket.recv(self, buflen, flags) + + def recv_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() + if buffer and (nbytes is None): + nbytes = len(buffer) + elif nbytes is None: + nbytes = 1024 + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv_into() on %s" % + self.__class__) + return self.read(nbytes, buffer) + else: + return socket.recv_into(self, buffer, nbytes, flags) + + def recvfrom(self, buflen=1024, flags=0): + self._checkClosed() + if self._sslobj: + raise ValueError("recvfrom not allowed on instances of %s" % + self.__class__) + else: + return socket.recvfrom(self, buflen, flags) + + def recvfrom_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() + if self._sslobj: + raise ValueError("recvfrom_into not allowed on instances of %s" % + self.__class__) + else: + return socket.recvfrom_into(self, buffer, nbytes, flags) + + def recvmsg(self, *args, **kwargs): + raise NotImplementedError("recvmsg not allowed on instances of %s" % + self.__class__) + + def recvmsg_into(self, *args, **kwargs): + raise NotImplementedError("recvmsg_into not allowed on instances of " + "%s" % self.__class__) + + def pending(self): + self._checkClosed() + if self._sslobj: + return self._sslobj.pending() + else: + return 0 + + def shutdown(self, how): + self._checkClosed() + self._sslobj = None + socket.shutdown(self, how) + + def unwrap(self): + if self._sslobj: + s = self._sslobj.unwrap() + self._sslobj = None + return s + else: + raise ValueError("No SSL wrapper around " + str(self)) + + def verify_client_post_handshake(self): + if self._sslobj: + return self._sslobj.verify_client_post_handshake() + else: + raise ValueError("No SSL wrapper around " + str(self)) + + def _real_close(self): + self._sslobj = None + socket._real_close(self) + + def do_handshake(self, block=False): + """Perform a TLS/SSL handshake.""" + self._check_connected() + timeout = self.gettimeout() + try: + if timeout == 0.0 and block: + self.settimeout(None) + self._sslobj.do_handshake() + finally: + self.settimeout(timeout) + + def _real_connect(self, addr, connect_ex): + if self.server_side: + raise ValueError("can't connect in server-side mode") + # Here we assume that the socket is client-side, and not + # connected at the time of the call. We connect it, then wrap it. + if self._connected: + raise ValueError("attempt to connect already-connected SSLSocket!") + sslobj = self.context._wrap_socket(self, False, self.server_hostname) + self._sslobj = SSLObject(sslobj, owner=self, + session=self._session) + try: + if connect_ex: + rc = socket.connect_ex(self, addr) + else: + rc = None + socket.connect(self, addr) + if not rc: + self._connected = True + if self.do_handshake_on_connect: + self.do_handshake() + return rc + except (OSError, ValueError): + self._sslobj = None + raise + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + self._real_connect(addr, False) + + def connect_ex(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + return self._real_connect(addr, True) + + def accept(self): + """Accepts a new connection from a remote client, and returns + a tuple containing that new connection wrapped with a server-side + SSL channel, and the address of the remote client.""" + + newsock, addr = socket.accept(self) + newsock = self.context.wrap_socket(newsock, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs, + server_side=True) + return newsock, addr + + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake). + """ + if self._sslobj is None: + return None + return self._sslobj.get_channel_binding(cb_type) + + def version(self): + """ + Return a string identifying the protocol version used by the + current SSL channel, or None if there is no established channel. + """ + if self._sslobj is None: + return None + return self._sslobj.version() + + +def wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_TLS, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + ciphers=None): + return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers) + +# some utility functions + +def cert_time_to_seconds(cert_time): + """Return the time in seconds since the Epoch, given the timestring + representing the "notBefore" or "notAfter" date from a certificate + in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale). + + "notBefore" or "notAfter" dates must use UTC (RFC 5280). + + Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec + UTC should be specified as GMT (see ASN1_TIME_print()) + """ + from time import strptime + from calendar import timegm + + months = ( + "Jan","Feb","Mar","Apr","May","Jun", + "Jul","Aug","Sep","Oct","Nov","Dec" + ) + time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT + try: + month_number = months.index(cert_time[:3].title()) + 1 + except ValueError: + raise ValueError('time data %r does not match ' + 'format "%%b%s"' % (cert_time, time_format)) + else: + # found valid month + tt = strptime(cert_time[3:], time_format) + # return an integer, the previous mktime()-based implementation + # returned a float (fractional seconds are always zero here). + return timegm((tt[0], month_number) + tt[2:6]) + +PEM_HEADER = "-----BEGIN CERTIFICATE-----" +PEM_FOOTER = "-----END CERTIFICATE-----" + +def DER_cert_to_PEM_cert(der_cert_bytes): + """Takes a certificate in binary DER format and returns the + PEM version of it as a string.""" + + f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict') + return (PEM_HEADER + '\n' + + textwrap.fill(f, 64) + '\n' + + PEM_FOOTER + '\n') + +def PEM_cert_to_DER_cert(pem_cert_string): + """Takes a certificate in ASCII PEM format and returns the + DER-encoded version of it as a byte sequence""" + + if not pem_cert_string.startswith(PEM_HEADER): + raise ValueError("Invalid PEM encoding; must start with %s" + % PEM_HEADER) + if not pem_cert_string.strip().endswith(PEM_FOOTER): + raise ValueError("Invalid PEM encoding; must end with %s" + % PEM_FOOTER) + d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] + return base64.decodebytes(d.encode('ASCII', 'strict')) + +def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None): + """Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt.""" + + host, port = addr + if ca_certs is not None: + cert_reqs = CERT_REQUIRED + else: + cert_reqs = CERT_NONE + context = _create_stdlib_context(ssl_version, + cert_reqs=cert_reqs, + cafile=ca_certs) + with create_connection(addr) as sock: + with context.wrap_socket(sock) as sslsock: + dercert = sslsock.getpeercert(True) + return DER_cert_to_PEM_cert(dercert) + +def get_protocol_name(protocol_code): + return _PROTOCOL_NAMES.get(protocol_code, '') From 255b5a631af0c2b4aa5841ac2489ccd0fbab6388 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 00:07:09 -0500 Subject: [PATCH 04/17] Edit ssl.py slightly --- Lib/ssl.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index 58d3e93922..a578c0f552 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -100,14 +100,15 @@ import _ssl # if we can't import it, let the error propagate -from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, MemoryBIO, SSLSession +# XXX RustPython TODO: provide more of these imports +# from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION +from _ssl import _SSLContext #, MemoryBIO, SSLSession from _ssl import ( - SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, - SSLSyscallError, SSLEOFError, + SSLError, #SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, +# SSLSyscallError, SSLEOFError, ) from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj -from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes +# from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes try: from _ssl import RAND_egd except ImportError: @@ -115,8 +116,8 @@ pass -from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3 -from _ssl import _OPENSSL_API_VERSION +# from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3 +# from _ssl import _OPENSSL_API_VERSION _IntEnum._convert( @@ -969,11 +970,13 @@ def sendall(self, data, flags=0): "non-zero flags not allowed in calls to sendall() on %s" % self.__class__) count = 0 - with memoryview(data) as view, view.cast("B") as byte_view: - amount = len(byte_view) - while count < amount: - v = self.send(byte_view[count:]) - count += v + # with memoryview(data) as view, view.cast("B") as byte_view: + # XXX RustPython TODO: proper memoryview implementation + byte_view = data + amount = len(byte_view) + while count < amount: + v = self.send(byte_view[count:]) + count += v else: return socket.sendall(self, data, flags) From 289311727bcb212688bc0fee539f1a1822133ca9 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 00:07:37 -0500 Subject: [PATCH 05/17] impl TryFromObject for CString --- vm/src/obj/objstr.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index ffee006c2f..5a5a5f73c5 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -31,7 +31,7 @@ use crate::function::{single_or_tuple_any, OptionalArg, PyFuncArgs}; use crate::pyhash; use crate::pyobject::{ Either, IdProtocol, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyIterable, - PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, TypeProtocol, + PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TryIntoRef, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -1351,6 +1351,14 @@ impl IntoPyObject for &String { } } +impl TryFromObject for std::ffi::CString { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let s = PyStringRef::try_from_object(vm, obj)?; + Self::new(s.as_str().to_owned()) + .map_err(|_| vm.new_value_error("embedded null character".to_owned())) + } +} + #[derive(FromArgs)] struct SplitArgs { #[pyarg(positional_or_keyword, default = "None")] From a42e94b44fded85680b0e462753982ec1cfc437d Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 00:08:19 -0500 Subject: [PATCH 06/17] Add socket.getsockopt, impl io::{Read,Write} for PySocketRef --- vm/src/stdlib/socket.rs | 65 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 60057171f4..2969ae5a7a 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -37,7 +37,7 @@ mod c { pub use winapi::shared::ws2def::*; pub use winapi::um::winsock2::{ SD_BOTH as SHUT_RDWR, SD_RECEIVE as SHUT_RD, SD_SEND as SHUT_WR, SOCK_DGRAM, SOCK_RAW, - SOCK_RDM, SOCK_STREAM, SOL_SOCKET, SO_BROADCAST, SO_REUSEADDR, *, + SOCK_RDM, SOCK_STREAM, SOL_SOCKET, SO_BROADCAST, SO_REUSEADDR, SO_TYPE, *, }; } @@ -208,6 +208,10 @@ impl PySocket { fn close(&self) { self.sock.replace(invalid_sock()); } + #[pymethod] + fn detach(&self) -> RawSocket { + into_sock_fileno(self.sock.replace(invalid_sock())) + } #[pymethod] fn fileno(&self) -> RawSocket { @@ -278,6 +282,50 @@ impl PySocket { Ok(()) } + #[pymethod] + fn getsockopt( + &self, + level: i32, + name: i32, + buflen: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let fd = sock_fileno(&self.sock()) as _; + let buflen = buflen.unwrap_or(0); + if buflen == 0 { + let mut flag: libc::c_int = 0; + let mut flagsize = std::mem::size_of::() as _; + let ret = unsafe { + c::getsockopt( + fd, + level, + name, + &mut flag as *mut libc::c_int as *mut _, + &mut flagsize, + ) + }; + if ret < 0 { + Err(convert_sock_error(vm, io::Error::last_os_error())) + } else { + Ok(vm.new_int(flag)) + } + } else { + if buflen <= 0 || buflen > 1024 { + return Err(vm.new_os_error("getsockopt buflen out of range".to_owned())); + } + let mut buf = vec![0u8; buflen as usize]; + let mut buflen = buflen as _; + let ret = + unsafe { c::getsockopt(fd, level, name, buf.as_mut_ptr() as *mut _, &mut buflen) }; + buf.truncate(buflen as usize); + if ret < 0 { + Err(convert_sock_error(vm, io::Error::last_os_error())) + } else { + Ok(vm.ctx.new_bytes(buf)) + } + } + } + #[pymethod] fn setsockopt( &self, @@ -348,6 +396,20 @@ impl PySocket { } } +impl io::Read for PySocketRef { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + ::read(&mut self.sock.borrow_mut(), buf) + } +} +impl io::Write for PySocketRef { + fn write(&mut self, buf: &[u8]) -> io::Result { + ::write(&mut self.sock.borrow_mut(), buf) + } + fn flush(&mut self) -> io::Result<()> { + ::flush(&mut self.sock.borrow_mut()) + } +} + struct Address { host: PyStringRef, port: u16, @@ -609,6 +671,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "SO_REUSEADDR" => ctx.new_int(c::SO_REUSEADDR), "TCP_NODELAY" => ctx.new_int(c::TCP_NODELAY), "SO_BROADCAST" => ctx.new_int(c::SO_BROADCAST), + "SO_TYPE" => ctx.new_int(c::SO_TYPE), }); #[cfg(not(target_os = "redox"))] From 40835b19e6a0c08df87edc61f42e6ba5423b3e34 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 00:08:34 -0500 Subject: [PATCH 07/17] Add _ssl stdlib module --- Cargo.lock | 97 ++++++++ vm/Cargo.toml | 4 + vm/src/stdlib/mod.rs | 3 + vm/src/stdlib/ssl.rs | 574 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 678 insertions(+) create mode 100644 vm/src/stdlib/ssl.rs diff --git a/Cargo.lock b/Cargo.lock index b2749b3cb1..227d2d1a9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -402,6 +402,17 @@ dependencies = [ "memchr", ] +[[package]] +name = "derivative" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b94d2eb97732ec84b4e25eaf37db890e317b80e921f168c82cb5282473f8151" +dependencies = [ + "proc-macro2 1.0.8", + "quote 1.0.2", + "syn 1.0.14", +] + [[package]] name = "diff" version = "0.1.12" @@ -596,6 +607,21 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "fuchsia-cprng" version = "0.1.1" @@ -1008,6 +1034,28 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca565a7df06f3d4b485494f25ba05da1435950f4dc263440eda7a6fa9b8e36e4" +dependencies = [ + "derivative", + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffa5a33ddddfee04c0283a7653987d634e880347e96b5b2ed64de07efb59db9d" +dependencies = [ + "proc-macro-crate", + "proc-macro2 1.0.8", + "quote 1.0.2", + "syn 1.0.14", +] + [[package]] name = "once_cell" version = "1.3.1" @@ -1020,6 +1068,33 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" +[[package]] +name = "openssl" +version = "0.10.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "973293749822d7dd6370d6da1e523b0d1db19f06c459134c658b2a4261378b52" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "lazy_static 1.4.0", + "libc", + "openssl-sys", +] + +[[package]] +name = "openssl-sys" +version = "0.9.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1024c0a59774200a555087a6da3f253a9095a5f344e353b212ac4c8b8e450986" +dependencies = [ + "autocfg 1.0.0", + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "ordermap" version = "0.3.5" @@ -1106,6 +1181,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "proc-macro-crate" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10d4b51f154c8a7fb96fd6dad097cb74b863943ec010ac94b9fd1be8861fe1e" +dependencies = [ + "toml", +] + [[package]] name = "proc-macro-hack" version = "0.5.11" @@ -1516,6 +1600,7 @@ dependencies = [ "flame", "flamer", "flate2", + "foreign-types-shared", "gethostname", "getrandom", "hex", @@ -1538,7 +1623,10 @@ dependencies = [ "num-rational", "num-traits", "num_cpus", + "num_enum", "once_cell", + "openssl", + "openssl-sys", "paste", "pwd", "rand 0.7.3", @@ -1914,6 +2002,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "toml" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffc92d160b1eef40665be3a05630d003936a3bc7da7421277846c2613e92c71a" +dependencies = [ + "serde", +] + [[package]] name = "typenum" version = "1.11.2" diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 9fa1c76e7d..0812c2400d 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -70,6 +70,8 @@ paste = "0.1" base64 = "0.11" is-macro = "0.1" result-like = "^0.2.1" +foreign-types-shared = "0.1" +num_enum = "0.4" flame = { version = "0.2", optional = true } flamer = { version = "0.3", optional = true } @@ -89,6 +91,8 @@ subprocess = "0.2.2" num_cpus = "1" socket2 = { version = "0.3", features = ["unix"] } rustyline = "6.0" +openssl = "0.10" +openssl-sys = "0.9" [target.'cfg(not(any(target_arch = "wasm32", target_os = "redox")))'.dependencies] dns-lookup = "1.0" diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 5e74a0051f..301bafc264 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -55,6 +55,8 @@ mod select; #[cfg(not(target_arch = "wasm32"))] pub mod signal; #[cfg(not(target_arch = "wasm32"))] +mod ssl; +#[cfg(not(target_arch = "wasm32"))] mod subprocess; #[cfg(windows)] mod winapi; @@ -123,6 +125,7 @@ pub fn get_module_inits() -> HashMap { ); modules.insert("signal".to_owned(), Box::new(signal::make_module)); modules.insert("select".to_owned(), Box::new(select::make_module)); + modules.insert("_ssl".to_owned(), Box::new(ssl::make_module)); modules.insert("_subprocess".to_owned(), Box::new(subprocess::make_module)); #[cfg(not(target_os = "redox"))] modules.insert("zlib".to_owned(), Box::new(zlib::make_module)); diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs new file mode 100644 index 0000000000..8a14428641 --- /dev/null +++ b/vm/src/stdlib/ssl.rs @@ -0,0 +1,574 @@ +use super::socket::PySocketRef; +use crate::exceptions::PyBaseExceptionRef; +use crate::function::OptionalArg; +use crate::obj::objbytearray::PyByteArrayRef; +use crate::obj::objbyteinner::PyBytesLike; +use crate::obj::objbytes::PyBytesRef; +use crate::obj::objstr::{PyString, PyStringRef}; +use crate::obj::{objtype::PyClassRef, objweakref::PyWeak}; +use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; +use crate::types::create_type; +use crate::VirtualMachine; + +use std::cell::{RefCell, RefMut}; +use std::convert::TryFrom; +use std::ffi::{CStr, CString}; +use std::fmt; + +use foreign_types_shared::{ForeignType, ForeignTypeRef}; +use openssl::{ + asn1::{Asn1Object, Asn1ObjectRef}, + nid::Nid, + ssl::{self, SslContextBuilder, SslVerifyMode}, +}; + +mod sys { + use libc::{c_char, c_int}; + pub use openssl_sys::*; + extern "C" { + pub fn OBJ_txt2obj(s: *const c_char, no_name: c_int) -> *mut ASN1_OBJECT; + pub fn OBJ_nid2obj(n: c_int) -> *mut ASN1_OBJECT; + pub fn TLS_server_method() -> *const SSL_METHOD; + pub fn TLS_client_method() -> *const SSL_METHOD; + pub fn SSL_CTX_get_verify_mode(ctx: *const SSL_CTX) -> c_int; + pub fn X509_get_default_cert_file_env() -> *const c_char; + pub fn X509_get_default_cert_file() -> *const c_char; + pub fn X509_get_default_cert_dir_env() -> *const c_char; + pub fn X509_get_default_cert_dir() -> *const c_char; + } +} + +#[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] +#[repr(i32)] +enum SslVersion { + Ssl2, + Ssl3 = 1, + Tls, + Tls1, + // TODO: Tls1_1, Tls1_2 ? + TlsClient = 0x10, + TlsServer, +} + +#[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] +#[repr(i32)] +enum CertRequirements { + None, + Optional, + Required, +} + +#[derive(Debug)] +enum SslServerOrClient { + Client, + Server, +} + +unsafe fn ptr2obj(ptr: *mut sys::ASN1_OBJECT) -> Option { + if ptr.is_null() { + None + } else { + Some(Asn1Object::from_ptr(ptr)) + } +} +fn txt2obj(s: &CStr, no_name: bool) -> Option { + unsafe { ptr2obj(sys::OBJ_txt2obj(s.as_ptr(), if no_name { 1 } else { 0 })) } +} +fn nid2obj(nid: Nid) -> Option { + unsafe { ptr2obj(sys::OBJ_nid2obj(nid.as_raw())) } +} +fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option { + unsafe { + let no_name = if no_name { 1 } else { 0 }; + let ptr = obj.as_ptr(); + let buflen = sys::OBJ_obj2txt(std::ptr::null_mut(), 0, ptr, no_name); + assert!(buflen >= 0); + if buflen == 0 { + return None; + } + let mut buf = vec![0u8; buflen as usize]; + let ret = sys::OBJ_obj2txt(buf.as_mut_ptr() as *mut libc::c_char, buflen, ptr, no_name); + assert!(ret >= 0); + let s = String::from_utf8(buf) + .unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()); + Some(s) + } +} + +type PyNid = (libc::c_int, String, String, Option); +fn obj2py(obj: &Asn1ObjectRef) -> PyNid { + let nid = obj.nid(); + ( + nid.as_raw(), + nid.short_name().unwrap().to_owned(), + nid.long_name().unwrap().to_owned(), + obj2txt(obj, true), + ) +} + +#[derive(FromArgs)] +struct Txt2ObjArgs { + #[pyarg(positional_or_keyword)] + txt: CString, + #[pyarg(positional_or_keyword, default = "false")] + name: bool, +} +fn ssl_txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { + txt2obj(&args.txt, !args.name) + .as_deref() + .map(obj2py) + .ok_or_else(|| { + vm.new_value_error(format!("unknown object '{}'", args.txt.to_str().unwrap())) + }) +} + +fn ssl_nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult { + nid2obj(Nid::from_raw(nid)) + .as_deref() + .map(obj2py) + .ok_or_else(|| vm.new_value_error(format!("unknown NID {}", nid))) +} + +fn ssl_get_default_verify_paths() -> (String, String, String, String) { + macro_rules! convert { + ($f:ident) => { + CStr::from_ptr(sys::$f()).to_string_lossy().into_owned() + }; + } + unsafe { + ( + convert!(X509_get_default_cert_file_env), + convert!(X509_get_default_cert_file), + convert!(X509_get_default_cert_dir_env), + convert!(X509_get_default_cert_dir), + ) + } +} + +#[pyclass(name = "_SSLContext")] +struct PySslContext { + ctx: RefCell, + check_hostname: bool, +} + +impl fmt::Debug for PySslContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("_SSLContext") + } +} + +impl PyValue for PySslContext { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_ssl", "_SSLContext") + } +} + +#[pyimpl(flags(BASETYPE))] +impl PySslContext { + fn builder(&self) -> RefMut { + self.ctx.borrow_mut() + } + // fn ctx(&self) -> Ref { + // Ref::map(self.ctx.borrow(), |ctx| unsafe { + // SslContextRef::from_ptr(ctx.as_ptr()) + // }) + // } + fn ptr(&self) -> *mut sys::SSL_CTX { + self.ctx.borrow().as_ptr() + } + + #[pyslot] + fn tp_new(cls: PyClassRef, proto_version: i32, vm: &VirtualMachine) -> PyResult> { + let proto = SslVersion::try_from(proto_version) + .map_err(|_| vm.new_value_error("invalid protocol version".to_owned()))?; + let method = match proto { + SslVersion::Ssl2 => todo!(), + SslVersion::Ssl3 => todo!(), + SslVersion::Tls => unsafe { ssl::SslMethod::from_ptr(sys::TLS_method()) }, + SslVersion::Tls1 => todo!(), + // TODO: Tls1_1, Tls1_2 ? + SslVersion::TlsClient => unsafe { ssl::SslMethod::from_ptr(sys::TLS_client_method()) }, + SslVersion::TlsServer => unsafe { ssl::SslMethod::from_ptr(sys::TLS_server_method()) }, + }; + let mut builder = + SslContextBuilder::new(method).map_err(|e| convert_openssl_error(vm, e))?; + let check_hostname = matches!(proto, SslVersion::TlsClient); + builder.set_verify(if check_hostname { + SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT + } else { + SslVerifyMode::NONE + }); + PySslContext { + ctx: RefCell::new(builder), + check_hostname, + } + .into_ref_with_type(vm, cls) + } + + #[pymethod] + fn set_ciphers(&self, cipherlist: CString, vm: &VirtualMachine) -> PyResult<()> { + self.builder() + .set_cipher_list(cipherlist.to_str().unwrap()) + .map_err(|_| { + vm.new_exception_msg(ssl_error(vm), "No cipher can be selected.".to_owned()) + }) + } + + #[pyproperty] + fn verify_mode(&self) -> i32 { + let mode = unsafe { sys::SSL_CTX_get_verify_mode(self.ptr()) }; + let mode = + SslVerifyMode::from_bits(mode).expect("bad SSL_CTX_get_verify_mode return value"); + if mode == SslVerifyMode::NONE { + CertRequirements::None.into() + } else if mode == SslVerifyMode::PEER { + CertRequirements::Optional.into() + } else if mode == SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT { + CertRequirements::Required.into() + } else { + unreachable!() + } + } + #[pyproperty(setter)] + fn set_verify_mode(&self, cert: i32, vm: &VirtualMachine) -> PyResult<()> { + let cert_req = CertRequirements::try_from(cert) + .map_err(|_| vm.new_value_error("invalid value for verify_mode".to_owned()))?; + let mode = match cert_req { + CertRequirements::None if self.check_hostname => { + return Err(vm.new_value_error( + "Cannot set verify_mode to CERT_NONE when check_hostname is enabled." + .to_owned(), + )) + } + CertRequirements::None => SslVerifyMode::NONE, + CertRequirements::Optional => SslVerifyMode::PEER, + CertRequirements::Required => SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT, + }; + self.builder().set_verify(mode); + Ok(()) + } + + #[pymethod] + fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> { + self.builder() + .set_default_verify_paths() + .map_err(|e| convert_openssl_error(vm, e)) + } + + #[pymethod] + fn load_verify_locations( + &self, + args: LoadVerifyLocationsArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + if args.cafile.is_none() && args.capath.is_none() && args.cadata.is_none() { + return Err( + vm.new_type_error("cafile, capath and cadata cannot be all omitted".to_owned()) + ); + } + + if let Some(_cadata) = args.cadata { + todo!() + } + + if args.cafile.is_some() || args.capath.is_some() { + let ret = unsafe { + sys::SSL_CTX_load_verify_locations( + self.ptr(), + args.cafile + .as_ref() + .map_or_else(std::ptr::null, |cs| cs.as_ptr()), + args.capath + .as_ref() + .map_or_else(std::ptr::null, |cs| cs.as_ptr()), + ) + }; + if ret != 1 { + let errno = std::io::Error::last_os_error().raw_os_error().unwrap(); + let err = if errno != 0 { + super::os::errno_err(vm) + } else { + convert_openssl_error(vm, openssl::error::ErrorStack::get()) + }; + return Err(err); + } + } + + Ok(()) + } + + #[pymethod] + fn _wrap_socket( + zelf: PyRef, + args: WrapSocketArgs, + vm: &VirtualMachine, + ) -> PyResult { + let server_hostname = args + .server_hostname + .map(|s| { + vm.encode( + s.into_object(), + Some(PyString::from("ascii").into_ref(vm)), + None, + ) + .and_then(|res| PyBytesRef::try_from_object(vm, res)) + }) + .transpose()?; + + let ssl = { + let ptr = zelf.ptr(); + let ctx = unsafe { ssl::SslContext::from_ptr(ptr) }; + let ssl = ssl::Ssl::new(&ctx).map_err(|e| convert_openssl_error(vm, e))?; + std::mem::forget(ctx); + ssl + }; + + let mut stream = ssl::SslStreamBuilder::new(ssl, args.sock.clone()); + + let socket_type = if args.server_side { + stream.set_accept_state(); + SslServerOrClient::Server + } else { + stream.set_connect_state(); + SslServerOrClient::Client + }; + + // TODO: use this + let _ = args.session; + + Ok(PySslSocket { + ctx: zelf, + stream: RefCell::new(Some(stream)), + socket_type, + server_hostname, + owner: RefCell::new(args.owner.as_ref().map(PyWeak::downgrade)), + }) + } +} + +#[derive(FromArgs)] +// #[allow(dead_code)] +struct WrapSocketArgs { + #[pyarg(positional_or_keyword)] + sock: PySocketRef, + #[pyarg(positional_or_keyword)] + server_side: bool, + #[pyarg(positional_or_keyword, default = "None")] + server_hostname: Option, + #[pyarg(keyword_only, default = "None")] + owner: Option, + #[pyarg(keyword_only, default = "None")] + session: Option, +} + +#[derive(FromArgs)] +struct LoadVerifyLocationsArgs { + #[pyarg(positional_or_keyword, default = "None")] + cafile: Option, + #[pyarg(positional_or_keyword, default = "None")] + capath: Option, + #[pyarg(positional_or_keyword, default = "None")] + cadata: Option, +} + +#[pyclass(name = "_SSLSocket")] +struct PySslSocket { + ctx: PyRef, + stream: RefCell>>, + socket_type: SslServerOrClient, + server_hostname: Option, + owner: RefCell>, +} + +impl fmt::Debug for PySslSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("_SSLSocket") + } +} + +impl PyValue for PySslSocket { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_ssl", "_SSLSocket") + } +} + +#[pyimpl] +impl PySslSocket { + fn stream_builder(&self) -> ssl::SslStreamBuilder { + self.stream.replace(None).unwrap() + } + fn stream(&self) -> RefMut> { + RefMut::map(self.stream.borrow_mut(), |b| { + let b = b.as_mut().unwrap(); + unsafe { &mut *(b as *mut ssl::SslStreamBuilder<_> as *mut ssl::SslStream<_>) } + }) + } + fn set_stream(&self, stream: ssl::SslStream) { + let prev = self + .stream + .replace(Some(unsafe { std::mem::transmute(stream) })); + debug_assert!(prev.is_none()); + } + + #[pyproperty] + fn owner(&self) -> Option { + self.owner.borrow().as_ref().and_then(PyWeak::upgrade) + } + #[pyproperty(setter)] + fn set_owner(&self, owner: PyObjectRef) { + *self.owner.borrow_mut() = Some(PyWeak::downgrade(&owner)) + } + #[pyproperty] + fn server_side(&self) -> bool { + matches!(self.socket_type, SslServerOrClient::Server) + } + #[pyproperty] + fn context(&self) -> PyRef { + self.ctx.clone() + } + #[pyproperty] + fn server_hostname(&self) -> Option { + self.server_hostname.clone() + } + + #[pymethod] + fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { + use crate::pyobject::Either; + // Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE + let mut handshaker: Either<_, ssl::MidHandshakeSslStream<_>> = + Either::A(self.stream_builder()); + loop { + let handshake_result = match handshaker { + Either::A(s) => s.handshake(), + Either::B(s) => s.handshake(), + }; + match handshake_result { + Ok(stream) => { + self.set_stream(stream); + return Ok(()); + } + Err(ssl::HandshakeError::SetupFailure(e)) => { + return Err(convert_openssl_error(vm, e)) + } + Err(ssl::HandshakeError::WouldBlock(s)) => handshaker = Either::B(s), + Err(ssl::HandshakeError::Failure(s)) => { + return Err(convert_ssl_error(vm, s.into_error())) + } + } + } + } + + #[pymethod] + fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult { + data.with_ref(|b| self.stream().ssl_write(b)) + .map_err(|e| convert_ssl_error(vm, e)) + } + + #[pymethod] + fn read(&self, n: usize, buffer: OptionalArg, vm: &VirtualMachine) -> PyResult { + if let OptionalArg::Present(buffer) = buffer { + let mut buf = buffer.borrow_value_mut(); + let n = self + .stream() + .ssl_read(&mut buf.elements) + .map_err(|e| convert_ssl_error(vm, e))?; + Ok(vm.new_int(n)) + } else { + let mut buf = vec![0u8; n]; + buf.truncate(n); + Ok(vm.ctx.new_bytes(buf)) + } + } +} + +fn ssl_error(vm: &VirtualMachine) -> PyClassRef { + vm.class("_ssl", "SSLError") +} + +fn convert_openssl_error( + vm: &VirtualMachine, + err: openssl::error::ErrorStack, +) -> PyBaseExceptionRef { + let cls = ssl_error(vm); + match err.errors().first() { + Some(e) => { + let no = "unknown"; + let msg = format!( + "openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}", + e.code(), e.library().unwrap_or(no), e.function().unwrap_or(no), e.line(), + e.reason().unwrap_or(no), e.data().unwrap_or("none"), + ); + vm.new_exception_msg(cls, msg) + } + None => vm.new_exception_empty(cls), + } +} +fn convert_ssl_error(vm: &VirtualMachine, e: ssl::Error) -> PyBaseExceptionRef { + match e.into_io_error() { + Ok(io_err) => super::os::convert_io_error(vm, io_err), + Err(e) => convert_openssl_error(vm, e.ssl_error().unwrap().clone()), + } +} + +pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { + openssl::init(); + let ctx = &vm.ctx; + let ssl_error = create_type( + "SSLError", + &vm.ctx.types.type_type, + &vm.ctx.exceptions.os_error, + ); + py_module!(vm, "_ssl", { + "_SSLContext" => PySslContext::make_class(ctx), + "_SSLSocket" => PySslSocket::make_class(ctx), + "SSLError" => ssl_error, + "txt2obj" => ctx.new_function(ssl_txt2obj), + "nid2obj" => ctx.new_function(ssl_nid2obj), + "get_default_verify_paths" => ctx.new_function(ssl_get_default_verify_paths), + + // Constants + "PROTOCOL_SSLv2" => ctx.new_int(SslVersion::Ssl2 as u32), + "PROTOCOL_SSLv3" => ctx.new_int(SslVersion::Ssl3 as u32), + "PROTOCOL_SSLv23" => ctx.new_int(SslVersion::Tls as u32), + "PROTOCOL_TLS" => ctx.new_int(SslVersion::Tls as u32), + "PROTOCOL_TLS_CLIENT" => ctx.new_int(SslVersion::TlsClient as u32), + "PROTOCOL_TLS_SERVER" => ctx.new_int(SslVersion::TlsServer as u32), + "PROTOCOL_TLSv1" => ctx.new_int(SslVersion::Tls1 as u32), + "OP_NO_SSLv2" => ctx.new_int(sys::SSL_OP_NO_SSLv2), + "OP_NO_SSLv3" => ctx.new_int(sys::SSL_OP_NO_SSLv3), + "OP_NO_TLSv1" => ctx.new_int(sys::SSL_OP_NO_TLSv1), + // "OP_NO_TLSv1_1" => ctx.new_int(sys::SSL_OP_NO_TLSv1_1), + // "OP_NO_TLSv1_2" => ctx.new_int(sys::SSL_OP_NO_TLSv1_2), + "OP_NO_TLSv1_3" => ctx.new_int(sys::SSL_OP_NO_TLSv1_3), + "OP_CIPHER_SERVER_PREFERENCE" => ctx.new_int(sys::SSL_OP_CIPHER_SERVER_PREFERENCE), + "OP_SINGLE_DH_USE" => ctx.new_int(sys::SSL_OP_SINGLE_DH_USE), + "OP_NO_TICKET" => ctx.new_int(sys::SSL_OP_NO_TICKET), + // #ifdef SSL_OP_SINGLE_ECDH_USE + // "OP_SINGLE_ECDH_USE" => ctx.new_int(sys::SSL_OP_SINGLE_ECDH_USE), + // #endif + // #ifdef SSL_OP_NO_COMPRESSION + // "OP_NO_COMPRESSION" => ctx.new_int(sys::SSL_OP_NO_COMPRESSION), + // #endif + "HAS_TLS_UNIQUE" => ctx.new_bool(true), + "CERT_NONE" => ctx.new_int(CertRequirements::None as u32), + "CERT_OPTIONAL" => ctx.new_int(CertRequirements::Optional as u32), + "CERT_REQUIRED" => ctx.new_int(CertRequirements::Required as u32), + "VERIFY_DEFAULT" => ctx.new_int(0), + // "VERIFY_CRL_CHECK_LEAF" => sys::X509_V_FLAG_CRL_CHECK, + // "VERIFY_CRL_CHECK_CHAIN" => sys::X509_V_FLAG_CRL_CHECK|sys::X509_V_FLAG_CRL_CHECK_ALL, + // "VERIFY_X509_STRICT" => X509_V_FLAG_X509_STRICT, + "SSL_ERROR_ZERO_RETURN" => ctx.new_int(sys::SSL_ERROR_ZERO_RETURN), + "SSL_ERROR_WANT_READ" => ctx.new_int(sys::SSL_ERROR_WANT_READ), + "SSL_ERROR_WANT_WRITE" => ctx.new_int(sys::SSL_ERROR_WANT_WRITE), + // "SSL_ERROR_WANT_X509_LOOKUP" => ctx.new_int(sys::SSL_ERROR_WANT_X509_LOOKUP), + "SSL_ERROR_SYSCALL" => ctx.new_int(sys::SSL_ERROR_SYSCALL), + "SSL_ERROR_SSL" => ctx.new_int(sys::SSL_ERROR_SSL), + "SSL_ERROR_WANT_CONNECT" => ctx.new_int(sys::SSL_ERROR_WANT_CONNECT), + // "SSL_ERROR_EOF" => ctx.new_int(sys::SSL_ERROR_EOF), + // "SSL_ERROR_INVALID_ERROR_CODE" => ctx.new_int(sys::SSL_ERROR_INVALID_ERROR_CODE), + // TODO: so many more of these + "ALERT_DESCRIPTION_DECODE_ERROR" => ctx.new_int(sys::SSL_AD_DECODE_ERROR), + "ALERT_DESCRIPTION_ILLEGAL_PARAMETER" => ctx.new_int(sys::SSL_AD_ILLEGAL_PARAMETER), + "ALERT_DESCRIPTION_UNRECOGNIZED_NAME" => ctx.new_int(sys::SSL_AD_UNRECOGNIZED_NAME), + }) +} From 72f7cd65eb93850463d5c477e72e600b46573fa6 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 11:04:12 -0500 Subject: [PATCH 08/17] Install openssl on Windows in GH Actions --- .github/workflows/ci.yaml | 27 +++++++++++++++++++++++---- scripts/install-openssl.ps1 | 8 ++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 scripts/install-openssl.ps1 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 846a0ca2ad..656dc6ee00 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,6 +5,9 @@ on: name: CI +env: + VCPKGRS_DYNAMIC: 1 + jobs: rust_tests: name: Run rust tests @@ -15,8 +18,16 @@ jobs: fail-fast: false steps: - uses: actions/checkout@master - - name: Convert symlinks to hardlink (windows only) - run: powershell.exe scripts/symlinks-to-hardlinks.ps1 + - name: Cache Windows + uses: actions/cache@v1 + with: + key: vcpkg + path: 'C:/vcpkg' + if: runner.os == 'Windows' + - name: Set up the Windows environment + run: | + powershell.exe scripts/symlinks-to-hardlinks.ps1 + powershell.exe scripts/install-openssl.ps1 if: runner.os == 'Windows' - name: Cache cargo dependencies uses: actions/cache@v1 @@ -40,8 +51,16 @@ jobs: fail-fast: false steps: - uses: actions/checkout@master - - name: Convert symlinks to hardlink (windows only) - run: powershell.exe scripts/symlinks-to-hardlinks.ps1 + - name: Cache Windows + uses: actions/cache@v1 + with: + key: vcpkg + path: 'C:/vcpkg' + if: runner.os == 'Windows' + - name: Set up the Windows environment + run: | + powershell.exe scripts/symlinks-to-hardlinks.ps1 + powershell.exe scripts/install-openssl.ps1 if: runner.os == 'Windows' - name: Cache cargo dependencies uses: actions/cache@v1 diff --git a/scripts/install-openssl.ps1 b/scripts/install-openssl.ps1 new file mode 100644 index 0000000000..beab80edb9 --- /dev/null +++ b/scripts/install-openssl.ps1 @@ -0,0 +1,8 @@ +# From the Actix Web windows workflow: +# https://github.com/actix/actix-web/blob/master/.github/workflows/windows.yml +vcpkg integrate install +vcpkg install openssl:x64-windows +Copy-Item C:\vcpkg\installed\x64-windows\bin\libcrypto-1_1-x64.dll C:\vcpkg\installed\x64-windows\bin\libcrypto.dll +Copy-Item C:\vcpkg\installed\x64-windows\bin\libssl-1_1-x64.dll C:\vcpkg\installed\x64-windows\bin\libssl.dll +Get-ChildItem C:\vcpkg\installed\x64-windows\bin +Get-ChildItem C:\vcpkg\installed\x64-windows\lib From 03ba022258e2c62dd68ead0d86c5ce64ffe9cd72 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 13:13:46 -0500 Subject: [PATCH 09/17] Don't use the matches!() macro --- vm/src/stdlib/ssl.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 8a14428641..5ae84bc57c 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -38,7 +38,7 @@ mod sys { } } -#[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] +#[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive, PartialEq)] #[repr(i32)] enum SslVersion { Ssl2, @@ -58,7 +58,7 @@ enum CertRequirements { Required, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] enum SslServerOrClient { Client, Server, @@ -192,7 +192,7 @@ impl PySslContext { }; let mut builder = SslContextBuilder::new(method).map_err(|e| convert_openssl_error(vm, e))?; - let check_hostname = matches!(proto, SslVersion::TlsClient); + let check_hostname = proto == SslVersion::TlsClient; builder.set_verify(if check_hostname { SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT } else { @@ -420,7 +420,7 @@ impl PySslSocket { } #[pyproperty] fn server_side(&self) -> bool { - matches!(self.socket_type, SslServerOrClient::Server) + self.socket_type == SslServerOrClient::Server } #[pyproperty] fn context(&self) -> PyRef { From bea6e54a3710c629ea143cf1cdadca563adf66fd Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 17:53:50 -0500 Subject: [PATCH 10/17] Make ssl work on windows --- Cargo.lock | 11 +++++++++ Lib/ssl.py | 2 +- vm/Cargo.toml | 3 ++- vm/src/obj/objset.rs | 11 +++++++++ vm/src/stdlib/ssl.rs | 57 +++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 81 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 227d2d1a9f..b952e66286 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1639,6 +1639,7 @@ dependencies = [ "rustpython-derive", "rustpython-parser", "rustyline", + "schannel", "serde", "serde_json", "sha-1", @@ -1706,6 +1707,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8" +[[package]] +name = "schannel" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "039c25b130bd8c1321ee2d7de7fde2659fa9c2744e4bb29711cfc852ea53cd19" +dependencies = [ + "lazy_static 1.4.0", + "winapi", +] + [[package]] name = "semver" version = "0.9.0" diff --git a/Lib/ssl.py b/Lib/ssl.py index a578c0f552..0904a8a140 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -158,7 +158,7 @@ if sys.platform == "win32": - from _ssl import enum_certificates, enum_crls + from _ssl import enum_certificates #, enum_crls from socket import socket, AF_INET, SOCK_STREAM, create_connection from socket import SOL_SOCKET, SO_TYPE diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 0812c2400d..4d1d0bdf6f 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -101,10 +101,11 @@ libz-sys = "1.0" [target.'cfg(windows)'.dependencies] winreg = "0.7" +schannel = "0.1" [target."cfg(windows)".dependencies.winapi] version = "0.3" -features = ["winsock2", "handleapi", "ws2def", "std", "winbase"] +features = ["winsock2", "handleapi", "ws2def", "std", "winbase", "wincrypt"] [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2" diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 369af668df..c9adfc9448 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -631,6 +631,17 @@ macro_rules! multi_args_frozenset { #[pyimpl(flags(BASETYPE))] impl PyFrozenSet { + pub fn from_iter( + vm: &VirtualMachine, + it: impl IntoIterator, + ) -> PyResult { + let mut inner = PySetInner::default(); + for elem in it { + inner.add(&elem, vm)?; + } + Ok(Self { inner }) + } + #[pyslot] fn tp_new( cls: PyClassRef, diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 5ae84bc57c..4b08f505d8 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -106,6 +106,46 @@ fn obj2py(obj: &Asn1ObjectRef) -> PyNid { ) } +#[cfg(windows)] +fn ssl_enum_certificates(store_name: PyStringRef, vm: &VirtualMachine) -> PyResult { + use crate::obj::objset::PyFrozenSet; + use schannel::{cert_context::ValidUses, cert_store::CertStore, RawPointer}; + use winapi::um::wincrypt; + // TODO: check every store for it, not just 2 of them: + // https://github.com/python/cpython/blob/3.8/Modules/_ssl.c#L5603-L5610 + let open_fns = [CertStore::open_current_user, CertStore::open_local_machine]; + let stores = open_fns + .iter() + .filter_map(|open| open(store_name.as_str()).ok()) + .collect::>(); + let certs = stores.iter().map(|s| s.certs()).flatten().map(|c| { + let cert = vm.ctx.new_bytes(c.to_der().to_owned()); + let enc_type = unsafe { + let ptr = c.as_ptr() as wincrypt::PCCERT_CONTEXT; + (*ptr).dwCertEncodingType + }; + let enc_type = match enc_type { + wincrypt::X509_ASN_ENCODING => vm.new_str("x509_asn".to_owned()), + wincrypt::PKCS_7_ASN_ENCODING => vm.new_str("pkcs_7_asn".to_owned()), + other => vm.new_int(other), + }; + let usage = match c.valid_uses()? { + ValidUses::All => vm.new_bool(true), + ValidUses::Oids(oids) => { + PyFrozenSet::from_iter(vm, oids.into_iter().map(|oid| vm.new_str(oid))) + .unwrap() + .into_ref(vm) + .into_object() + } + }; + Ok(vm.ctx.new_tuple(vec![cert, enc_type, usage])) + }); + let certs = certs + .collect::, _>>() + .map_err(|e| super::os::convert_io_error(vm, e))?; + Ok(vm.ctx.new_list(certs)) +} + #[derive(FromArgs)] struct Txt2ObjArgs { #[pyarg(positional_or_keyword)] @@ -518,7 +558,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { &vm.ctx.types.type_type, &vm.ctx.exceptions.os_error, ); - py_module!(vm, "_ssl", { + let module = py_module!(vm, "_ssl", { "_SSLContext" => PySslContext::make_class(ctx), "_SSLSocket" => PySslSocket::make_class(ctx), "SSLError" => ssl_error, @@ -570,5 +610,20 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "ALERT_DESCRIPTION_DECODE_ERROR" => ctx.new_int(sys::SSL_AD_DECODE_ERROR), "ALERT_DESCRIPTION_ILLEGAL_PARAMETER" => ctx.new_int(sys::SSL_AD_ILLEGAL_PARAMETER), "ALERT_DESCRIPTION_UNRECOGNIZED_NAME" => ctx.new_int(sys::SSL_AD_UNRECOGNIZED_NAME), + }); + + extend_module_platform_specific(&module, vm); + + module +} + +#[cfg(windows)] +fn extend_module_platform_specific(module: &PyObjectRef, vm: &VirtualMachine) { + let ctx = &vm.ctx; + extend_module!(vm, module, { + "enum_certificates" => ctx.new_function(ssl_enum_certificates), }) } + +#[cfg(not(windows))] +fn extend_module_platform_specific(_module: &PyObjectRef, _vm: &VirtualMachine) {} From 588589865e4b2349a9edd85abd52b70fee3a47f7 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 21:10:17 -0500 Subject: [PATCH 11/17] Try a few things to make cert verification work --- vm/src/stdlib/ssl.rs | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 4b08f505d8..1c8aa424b3 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -19,7 +19,7 @@ use foreign_types_shared::{ForeignType, ForeignTypeRef}; use openssl::{ asn1::{Asn1Object, Asn1ObjectRef}, nid::Nid, - ssl::{self, SslContextBuilder, SslVerifyMode}, + ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode}, }; mod sys { @@ -35,6 +35,7 @@ mod sys { pub fn X509_get_default_cert_file() -> *const c_char; pub fn X509_get_default_cert_dir_env() -> *const c_char; pub fn X509_get_default_cert_dir() -> *const c_char; + pub fn SSL_CTX_set_post_handshake_auth(ctx: *mut SSL_CTX, val: c_int); } } @@ -232,12 +233,36 @@ impl PySslContext { }; let mut builder = SslContextBuilder::new(method).map_err(|e| convert_openssl_error(vm, e))?; + let check_hostname = proto == SslVersion::TlsClient; builder.set_verify(if check_hostname { SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT } else { SslVerifyMode::NONE }); + + let mut options = SslOptions::ALL & !SslOptions::DONT_INSERT_EMPTY_FRAGMENTS; + if proto != SslVersion::Ssl2 { + options |= SslOptions::NO_SSLV2; + } + if proto != SslVersion::Ssl3 { + options |= SslOptions::NO_SSLV3; + } + options |= SslOptions::NO_COMPRESSION; + options |= SslOptions::CIPHER_SERVER_PREFERENCE; + options |= SslOptions::SINGLE_DH_USE; + options |= SslOptions::SINGLE_ECDH_USE; + builder.set_options(options); + + let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY; + builder.set_mode(mode); + + unsafe { sys::SSL_CTX_set_post_handshake_auth(builder.as_ptr(), 0) }; + + builder + .set_session_id_context(b"Python") + .map_err(|e| convert_openssl_error(vm, e))?; + PySslContext { ctx: RefCell::new(builder), check_hostname, From 31905bd1ceb95e1430d1cf9c4f24e89609c2bcaf Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 22:28:05 -0500 Subject: [PATCH 12/17] Add OPENSSL_VERSION constants --- Lib/ssl.py | 2 +- vm/src/stdlib/ssl.rs | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index 0904a8a140..baff0ad291 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -101,7 +101,7 @@ import _ssl # if we can't import it, let the error propagate # XXX RustPython TODO: provide more of these imports -# from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION +from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION from _ssl import _SSLContext #, MemoryBIO, SSLSession from _ssl import ( SSLError, #SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 1c8aa424b3..42e2896a1f 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -6,7 +6,9 @@ use crate::obj::objbyteinner::PyBytesLike; use crate::obj::objbytes::PyBytesRef; use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::{objtype::PyClassRef, objweakref::PyWeak}; -use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; +use crate::pyobject::{ + IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, +}; use crate::types::create_type; use crate::VirtualMachine; @@ -575,6 +577,19 @@ fn convert_ssl_error(vm: &VirtualMachine, e: ssl::Error) -> PyBaseExceptionRef { } } +fn parse_version_info(mut n: i64) -> (u8, u8, u8, u8, u8) { + let status = (n & 0xF) as u8; + n >>= 4; + let patch = (n & 0xFF) as u8; + n >>= 8; + let fix = (n & 0xFF) as u8; + n >>= 8; + let minor = (n & 0xFF) as u8; + n >>= 8; + let major = (n & 0xFF) as u8; + (major, minor, fix, patch, status) +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { openssl::init(); let ctx = &vm.ctx; @@ -592,6 +607,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "get_default_verify_paths" => ctx.new_function(ssl_get_default_verify_paths), // Constants + "OPENSSL_VERSION" => ctx.new_str(openssl::version::version().to_owned()), + "OPENSSL_VERSION_NUMBER" => ctx.new_int(openssl::version::number()), + "OPENSSL_VERSION_INFO" => parse_version_info(openssl::version::number()).into_pyobject(vm).unwrap(), "PROTOCOL_SSLv2" => ctx.new_int(SslVersion::Ssl2 as u32), "PROTOCOL_SSLv3" => ctx.new_int(SslVersion::Ssl3 as u32), "PROTOCOL_SSLv23" => ctx.new_int(SslVersion::Tls as u32), From 12c57ce13526d7a00f687f58b32255f115417d05 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Mar 2020 23:09:10 -0500 Subject: [PATCH 13/17] Add ssl.RAND_* functions --- Lib/ssl.py | 2 +- vm/src/stdlib/ssl.rs | 49 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index baff0ad291..ce28d7b142 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -108,7 +108,7 @@ # SSLSyscallError, SSLEOFError, ) from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj -# from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes +from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes try: from _ssl import RAND_egd except ImportError: diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 42e2896a1f..612e1e73f3 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -7,7 +7,7 @@ use crate::obj::objbytes::PyBytesRef; use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::{objtype::PyClassRef, objweakref::PyWeak}; use crate::pyobject::{ - IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, }; use crate::types::create_type; use crate::VirtualMachine; @@ -25,7 +25,7 @@ use openssl::{ }; mod sys { - use libc::{c_char, c_int}; + use libc::{c_char, c_double, c_int, c_void}; pub use openssl_sys::*; extern "C" { pub fn OBJ_txt2obj(s: *const c_char, no_name: c_int) -> *mut ASN1_OBJECT; @@ -38,6 +38,8 @@ mod sys { pub fn X509_get_default_cert_dir_env() -> *const c_char; pub fn X509_get_default_cert_dir() -> *const c_char; pub fn SSL_CTX_set_post_handshake_auth(ctx: *mut SSL_CTX, val: c_int); + pub fn RAND_add(buf: *const c_void, num: c_int, randomness: c_double); + pub fn RAND_pseudo_bytes(buf: *const u8, num: c_int) -> c_int; } } @@ -188,6 +190,44 @@ fn ssl_get_default_verify_paths() -> (String, String, String, String) { } } +fn ssl_rand_status() -> i32 { + unsafe { sys::RAND_status() } +} + +fn ssl_rand_add(string: Either, entropy: f64) { + let f = |b: &[u8]| { + for buf in b.chunks(libc::c_int::max_value() as usize) { + unsafe { sys::RAND_add(buf.as_ptr() as *const _, buf.len() as _, entropy) } + } + }; + match string { + Either::A(s) => f(s.as_str().as_bytes()), + Either::B(b) => b.with_ref(f), + } +} + +fn ssl_rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult> { + if n < 0 { + return Err(vm.new_value_error("num must be positive".to_owned())); + } + let mut buf = vec![0; n as usize]; + openssl::rand::rand_bytes(&mut buf) + .map(|()| buf) + .map_err(|e| convert_openssl_error(vm, e)) +} + +fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec, bool)> { + if n < 0 { + return Err(vm.new_value_error("num must be positive".to_owned())); + } + let mut buf = vec![0; n as usize]; + let ret = unsafe { sys::RAND_pseudo_bytes(buf.as_mut_ptr(), n) }; + match ret { + 0 | 1 => Ok((buf, ret == 1)), + _ => Err(convert_openssl_error(vm, openssl::error::ErrorStack::get())), + } +} + #[pyclass(name = "_SSLContext")] struct PySslContext { ctx: RefCell, @@ -500,7 +540,6 @@ impl PySslSocket { #[pymethod] fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { - use crate::pyobject::Either; // Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE let mut handshaker: Either<_, ssl::MidHandshakeSslStream<_>> = Either::A(self.stream_builder()); @@ -605,6 +644,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "txt2obj" => ctx.new_function(ssl_txt2obj), "nid2obj" => ctx.new_function(ssl_nid2obj), "get_default_verify_paths" => ctx.new_function(ssl_get_default_verify_paths), + "RAND_status" => ctx.new_function(ssl_rand_status), + "RAND_add" => ctx.new_function(ssl_rand_add), + "RAND_bytes" => ctx.new_function(ssl_rand_bytes), + "RAND_pseudo_bytes" => ctx.new_function(ssl_rand_pseudo_bytes), // Constants "OPENSSL_VERSION" => ctx.new_str(openssl::version::version().to_owned()), From 126f41e003ee47e06fa10bc39b4db187f3a16a5e Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sun, 22 Mar 2020 17:06:08 -0500 Subject: [PATCH 14/17] Make `rustpython -i script.py` work like it does in CPython --- src/main.rs | 9 ++++++++- src/shell.rs | 5 ----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index da94fb26cc..c4ac3b38ef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -368,8 +368,15 @@ fn run_rustpython(vm: &VirtualMachine, matches: &ArgMatches) -> PyResult<()> { } else if let Some(module) = matches.value_of("m") { run_module(&vm, module)?; } else if let Some(filename) = matches.value_of("script") { - run_script(&vm, scope, filename)? + run_script(&vm, scope.clone(), filename)?; + if matches.is_present("inspect") { + shell::run_shell(&vm, scope)?; + } } else { + println!( + "Welcome to the magnificent Rust Python {} interpreter \u{1f631} \u{1f596}", + crate_version!() + ); shell::run_shell(&vm, scope)?; } diff --git a/src/shell.rs b/src/shell.rs index 7371c34efa..e3922ee882 100644 --- a/src/shell.rs +++ b/src/shell.rs @@ -41,11 +41,6 @@ fn shell_exec(vm: &VirtualMachine, source: &str, scope: Scope) -> ShellExecResul } pub fn run_shell(vm: &VirtualMachine, scope: Scope) -> PyResult<()> { - println!( - "Welcome to the magnificent Rust Python {} interpreter \u{1f631} \u{1f596}", - crate_version!() - ); - let mut repl = Readline::new(helper::ShellHelper::new(vm, scope.clone())); let mut full_input = String::new(); From 4f64afb8cfed9e6694fcbeb13d833a88d2a371e4 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sun, 22 Mar 2020 17:06:53 -0500 Subject: [PATCH 15/17] Add ctx.load_verify_locations(cadata=), and ctx.get_ca_certs --- Cargo.lock | 2 +- vm/Cargo.toml | 2 +- vm/src/stdlib/ssl.rs | 134 ++++++++++++++++++++++++++++++++++++------- 3 files changed, 114 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b952e66286..0f0eced175 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1600,7 +1600,7 @@ dependencies = [ "flame", "flamer", "flate2", - "foreign-types-shared", + "foreign-types", "gethostname", "getrandom", "hex", diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 4d1d0bdf6f..eeaca1247b 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -70,7 +70,7 @@ paste = "0.1" base64 = "0.11" is-macro = "0.1" result-like = "^0.2.1" -foreign-types-shared = "0.1" +foreign-types = "0.3" num_enum = "0.4" flame = { version = "0.2", optional = true } diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 612e1e73f3..8f2e883d4d 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -12,19 +12,22 @@ use crate::pyobject::{ use crate::types::create_type; use crate::VirtualMachine; -use std::cell::{RefCell, RefMut}; +use std::cell::{Ref, RefCell, RefMut}; use std::convert::TryFrom; use std::ffi::{CStr, CString}; use std::fmt; -use foreign_types_shared::{ForeignType, ForeignTypeRef}; +use foreign_types::{ForeignType, ForeignTypeRef}; use openssl::{ asn1::{Asn1Object, Asn1ObjectRef}, + error::ErrorStack, nid::Nid, ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode}, + x509::{X509Ref, X509}, }; mod sys { + #![allow(non_camel_case_types, unused)] use libc::{c_char, c_double, c_int, c_void}; pub use openssl_sys::*; extern "C" { @@ -40,7 +43,54 @@ mod sys { pub fn SSL_CTX_set_post_handshake_auth(ctx: *mut SSL_CTX, val: c_int); pub fn RAND_add(buf: *const c_void, num: c_int, randomness: c_double); pub fn RAND_pseudo_bytes(buf: *const u8, num: c_int) -> c_int; + pub fn X509_STORE_get0_objects(ctx: *mut X509_STORE) -> *mut stack_st_X509_OBJECT; + pub fn X509_OBJECT_free(a: *mut X509_OBJECT); } + + pub enum stack_st_X509_OBJECT {} + + pub type X509_LOOKUP_TYPE = c_int; + pub const X509_LU_NONE: X509_LOOKUP_TYPE = 0; + pub const X509_LU_X509: X509_LOOKUP_TYPE = 1; + pub const X509_LU_CRL: X509_LOOKUP_TYPE = 2; + + #[repr(C)] + pub struct X509_OBJECT { + pub r#type: X509_LOOKUP_TYPE, + pub data: X509_OBJECT_data, + } + #[repr(C)] + pub union X509_OBJECT_data { + pub ptr: *mut c_char, + pub x509: *mut X509, + pub crl: *mut X509_CRL, + pub pkey: *mut EVP_PKEY, + } +} + +// TODO: upstream this into rust-openssl +foreign_types::foreign_type! { + type CType = sys::X509_OBJECT; + fn drop = sys::X509_OBJECT_free; + + pub struct X509Object; + pub struct X509ObjectRef; +} + +impl X509ObjectRef { + fn x509(&self) -> Option<&X509Ref> { + let ptr = self.as_ptr(); + let ty = unsafe { (*ptr).r#type }; + if ty == sys::X509_LU_X509 { + Some(unsafe { X509Ref::from_ptr((*ptr).data.x509) }) + } else { + None + } + } +} + +impl openssl::stack::Stackable for X509Object { + type StackType = sys::stack_st_X509_OBJECT; } #[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive, PartialEq)] @@ -224,7 +274,7 @@ fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec, bool let ret = unsafe { sys::RAND_pseudo_bytes(buf.as_mut_ptr(), n) }; match ret { 0 | 1 => Ok((buf, ret == 1)), - _ => Err(convert_openssl_error(vm, openssl::error::ErrorStack::get())), + _ => Err(convert_openssl_error(vm, ErrorStack::get())), } } @@ -251,11 +301,11 @@ impl PySslContext { fn builder(&self) -> RefMut { self.ctx.borrow_mut() } - // fn ctx(&self) -> Ref { - // Ref::map(self.ctx.borrow(), |ctx| unsafe { - // SslContextRef::from_ptr(ctx.as_ptr()) - // }) - // } + fn ctx(&self) -> Ref { + Ref::map(self.ctx.borrow(), |ctx| unsafe { + &**(ctx as *const SslContextBuilder as *const ssl::SslContext) + }) + } fn ptr(&self) -> *mut sys::SSL_CTX { self.ctx.borrow().as_ptr() } @@ -374,8 +424,23 @@ impl PySslContext { ); } - if let Some(_cadata) = args.cadata { - todo!() + if let Some(cadata) = args.cadata { + let cert = match cadata { + Either::A(s) => { + if !s.as_str().is_ascii() { + return Err(vm.new_type_error("Must be an ascii string".to_owned())); + } + X509::from_pem(s.as_str().as_bytes()) + } + Either::B(b) => b.with_ref(X509::from_der), + }; + let cert = cert.map_err(|e| convert_openssl_error(vm, e))?; + let ctx = self.ctx(); + let store = ctx.cert_store(); + let ret = unsafe { sys::X509_STORE_add_cert(store.as_ptr(), cert.as_ptr()) }; + if ret <= 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } } if args.cafile.is_some() || args.capath.is_some() { @@ -395,7 +460,7 @@ impl PySslContext { let err = if errno != 0 { super::os::errno_err(vm) } else { - convert_openssl_error(vm, openssl::error::ErrorStack::get()) + convert_openssl_error(vm, ErrorStack::get()) }; return Err(err); } @@ -404,6 +469,32 @@ impl PySslContext { Ok(()) } + #[pymethod] + fn get_ca_certs(&self, binary_form: OptionalArg, vm: &VirtualMachine) -> PyResult { + use openssl::stack::StackRef; + let binary_form = binary_form.unwrap_or(false); + let certs = unsafe { + let stack = sys::X509_STORE_get0_objects(self.ctx().cert_store().as_ptr()); + assert!(!stack.is_null()); + StackRef::::from_ptr(stack) + }; + let certs = certs + .iter() + .filter_map(|cert| { + let cert = cert.x509()?; + let obj = if binary_form { + cert.to_der() + .map(|b| vm.ctx.new_bytes(b)) + .map_err(|e| convert_openssl_error(vm, e)) + } else { + todo!() + }; + Some(obj) + }) + .collect::, _>>()?; + Ok(vm.ctx.new_list(certs)) + } + #[pymethod] fn _wrap_socket( zelf: PyRef, @@ -475,7 +566,7 @@ struct LoadVerifyLocationsArgs { #[pyarg(positional_or_keyword, default = "None")] capath: Option, #[pyarg(positional_or_keyword, default = "None")] - cadata: Option, + cadata: Option>, } #[pyclass(name = "_SSLSocket")] @@ -591,19 +682,18 @@ fn ssl_error(vm: &VirtualMachine) -> PyClassRef { vm.class("_ssl", "SSLError") } -fn convert_openssl_error( - vm: &VirtualMachine, - err: openssl::error::ErrorStack, -) -> PyBaseExceptionRef { +fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef { let cls = ssl_error(vm); match err.errors().first() { Some(e) => { - let no = "unknown"; - let msg = format!( - "openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}", - e.code(), e.library().unwrap_or(no), e.function().unwrap_or(no), e.line(), - e.reason().unwrap_or(no), e.data().unwrap_or("none"), - ); + // let no = "unknown"; + // let msg = format!( + // "openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}", + // e.code(), e.library().unwrap_or(no), e.function().unwrap_or(no), e.line(), + // e.reason().unwrap_or(no), e.data().unwrap_or("none"), + // ); + // TODO: map the error codes to code names, e.g. "CERTIFICATE_VERIFY_FAILED", just requires a big hashmap/dict + let msg = e.to_string(); vm.new_exception_msg(cls, msg) } None => vm.new_exception_empty(cls), From 4fc45d25697600454478304acd23ac4e8d9fdd68 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Tue, 24 Mar 2020 00:17:04 -0500 Subject: [PATCH 16/17] Add _SSLSocket.peer_certificate --- vm/src/stdlib/ssl.rs | 130 +++++++++++++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 29 deletions(-) diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 8f2e883d4d..091c4d560b 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -3,11 +3,10 @@ use crate::exceptions::PyBaseExceptionRef; use crate::function::OptionalArg; use crate::obj::objbytearray::PyByteArrayRef; use crate::obj::objbyteinner::PyBytesLike; -use crate::obj::objbytes::PyBytesRef; -use crate::obj::objstr::{PyString, PyStringRef}; +use crate::obj::objstr::PyStringRef; use crate::obj::{objtype::PyClassRef, objweakref::PyWeak}; use crate::pyobject::{ - Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + Either, IntoPyObject, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, }; use crate::types::create_type; use crate::VirtualMachine; @@ -23,12 +22,12 @@ use openssl::{ error::ErrorStack, nid::Nid, ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode}, - x509::{X509Ref, X509}, + x509::{self, X509Ref, X509}, }; mod sys { #![allow(non_camel_case_types, unused)] - use libc::{c_char, c_double, c_int, c_void}; + use libc::{c_char, c_double, c_int, c_long, c_void}; pub use openssl_sys::*; extern "C" { pub fn OBJ_txt2obj(s: *const c_char, no_name: c_int) -> *mut ASN1_OBJECT; @@ -45,6 +44,8 @@ mod sys { pub fn RAND_pseudo_bytes(buf: *const u8, num: c_int) -> c_int; pub fn X509_STORE_get0_objects(ctx: *mut X509_STORE) -> *mut stack_st_X509_OBJECT; pub fn X509_OBJECT_free(a: *mut X509_OBJECT); + pub fn SSL_is_init_finished(ssl: *const SSL) -> c_int; + pub fn X509_get_version(x: *const X509) -> c_long; } pub enum stack_st_X509_OBJECT {} @@ -317,7 +318,7 @@ impl PySslContext { let method = match proto { SslVersion::Ssl2 => todo!(), SslVersion::Ssl3 => todo!(), - SslVersion::Tls => unsafe { ssl::SslMethod::from_ptr(sys::TLS_method()) }, + SslVersion::Tls => ssl::SslMethod::tls(), SslVersion::Tls1 => todo!(), // TODO: Tls1_1, Tls1_2 ? SslVersion::TlsClient => unsafe { ssl::SslMethod::from_ptr(sys::TLS_client_method()) }, @@ -482,14 +483,7 @@ impl PySslContext { .iter() .filter_map(|cert| { let cert = cert.x509()?; - let obj = if binary_form { - cert.to_der() - .map(|b| vm.ctx.new_bytes(b)) - .map_err(|e| convert_openssl_error(vm, e)) - } else { - todo!() - }; - Some(obj) + Some(cert_to_py(vm, cert, binary_form)) }) .collect::, _>>()?; Ok(vm.ctx.new_list(certs)) @@ -501,18 +495,6 @@ impl PySslContext { args: WrapSocketArgs, vm: &VirtualMachine, ) -> PyResult { - let server_hostname = args - .server_hostname - .map(|s| { - vm.encode( - s.into_object(), - Some(PyString::from("ascii").into_ref(vm)), - None, - ) - .and_then(|res| PyBytesRef::try_from_object(vm, res)) - }) - .transpose()?; - let ssl = { let ptr = zelf.ptr(); let ctx = unsafe { ssl::SslContext::from_ptr(ptr) }; @@ -538,7 +520,7 @@ impl PySslContext { ctx: zelf, stream: RefCell::new(Some(stream)), socket_type, - server_hostname, + server_hostname: args.server_hostname, owner: RefCell::new(args.owner.as_ref().map(PyWeak::downgrade)), }) } @@ -574,7 +556,7 @@ struct PySslSocket { ctx: PyRef, stream: RefCell>>, socket_type: SslServerOrClient, - server_hostname: Option, + server_hostname: Option, owner: RefCell>, } @@ -625,10 +607,28 @@ impl PySslSocket { self.ctx.clone() } #[pyproperty] - fn server_hostname(&self) -> Option { + fn server_hostname(&self) -> Option { self.server_hostname.clone() } + #[pymethod] + fn peer_certificate( + &self, + binary: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let binary = binary.unwrap_or(false); + let init_finished = unsafe { sys::SSL_is_init_finished(self.stream().ssl().as_ptr()) } != 0; + if !init_finished { + return Err(vm.new_value_error("handshake not done yet".to_owned())); + } + self.stream() + .ssl() + .peer_certificate() + .map(|cert| cert_to_py(vm, &cert, binary)) + .transpose() + } + #[pymethod] fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { // Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE @@ -706,6 +706,78 @@ fn convert_ssl_error(vm: &VirtualMachine, e: ssl::Error) -> PyBaseExceptionRef { } } +fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult { + if binary { + cert.to_der() + .map(|b| vm.ctx.new_bytes(b)) + .map_err(|e| convert_openssl_error(vm, e)) + } else { + let dict = vm.ctx.new_dict(); + + let name_to_py = |name: &x509::X509NameRef| { + name.entries() + .map(|entry| { + let txt = match obj2txt(entry.object(), false) { + Some(s) => vm.new_str(s), + None => vm.get_none(), + }; + let data = vm.new_str(entry.data().as_utf8()?.to_owned()); + Ok(vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![txt, data])])) + }) + .collect::>() + .map(|list| vm.ctx.new_tuple(list)) + .map_err(|e| convert_openssl_error(vm, e)) + }; + + dict.set_item("subject", name_to_py(cert.subject_name())?, vm)?; + dict.set_item("issuer", name_to_py(cert.issuer_name())?, vm)?; + + let version = unsafe { sys::X509_get_version(cert.as_ptr()) }; + dict.set_item("version", vm.new_int(version), vm)?; + + let serial_num = cert + .serial_number() + .to_bn() + .and_then(|bn| bn.to_hex_str()) + .map_err(|e| convert_openssl_error(vm, e))?; + dict.set_item("serialNumber", vm.new_str(serial_num.to_owned()), vm)?; + + dict.set_item("notBefore", vm.new_str(cert.not_before().to_string()), vm)?; + dict.set_item("notAfter", vm.new_str(cert.not_after().to_string()), vm)?; + + if let Some(names) = cert.subject_alt_names() { + let san = names + .iter() + .filter_map(|gen_name| { + if let Some(email) = gen_name.email() { + Some(vm.ctx.new_tuple(vec![ + vm.new_str("email".to_owned()), + vm.new_str(email.to_owned()), + ])) + } else if let Some(dnsname) = gen_name.dnsname() { + Some(vm.ctx.new_tuple(vec![ + vm.new_str("DNS".to_owned()), + vm.new_str(dnsname.to_owned()), + ])) + } else if let Some(ip) = gen_name.ipaddress() { + Some(vm.ctx.new_tuple(vec![ + vm.new_str("IP Address".to_owned()), + vm.new_str(String::from_utf8_lossy(ip).into_owned()), + ])) + } else { + // TODO: convert every type of general name: + // https://github.com/python/cpython/blob/3.6/Modules/_ssl.c#L1092-L1231 + None + } + }) + .collect(); + dict.set_item("subjectAltName", vm.ctx.new_tuple(san), vm)?; + }; + + Ok(dict.into_object()) + } +} + fn parse_version_info(mut n: i64) -> (u8, u8, u8, u8, u8) { let status = (n & 0xF) as u8; n >>= 4; From 7051b25d2aec8e3faa08d095abdf98176f0a7a96 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Fri, 27 Mar 2020 15:04:07 -0500 Subject: [PATCH 17/17] Manually error on embedded nul for set_ciphers --- vm/src/stdlib/ssl.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 091c4d560b..3e12ae9b8e 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -364,12 +364,14 @@ impl PySslContext { } #[pymethod] - fn set_ciphers(&self, cipherlist: CString, vm: &VirtualMachine) -> PyResult<()> { - self.builder() - .set_cipher_list(cipherlist.to_str().unwrap()) - .map_err(|_| { - vm.new_exception_msg(ssl_error(vm), "No cipher can be selected.".to_owned()) - }) + fn set_ciphers(&self, cipherlist: PyStringRef, vm: &VirtualMachine) -> PyResult<()> { + let ciphers = cipherlist.as_str(); + if ciphers.contains('\0') { + return Err(vm.new_value_error("embedded null character".to_owned())); + } + self.builder().set_cipher_list(ciphers).map_err(|_| { + vm.new_exception_msg(ssl_error(vm), "No cipher can be selected.".to_owned()) + }) } #[pyproperty]