Skip to content

Add type annotations #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion adafruit_azureiot/device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DeviceRegistrationError(Exception):
An error from the device registration
"""

def __init__(self, message):
def __init__(self, message: str):
super().__init__(message)
self.message = message

Expand Down
51 changes: 30 additions & 21 deletions adafruit_azureiot/hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@

# pylint: disable=C0103, W0108, R0915, C0116, C0115

try:
from typing import Union
except ImportError:
pass

def __translate(key, translation):

def __translate(key: Union[bytes, bytearray], translation: bytes) -> bytes:
return bytes(translation[x] for x in key)


Expand All @@ -28,7 +33,7 @@ def __translate(key, translation):
SHA_DIGESTSIZE = 32


def new_shaobject():
def new_shaobject() -> dict:
"""Struct. for storing SHA information."""
return {
"digest": [0] * 8,
Expand All @@ -40,7 +45,7 @@ def new_shaobject():
}


def sha_init():
def sha_init() -> dict:
"""Initialize the SHA digest."""
sha_info = new_shaobject()
sha_info["digest"] = [
Expand Down Expand Up @@ -73,7 +78,7 @@ def sha_init():
Gamma1 = lambda x: (S(x, 17) ^ S(x, 19) ^ R(x, 10))


def sha_transform(sha_info):
def sha_transform(sha_info: dict) -> None:
W = []

d = sha_info["data"]
Expand All @@ -90,7 +95,7 @@ def sha_transform(sha_info):
ss = sha_info["digest"][:]

# pylint: disable=too-many-arguments, line-too-long
def RND(a, b, c, d, e, f, g, h, i, ki):
def RND(a, b, c, d, e, f, g, h, i, ki): # type: ignore[no-untyped-def]
"""Compress"""
t0 = h + Sigma1(e) + Ch(e, f, g) + ki + W[i]
t1 = Sigma0(a) + Maj(a, b, c)
Expand Down Expand Up @@ -298,7 +303,7 @@ def RND(a, b, c, d, e, f, g, h, i, ki):
sha_info["digest"] = dig


def sha_update(sha_info, buffer):
def sha_update(sha_info: dict, buffer: Union[bytes, bytearray]) -> None:
"""Update the SHA digest.
:param dict sha_info: SHA Digest.
:param str buffer: SHA buffer size.
Expand Down Expand Up @@ -346,13 +351,13 @@ def sha_update(sha_info, buffer):
sha_info["local"] = count


def getbuf(s):
def getbuf(s: Union[str, bytes, bytearray]) -> Union[bytes, bytearray]:
if isinstance(s, str):
return s.encode("ascii")
return bytes(s)


def sha_final(sha_info):
def sha_final(sha_info: dict) -> bytes:
"""Finish computing the SHA Digest."""
lo_bit_count = sha_info["count_lo"]
hi_bit_count = sha_info["count_hi"]
Expand Down Expand Up @@ -393,28 +398,28 @@ class sha256:
block_size = SHA_BLOCKSIZE
name = "sha256"

def __init__(self, s=None):
def __init__(self, s: Union[str, bytes, bytearray] = None):
"""Constructs a SHA256 hash object."""
self._sha = sha_init()
if s:
sha_update(self._sha, getbuf(s))

def update(self, s):
def update(self, s: Union[str, bytes, bytearray]) -> None:
"""Updates the hash object with a bytes-like object, s."""
sha_update(self._sha, getbuf(s))

def digest(self):
def digest(self) -> bytes:
"""Returns the digest of the data passed to the update()
method so far."""
return sha_final(self._sha.copy())[: self._sha["digestsize"]]

def hexdigest(self):
def hexdigest(self) -> str:
"""Like digest() except the digest is returned as a string object of
double length, containing only hexadecimal digits.
"""
return "".join(["%.2x" % i for i in self.digest()])

def copy(self):
def copy(self) -> "sha256":
"""Return a copy (“clone”) of the hash object."""
new = sha256()
new._sha = self._sha.copy()
Expand All @@ -429,7 +434,9 @@ class HMAC:

blocksize = 64 # 512-bit HMAC; can be changed in subclasses.

def __init__(self, key, msg=None):
def __init__(
self, key: Union[bytes, bytearray], msg: Union[bytes, bytearray] = None
):
"""Create a new HMAC object.

key: key for the keyed hash object.
Expand Down Expand Up @@ -478,15 +485,15 @@ def __init__(self, key, msg=None):
self.update(msg)

@property
def name(self):
def name(self) -> str:
"""Return the name of this object"""
return "hmac-" + self.inner.name

def update(self, msg):
def update(self, msg: Union[bytes, bytearray]) -> None:
"""Update this hashing object with the string msg."""
self.inner.update(msg)

def copy(self):
def copy(self) -> "HMAC":
"""Return a separate copy of this hashing object.

An update to this copy won't affect the original object.
Expand All @@ -499,7 +506,7 @@ def copy(self):
other.outer = self.outer.copy()
return other

def _current(self):
def _current(self) -> "sha256":
"""Return a hash object for the current state.

To be used only internally with digest() and hexdigest().
Expand All @@ -508,7 +515,7 @@ def _current(self):
hmac.update(self.inner.digest())
return hmac

def digest(self):
def digest(self) -> bytes:
"""Return the hash value of this hashing object.

This returns a string containing 8-bit data. The object is
Expand All @@ -518,13 +525,15 @@ def digest(self):
hmac = self._current()
return hmac.digest()

def hexdigest(self):
def hexdigest(self) -> str:
"""Like digest(), but returns a string of hexadecimal digits instead."""
hmac = self._current()
return hmac.hexdigest()


def new_hmac(key, msg=None):
def new_hmac(
key: Union[bytes, bytearray], msg: Union[bytes, bytearray] = None
) -> "HMAC":
"""Create a new hashing object and return it.

key: The starting key for the hash.
Expand Down
50 changes: 33 additions & 17 deletions adafruit_azureiot/iothub_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
* Author(s): Jim Bennett, Elena Horton
"""

try:
from typing import Any, Callable, Mapping, Union
except ImportError:
pass

import json
import adafruit_logging as logging
from .iot_error import IoTError
from .iot_mqtt import IoTMQTT, IoTMQTTCallback, IoTResponse


def _validate_keys(connection_string_parts):
def _validate_keys(connection_string_parts: Mapping) -> None:
"""Raise ValueError if incorrect combination of keys"""
host_name = connection_string_parts.get(HOST_NAME)
shared_access_key_name = connection_string_parts.get(SHARED_ACCESS_KEY_NAME)
Expand Down Expand Up @@ -67,7 +72,7 @@ def connection_status_change(self, connected: bool) -> None:
self._on_connection_status_changed(connected)

# pylint: disable=W0613, R0201
def direct_method_invoked(self, method_name: str, payload) -> IoTResponse:
def direct_method_invoked(self, method_name: str, payload: str) -> IoTResponse:
"""Called when a direct method is invoked
:param str method_name: The name of the method that was invoked
:param str payload: The payload with the message
Expand All @@ -91,7 +96,10 @@ def cloud_to_device_message_received(self, body: str, properties: dict) -> None:
self._on_cloud_to_device_message_received(body, properties)

def device_twin_desired_updated(
self, desired_property_name: str, desired_property_value, desired_version: int
self,
desired_property_name: str,
desired_property_value: Any,
desired_version: int,
) -> None:
"""Called when the device twin desired properties are updated
:param str desired_property_name: The name of the desired property that was updated
Expand All @@ -107,7 +115,7 @@ def device_twin_desired_updated(
def device_twin_reported_updated(
self,
reported_property_name: str,
reported_property_value,
reported_property_value: Any,
reported_version: int,
) -> None:
"""Called when the device twin reported values are updated
Expand Down Expand Up @@ -175,21 +183,23 @@ def __init__(
self._mqtt = None

@property
def on_connection_status_changed(self):
def on_connection_status_changed(self) -> Callable:
"""A callback method that is called when the connection status is changed. This method should have the following signature:
def connection_status_changed(connected: bool) -> None
"""
return self._on_connection_status_changed

@on_connection_status_changed.setter
def on_connection_status_changed(self, new_on_connection_status_changed):
def on_connection_status_changed(
self, new_on_connection_status_changed: Callable
) -> None:
"""A callback method that is called when the connection status is changed. This method should have the following signature:
def connection_status_changed(connected: bool) -> None
"""
self._on_connection_status_changed = new_on_connection_status_changed

@property
def on_direct_method_invoked(self):
def on_direct_method_invoked(self) -> Callable:
"""A callback method that is called when a direct method is invoked. This method should have the following signature:
def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:

Expand All @@ -202,7 +212,7 @@ def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:
return self._on_direct_method_invoked

@on_direct_method_invoked.setter
def on_direct_method_invoked(self, new_on_direct_method_invoked):
def on_direct_method_invoked(self, new_on_direct_method_invoked: Callable) -> None:
"""A callback method that is called when a direct method is invoked. This method should have the following signature:
def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:

Expand All @@ -215,16 +225,16 @@ def direct_method_invoked(method_name: str, payload: str) -> IoTResponse:
self._on_direct_method_invoked = new_on_direct_method_invoked

@property
def on_cloud_to_device_message_received(self):
def on_cloud_to_device_message_received(self) -> Callable:
"""A callback method that is called when a cloud to device message is received. This method should have the following signature:
def cloud_to_device_message_received(body: str, properties: dict) -> None:
"""
return self._on_cloud_to_device_message_received

@on_cloud_to_device_message_received.setter
def on_cloud_to_device_message_received(
self, new_on_cloud_to_device_message_received
):
self, new_on_cloud_to_device_message_received: Callable
) -> None:
"""A callback method that is called when a cloud to device message is received. This method should have the following signature:
def cloud_to_device_message_received(body: str, properties: dict) -> None:
"""
Expand All @@ -233,15 +243,17 @@ def cloud_to_device_message_received(body: str, properties: dict) -> None:
)

@property
def on_device_twin_desired_updated(self):
def on_device_twin_desired_updated(self) -> Callable:
"""A callback method that is called when the desired properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_desired_updated(desired_property_name: str, desired_property_value, desired_version: int) -> None:
"""
return self._on_device_twin_desired_updated

@on_device_twin_desired_updated.setter
def on_device_twin_desired_updated(self, new_on_device_twin_desired_updated):
def on_device_twin_desired_updated(
self, new_on_device_twin_desired_updated: Callable
) -> None:
"""A callback method that is called when the desired properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_desired_updated(desired_property_name: str, desired_property_value, desired_version: int) -> None:
Expand All @@ -252,15 +264,17 @@ def device_twin_desired_updated(desired_property_name: str, desired_property_val
self._mqtt.subscribe_to_twins()

@property
def on_device_twin_reported_updated(self):
def on_device_twin_reported_updated(self) -> Callable:
"""A callback method that is called when the reported properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_reported_updated(reported_property_name: str, reported_property_value, reported_version: int) -> None:
"""
return self._on_device_twin_reported_updated

@on_device_twin_reported_updated.setter
def on_device_twin_reported_updated(self, new_on_device_twin_reported_updated):
def on_device_twin_reported_updated(
self, new_on_device_twin_reported_updated: Callable
) -> None:
"""A callback method that is called when the reported properties of the devices device twin are updated.
This method should have the following signature:
def device_twin_reported_updated(reported_property_name: str, reported_property_value, reported_version: int) -> None:
Expand Down Expand Up @@ -327,7 +341,9 @@ def is_connected(self) -> bool:

return False

def send_device_to_cloud_message(self, message, system_properties=None) -> None:
def send_device_to_cloud_message(
self, message: Union[str, dict], system_properties: dict = None
) -> None:
"""Send a device to cloud message from this device to Azure IoT Hub
:param message: The message data as a JSON string or a dictionary
:param system_properties: System properties to send with the message
Expand All @@ -339,7 +355,7 @@ def send_device_to_cloud_message(self, message, system_properties=None) -> None:

self._mqtt.send_device_to_cloud_message(message, system_properties)

def update_twin(self, patch) -> None:
def update_twin(self, patch: Union[str, dict]) -> None:
"""Updates the reported properties in the devices device twin
:param patch: The JSON patch to apply to the device twin reported properties
"""
Expand Down
Loading