Skip to content

refactor: add mypy and fix typing issues #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@ repos:
- id: check-yaml
- id: trailing-whitespace
- id: check-merge-conflict

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies:
- openfeature-sdk>=0.4.0
- opentelemetry-api
- types-protobuf
exclude: proto|tests
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class EventAttributes:
class TracingHook(Hook):
def after(
self, hook_context: HookContext, details: FlagEvaluationDetails, hints: dict
):
) -> None:
current_span = trace.get_current_span()

variant = details.variant
Expand All @@ -38,6 +38,8 @@ def after(

current_span.add_event(OTEL_EVENT_NAME, event_attributes)

def error(self, hook_context: HookContext, exception: Exception, hints: dict):
def error(
self, hook_context: HookContext, exception: Exception, hints: dict
) -> None:
current_span = trace.get_current_span()
current_span.record_exception(exception)
17 changes: 17 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[mypy]
files = hooks,providers
exclude = proto|tests
untyped_calls_exclude = flagd.proto

namespace_packages = True
explicit_package_bases = True
local_partial_types = True
pretty = True
strict = True
disallow_any_generics = False

[mypy-flagd.proto.*]
follow_imports = silent

[mypy-grpc]
ignore_missing_imports = True
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import typing

T = typing.TypeVar("T")


def str_to_bool(val: str) -> bool:
return val.lower() == "true"


def env_or_default(env_var, default, cast=None):
def env_or_default(
env_var: str, default: T, cast: typing.Optional[typing.Callable[[str], T]] = None
) -> typing.Union[str, T]:
val = os.environ.get(env_var)
if val is None:
return default
Expand All @@ -17,7 +21,7 @@ class Config:
def __init__(
self,
host: typing.Optional[str] = None,
port: typing.Optional[str] = None,
port: typing.Optional[int] = None,
tls: typing.Optional[bool] = None,
timeout: typing.Optional[int] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
"""

import typing
from dataclasses import dataclass
from numbers import Number

import grpc
from google.protobuf.struct_pb2 import Struct
Expand All @@ -36,17 +34,15 @@
ParseError,
TypeMismatchError,
)
from openfeature.flag_evaluation import FlagEvaluationDetails
from openfeature.flag_evaluation import FlagResolutionDetails
from openfeature.provider.metadata import Metadata
from openfeature.provider.provider import AbstractProvider

from .config import Config
from .flag_type import FlagType
from .proto.schema.v1 import schema_pb2, schema_pb2_grpc


@dataclass
class Metadata:
name: str
T = typing.TypeVar("T")


class FlagdProvider(AbstractProvider):
Expand Down Expand Up @@ -78,85 +74,85 @@ def __init__(
self.channel = channel_factory(f"{self.config.host}:{self.config.port}")
self.stub = schema_pb2_grpc.ServiceStub(self.channel)

def shutdown(self):
def shutdown(self) -> None:
self.channel.close()

def get_metadata(self):
def get_metadata(self) -> Metadata:
"""Returns provider metadata"""
return Metadata(name="FlagdProvider")

def resolve_boolean_details(
self,
key: str,
default_value: bool,
evaluation_context: EvaluationContext = None,
):
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return self._resolve(key, FlagType.BOOLEAN, default_value, evaluation_context)

def resolve_string_details(
self,
key: str,
default_value: str,
evaluation_context: EvaluationContext = None,
):
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return self._resolve(key, FlagType.STRING, default_value, evaluation_context)

def resolve_float_details(
self,
key: str,
default_value: Number,
evaluation_context: EvaluationContext = None,
):
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return self._resolve(key, FlagType.FLOAT, default_value, evaluation_context)

def resolve_integer_details(
self,
key: str,
default_value: Number,
evaluation_context: EvaluationContext = None,
):
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return self._resolve(key, FlagType.INTEGER, default_value, evaluation_context)

def resolve_object_details(
self,
key: str,
default_value: typing.Union[dict, list],
evaluation_context: EvaluationContext = None,
):
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)

def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: typing.Any,
evaluation_context: EvaluationContext,
):
default_value: T,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[T]:
context = self._convert_context(evaluation_context)
call_args = {"timeout": self.config.timeout}
try:
if flag_type == FlagType.BOOLEAN:
request = schema_pb2.ResolveBooleanRequest(
request = schema_pb2.ResolveBooleanRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveBoolean(request, **call_args)
elif flag_type == FlagType.STRING:
request = schema_pb2.ResolveStringRequest(
request = schema_pb2.ResolveStringRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveString(request, **call_args)
elif flag_type == FlagType.OBJECT:
request = schema_pb2.ResolveObjectRequest(
request = schema_pb2.ResolveObjectRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveObject(request, **call_args)
elif flag_type == FlagType.FLOAT:
request = schema_pb2.ResolveFloatRequest(
request = schema_pb2.ResolveFloatRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveFloat(request, **call_args)
elif flag_type == FlagType.INTEGER:
request = schema_pb2.ResolveIntRequest(
request = schema_pb2.ResolveIntRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveInt(request, **call_args)
Expand All @@ -176,14 +172,15 @@ def _resolve(
raise GeneralError(message) from e

# Got a valid flag and valid type. Return it.
return FlagEvaluationDetails(
flag_key=flag_key,
return FlagResolutionDetails(
value=response.value,
reason=response.reason,
variant=response.variant,
)

def _convert_context(self, evaluation_context: EvaluationContext):
def _convert_context(
self, evaluation_context: typing.Optional[EvaluationContext]
) -> Struct:
s = Struct()
if evaluation_context:
try:
Expand Down