diff --git a/adafruit_azureiot/device_registration.py b/adafruit_azureiot/device_registration.py index 9071981..e5fa2c5 100644 --- a/adafruit_azureiot/device_registration.py +++ b/adafruit_azureiot/device_registration.py @@ -28,7 +28,7 @@ class DeviceRegistrationError(Exception): An error from the device registration """ - def __init__(self, message): + def __init__(self, message: str): super().__init__(message) self.message = message diff --git a/adafruit_azureiot/hmac.py b/adafruit_azureiot/hmac.py index 64e72ed..7c4132e 100644 --- a/adafruit_azureiot/hmac.py +++ b/adafruit_azureiot/hmac.py @@ -16,8 +16,13 @@ # pylint: disable=C0103, W0108, R0915, C0116, C0115 +try: + from typing import Union +except ImportError: + pass -def __translate(key, translation): + +def __translate(key: Union[bytes, bytearray], translation: bytes) -> bytes: return bytes(translation[x] for x in key) @@ -28,7 +33,7 @@ def __translate(key, translation): SHA_DIGESTSIZE = 32 -def new_shaobject(): +def new_shaobject() -> dict: """Struct. for storing SHA information.""" return { "digest": [0] * 8, @@ -40,7 +45,7 @@ def new_shaobject(): } -def sha_init(): +def sha_init() -> dict: """Initialize the SHA digest.""" sha_info = new_shaobject() sha_info["digest"] = [ @@ -73,7 +78,7 @@ def sha_init(): Gamma1 = lambda x: (S(x, 17) ^ S(x, 19) ^ R(x, 10)) -def sha_transform(sha_info): +def sha_transform(sha_info: dict) -> None: W = [] d = sha_info["data"] @@ -90,7 +95,7 @@ def sha_transform(sha_info): ss = sha_info["digest"][:] # pylint: disable=too-many-arguments, line-too-long - def RND(a, b, c, d, e, f, g, h, i, ki): + def RND(a, b, c, d, e, f, g, h, i, ki): # type: ignore[no-untyped-def] """Compress""" t0 = h + Sigma1(e) + Ch(e, f, g) + ki + W[i] t1 = Sigma0(a) + Maj(a, b, c) @@ -298,7 +303,7 @@ def RND(a, b, c, d, e, f, g, h, i, ki): sha_info["digest"] = dig -def sha_update(sha_info, buffer): +def sha_update(sha_info: dict, buffer: Union[bytes, bytearray]) -> None: """Update the SHA digest. :param dict sha_info: SHA Digest. :param str buffer: SHA buffer size. @@ -346,13 +351,13 @@ def sha_update(sha_info, buffer): sha_info["local"] = count -def getbuf(s): +def getbuf(s: Union[str, bytes, bytearray]) -> Union[bytes, bytearray]: if isinstance(s, str): return s.encode("ascii") return bytes(s) -def sha_final(sha_info): +def sha_final(sha_info: dict) -> bytes: """Finish computing the SHA Digest.""" lo_bit_count = sha_info["count_lo"] hi_bit_count = sha_info["count_hi"] @@ -393,28 +398,28 @@ class sha256: block_size = SHA_BLOCKSIZE name = "sha256" - def __init__(self, s=None): + def __init__(self, s: Union[str, bytes, bytearray] = None): """Constructs a SHA256 hash object.""" self._sha = sha_init() if s: sha_update(self._sha, getbuf(s)) - def update(self, s): + def update(self, s: Union[str, bytes, bytearray]) -> None: """Updates the hash object with a bytes-like object, s.""" sha_update(self._sha, getbuf(s)) - def digest(self): + def digest(self) -> bytes: """Returns the digest of the data passed to the update() method so far.""" return sha_final(self._sha.copy())[: self._sha["digestsize"]] - def hexdigest(self): + def hexdigest(self) -> str: """Like digest() except the digest is returned as a string object of double length, containing only hexadecimal digits. """ return "".join(["%.2x" % i for i in self.digest()]) - def copy(self): + def copy(self) -> "sha256": """Return a copy (“clone”) of the hash object.""" new = sha256() new._sha = self._sha.copy() @@ -429,7 +434,9 @@ class HMAC: blocksize = 64 # 512-bit HMAC; can be changed in subclasses. - def __init__(self, key, msg=None): + def __init__( + self, key: Union[bytes, bytearray], msg: Union[bytes, bytearray] = None + ): """Create a new HMAC object. key: key for the keyed hash object. @@ -478,15 +485,15 @@ def __init__(self, key, msg=None): self.update(msg) @property - def name(self): + def name(self) -> str: """Return the name of this object""" return "hmac-" + self.inner.name - def update(self, msg): + def update(self, msg: Union[bytes, bytearray]) -> None: """Update this hashing object with the string msg.""" self.inner.update(msg) - def copy(self): + def copy(self) -> "HMAC": """Return a separate copy of this hashing object. An update to this copy won't affect the original object. @@ -499,7 +506,7 @@ def copy(self): other.outer = self.outer.copy() return other - def _current(self): + def _current(self) -> "sha256": """Return a hash object for the current state. To be used only internally with digest() and hexdigest(). @@ -508,7 +515,7 @@ def _current(self): hmac.update(self.inner.digest()) return hmac - def digest(self): + def digest(self) -> bytes: """Return the hash value of this hashing object. This returns a string containing 8-bit data. The object is @@ -518,13 +525,15 @@ def digest(self): hmac = self._current() return hmac.digest() - def hexdigest(self): + def hexdigest(self) -> str: """Like digest(), but returns a string of hexadecimal digits instead.""" hmac = self._current() return hmac.hexdigest() -def new_hmac(key, msg=None): +def new_hmac( + key: Union[bytes, bytearray], msg: Union[bytes, bytearray] = None +) -> "HMAC": """Create a new hashing object and return it. key: The starting key for the hash. diff --git a/adafruit_azureiot/iothub_device.py b/adafruit_azureiot/iothub_device.py index b5f01ec..274b8c4 100755 --- a/adafruit_azureiot/iothub_device.py +++ b/adafruit_azureiot/iothub_device.py @@ -12,13 +12,18 @@ * Author(s): Jim Bennett, Elena Horton """ +try: + from typing import Any, Callable, Mapping, Union +except ImportError: + pass + import json import adafruit_logging as logging from .iot_error import IoTError from .iot_mqtt import IoTMQTT, IoTMQTTCallback, IoTResponse -def _validate_keys(connection_string_parts): +def _validate_keys(connection_string_parts: Mapping) -> None: """Raise ValueError if incorrect combination of keys""" host_name = connection_string_parts.get(HOST_NAME) shared_access_key_name = connection_string_parts.get(SHARED_ACCESS_KEY_NAME) @@ -67,7 +72,7 @@ def connection_status_change(self, connected: bool) -> None: self._on_connection_status_changed(connected) # pylint: disable=W0613, R0201 - def direct_method_invoked(self, method_name: str, payload) -> IoTResponse: + def direct_method_invoked(self, method_name: str, payload: str) -> IoTResponse: """Called when a direct method is invoked :param str method_name: The name of the method that was invoked :param str payload: The payload with the message @@ -91,7 +96,10 @@ def cloud_to_device_message_received(self, body: str, properties: dict) -> None: self._on_cloud_to_device_message_received(body, properties) def device_twin_desired_updated( - self, desired_property_name: str, desired_property_value, desired_version: int + self, + desired_property_name: str, + desired_property_value: Any, + desired_version: int, ) -> None: """Called when the device twin desired properties are updated :param str desired_property_name: The name of the desired property that was updated @@ -107,7 +115,7 @@ def device_twin_desired_updated( def device_twin_reported_updated( self, reported_property_name: str, - reported_property_value, + reported_property_value: Any, reported_version: int, ) -> None: """Called when the device twin reported values are updated @@ -175,21 +183,23 @@ def __init__( self._mqtt = None @property - def on_connection_status_changed(self): + def on_connection_status_changed(self) -> Callable: """A callback method that is called when the connection status is changed. This method should have the following signature: def connection_status_changed(connected: bool) -> None """ return self._on_connection_status_changed @on_connection_status_changed.setter - def on_connection_status_changed(self, new_on_connection_status_changed): + def on_connection_status_changed( + self, new_on_connection_status_changed: Callable + ) -> None: """A callback method that is called when the connection status is changed. This method should have the following signature: def connection_status_changed(connected: bool) -> None """ self._on_connection_status_changed = new_on_connection_status_changed @property - def on_direct_method_invoked(self): + def on_direct_method_invoked(self) -> Callable: """A callback method that is called when a direct method is invoked. This method should have the following signature: def direct_method_invoked(method_name: str, payload: str) -> IoTResponse: @@ -202,7 +212,7 @@ def direct_method_invoked(method_name: str, payload: str) -> IoTResponse: return self._on_direct_method_invoked @on_direct_method_invoked.setter - def on_direct_method_invoked(self, new_on_direct_method_invoked): + def on_direct_method_invoked(self, new_on_direct_method_invoked: Callable) -> None: """A callback method that is called when a direct method is invoked. This method should have the following signature: def direct_method_invoked(method_name: str, payload: str) -> IoTResponse: @@ -215,7 +225,7 @@ def direct_method_invoked(method_name: str, payload: str) -> IoTResponse: self._on_direct_method_invoked = new_on_direct_method_invoked @property - def on_cloud_to_device_message_received(self): + def on_cloud_to_device_message_received(self) -> Callable: """A callback method that is called when a cloud to device message is received. This method should have the following signature: def cloud_to_device_message_received(body: str, properties: dict) -> None: """ @@ -223,8 +233,8 @@ def cloud_to_device_message_received(body: str, properties: dict) -> None: @on_cloud_to_device_message_received.setter def on_cloud_to_device_message_received( - self, new_on_cloud_to_device_message_received - ): + self, new_on_cloud_to_device_message_received: Callable + ) -> None: """A callback method that is called when a cloud to device message is received. This method should have the following signature: def cloud_to_device_message_received(body: str, properties: dict) -> None: """ @@ -233,7 +243,7 @@ def cloud_to_device_message_received(body: str, properties: dict) -> None: ) @property - def on_device_twin_desired_updated(self): + def on_device_twin_desired_updated(self) -> Callable: """A callback method that is called when the desired properties of the devices device twin are updated. This method should have the following signature: def device_twin_desired_updated(desired_property_name: str, desired_property_value, desired_version: int) -> None: @@ -241,7 +251,9 @@ def device_twin_desired_updated(desired_property_name: str, desired_property_val return self._on_device_twin_desired_updated @on_device_twin_desired_updated.setter - def on_device_twin_desired_updated(self, new_on_device_twin_desired_updated): + def on_device_twin_desired_updated( + self, new_on_device_twin_desired_updated: Callable + ) -> None: """A callback method that is called when the desired properties of the devices device twin are updated. This method should have the following signature: def device_twin_desired_updated(desired_property_name: str, desired_property_value, desired_version: int) -> None: @@ -252,7 +264,7 @@ def device_twin_desired_updated(desired_property_name: str, desired_property_val self._mqtt.subscribe_to_twins() @property - def on_device_twin_reported_updated(self): + def on_device_twin_reported_updated(self) -> Callable: """A callback method that is called when the reported properties of the devices device twin are updated. This method should have the following signature: def device_twin_reported_updated(reported_property_name: str, reported_property_value, reported_version: int) -> None: @@ -260,7 +272,9 @@ def device_twin_reported_updated(reported_property_name: str, reported_property_ return self._on_device_twin_reported_updated @on_device_twin_reported_updated.setter - def on_device_twin_reported_updated(self, new_on_device_twin_reported_updated): + def on_device_twin_reported_updated( + self, new_on_device_twin_reported_updated: Callable + ) -> None: """A callback method that is called when the reported properties of the devices device twin are updated. This method should have the following signature: def device_twin_reported_updated(reported_property_name: str, reported_property_value, reported_version: int) -> None: @@ -327,7 +341,9 @@ def is_connected(self) -> bool: return False - def send_device_to_cloud_message(self, message, system_properties=None) -> None: + def send_device_to_cloud_message( + self, message: Union[str, dict], system_properties: dict = None + ) -> None: """Send a device to cloud message from this device to Azure IoT Hub :param message: The message data as a JSON string or a dictionary :param system_properties: System properties to send with the message @@ -339,7 +355,7 @@ def send_device_to_cloud_message(self, message, system_properties=None) -> None: self._mqtt.send_device_to_cloud_message(message, system_properties) - def update_twin(self, patch) -> None: + def update_twin(self, patch: Union[str, dict]) -> None: """Updates the reported properties in the devices device twin :param patch: The JSON patch to apply to the device twin reported properties """ diff --git a/adafruit_azureiot/quote.py b/adafruit_azureiot/quote.py index a39b7ee..2a5cfbb 100644 --- a/adafruit_azureiot/quote.py +++ b/adafruit_azureiot/quote.py @@ -11,6 +11,12 @@ safe arg. """ + +try: + from typing import Any, Union +except ImportError: + pass + _ALWAYS_SAFE = frozenset( b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"abcdefghijklmnopqrstuvwxyz" b"0123456789" b"_.-~" ) @@ -18,7 +24,7 @@ SAFE_QUOTERS = {} -def quote(bytes_val: bytes, safe="/"): +def quote(bytes_val: bytes, safe: Union[str, bytes, bytearray] = "/") -> str: """The quote function %-escapes all characters that are neither in the unreserved chars ("always safe") nor the additional chars set via the safe arg. @@ -69,17 +75,17 @@ class defaultdict: @staticmethod # pylint: disable=W0613 - def __new__(cls, default_factory=None, **kwargs): + def __new__(cls, default_factory: Any = None, **kwargs: Any) -> "defaultdict": self = super(defaultdict, cls).__new__(cls) # pylint: disable=C0103 self.d = {} return self - def __init__(self, default_factory=None, **kwargs): + def __init__(self, default_factory: Any = None, **kwargs: Any): self.d = kwargs self.default_factory = default_factory - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: try: return self.d[key] except KeyError: @@ -87,16 +93,16 @@ def __getitem__(self, key): self.d[key] = val return val - def __setitem__(self, key, val): + def __setitem__(self, key: Any, val: Any) -> None: self.d[key] = val - def __delitem__(self, key): + def __delitem__(self, key: Any) -> None: del self.d[key] - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: return key in self.d - def __missing__(self, key): + def __missing__(self, key: Any) -> Any: if self.default_factory is None: raise KeyError(key) return self.default_factory() @@ -111,12 +117,12 @@ class Quoter(defaultdict): # Keeps a cache internally, using defaultdict, for efficiency (lookups # of cached keys don't call Python code at all). - def __init__(self, safe): + def __init__(self, safe: Union[bytes, bytearray]): """safe: bytes object.""" super().__init__() self.safe = _ALWAYS_SAFE.union(safe) - def __missing__(self, b): + def __missing__(self, b: int) -> str: # Handle a cache miss. Store quoted string in cache and return. res = chr(b) if b in self.safe else "%{:02X}".format(b) self[b] = res