diff --git a/Doc/fake_ldap_module_for_documentation.py b/Doc/fake_ldap_module_for_documentation.py index 30807819..2914bc9e 100644 --- a/Doc/fake_ldap_module_for_documentation.py +++ b/Doc/fake_ldap_module_for_documentation.py @@ -1,21 +1,21 @@ """ -A module that mocks `_ldap` for the purposes of generating documentation +A module that mocks `ldap._ldap` for the purposes of generating documentation -This module provides placeholders for the contents of `_ldap`, making it -possible to generate documentation even _ldap is not compiled. +This module provides placeholders for the contents of `ldap._ldap`, making it +possible to generate documentation even if ldap._ldap is not compiled. It should also make the documentation independent of which features are available in the system OpenLDAP library. The overly long module name will show up in AttributeError messages, -hinting that this is not the actual _ldap. +hinting that this is not the actual ldap._ldap. See https://www.python-ldap.org/ for details. """ import sys -# Cause `import _ldap` to import this module instead of the actual `_ldap`. -sys.modules['_ldap'] = sys.modules[__name__] +# Cause `import ldap._ldap` to import this module instead of the actual module. +sys.modules['ldap._ldap'] = sys.modules[__name__] from constants import CONSTANTS from pkginfo import __version__ diff --git a/Lib/ldap/__init__.py b/Lib/ldap/__init__.py index b1797078..9ce66fd0 100644 --- a/Lib/ldap/__init__.py +++ b/Lib/ldap/__init__.py @@ -11,16 +11,19 @@ import os import sys +from typing import Any, Type, Optional, Union + + if __debug__: # Tracing is only supported in debugging mode import atexit import traceback _trace_level = int(os.environ.get("PYTHON_LDAP_TRACE_LEVEL", 0)) - _trace_file = os.environ.get("PYTHON_LDAP_TRACE_FILE") - if _trace_file is None: + _trace_file_path = os.environ.get("PYTHON_LDAP_TRACE_FILE") + if _trace_file_path is None: _trace_file = sys.stderr else: - _trace_file = open(_trace_file, 'a') + _trace_file = open(_trace_file_path, 'a') atexit.register(_trace_file.close) _trace_stack_limit = None else: @@ -31,10 +34,10 @@ _trace_file = sys.stderr _trace_stack_limit = None -import _ldap +import ldap._ldap as _ldap assert _ldap.__version__==__version__, \ ImportError(f'ldap {__version__} and _ldap {_ldap.__version__} version mismatch!') -from _ldap import * +from ldap._ldap import * # call into libldap to initialize it right now LIBLDAP_API_INFO = _ldap.get_option(_ldap.OPT_API_INFO) @@ -45,18 +48,21 @@ class DummyLock: """Define dummy class with methods compatible to threading.Lock""" - def __init__(self): - pass - def acquire(self): + def __init__(self) -> None: pass - def release(self): + + def acquire(self) -> bool: + return True + + def release(self) -> None: pass try: # Check if Python installation was build with thread support + # FIXME: This can be simplified, from Python 3.7 this module is mandatory import threading except ImportError: - LDAPLockBaseClass = DummyLock + LDAPLockBaseClass: Union[Type[DummyLock], Type[threading.Lock]] = DummyLock else: LDAPLockBaseClass = threading.Lock @@ -69,7 +75,11 @@ class LDAPLock: """ _min_trace_level = 3 - def __init__(self,lock_class=None,desc=''): + def __init__( + self, + lock_class: Optional[Type[Any]] = None, + desc: str = '' + ) -> None: """ lock_class Class compatible to threading.Lock @@ -79,19 +89,19 @@ def __init__(self,lock_class=None,desc=''): self._desc = desc self._lock = (lock_class or LDAPLockBaseClass)() - def acquire(self): + def acquire(self) -> bool: if __debug__: global _trace_level if _trace_level>=self._min_trace_level: _trace_file.write('***{}.acquire() {} {}\n'.format(self.__class__.__name__,repr(self),self._desc)) return self._lock.acquire() - def release(self): + def release(self) -> None: if __debug__: global _trace_level if _trace_level>=self._min_trace_level: _trace_file.write('***{}.release() {} {}\n'.format(self.__class__.__name__,repr(self),self._desc)) - return self._lock.release() + self._lock.release() # Create module-wide lock for serializing all calls into underlying LDAP lib diff --git a/Lib/ldap/_ldap.pyi b/Lib/ldap/_ldap.pyi new file mode 100644 index 00000000..dbc737cf --- /dev/null +++ b/Lib/ldap/_ldap.pyi @@ -0,0 +1,404 @@ +from typing import Any, ClassVar + +__version__: str +__license__: str +__author__: str +API_VERSION: int +AUTH_NONE: int +AUTH_SIMPLE: int +AVA_BINARY: int +AVA_NONPRINTABLE: int +AVA_NULL: int +AVA_STRING: int +CONTROL_ASSERT: str +CONTROL_MANAGEDSAIT: str +CONTROL_PAGEDRESULTS: str +CONTROL_PASSWORDPOLICYREQUEST: str +CONTROL_PASSWORDPOLICYRESPONSE: str +CONTROL_POST_READ: str +CONTROL_PRE_READ: str +CONTROL_PROXY_AUTHZ: str +CONTROL_RELAX: str +CONTROL_SORTREQUEST: str +CONTROL_SORTRESPONSE: str +CONTROL_SUBENTRIES: str +CONTROL_SYNC: str +CONTROL_SYNC_DONE: str +CONTROL_SYNC_STATE: str +CONTROL_VALUESRETURNFILTER: str +DEREF_ALWAYS: int +DEREF_FINDING: int +DEREF_NEVER: int +DEREF_SEARCHING: int +DN_FORMAT_AD_CANONICAL: int +DN_FORMAT_DCE: int +DN_FORMAT_LDAP: int +DN_FORMAT_LDAPV2: int +DN_FORMAT_LDAPV3: int +DN_FORMAT_MASK: int +DN_FORMAT_UFN: int +DN_PEDANTIC: int +DN_PRETTY: int +DN_P_NOLEADTRAILSPACES: int +DN_P_NOSPACEAFTERRDN: int +DN_SKIP: int +INIT_FD_AVAIL: int +LIBLDAP_R: int +MOD_ADD: int +MOD_BVALUES: int +MOD_DELETE: int +MOD_INCREMENT: int +MOD_REPLACE: int +MSG_ALL: int +MSG_ONE: int +MSG_RECEIVED: int +NO_LIMIT: int +OPT_API_FEATURE_INFO: int +OPT_API_INFO: int +OPT_CLIENT_CONTROLS: int +OPT_CONNECT_ASYNC: int +OPT_DEBUG_LEVEL: int +OPT_DEFBASE: int +OPT_DEREF: int +OPT_DESC: int +OPT_DIAGNOSTIC_MESSAGE: int +OPT_ERROR_NUMBER: int +OPT_ERROR_STRING: int +OPT_HOST_NAME: int +OPT_MATCHED_DN: int +OPT_NETWORK_TIMEOUT: int +OPT_OFF: int +OPT_ON: int +OPT_PROTOCOL_VERSION: int +OPT_REFERRALS: int +OPT_REFHOPLIMIT: int +OPT_RESTART: int +OPT_RESULT_CODE: int +OPT_SERVER_CONTROLS: int +OPT_SIZELIMIT: int +OPT_SUCCESS: int +OPT_TCP_USER_TIMEOUT: int +OPT_TIMELIMIT: int +OPT_TIMEOUT: int +OPT_URI: int +OPT_X_KEEPALIVE_IDLE: int +OPT_X_KEEPALIVE_INTERVAL: int +OPT_X_KEEPALIVE_PROBES: int +OPT_X_SASL_AUTHCID: int +OPT_X_SASL_AUTHZID: int +OPT_X_SASL_MECH: int +OPT_X_SASL_NOCANON: int +OPT_X_SASL_REALM: int +OPT_X_SASL_SECPROPS: int +OPT_X_SASL_SSF: int +OPT_X_SASL_SSF_EXTERNAL: int +OPT_X_SASL_SSF_MAX: int +OPT_X_SASL_SSF_MIN: int +OPT_X_SASL_USERNAME: int +OPT_X_TLS: int +OPT_X_TLS_ALLOW: int +OPT_X_TLS_CACERTDIR: int +OPT_X_TLS_CACERTFILE: int +OPT_X_TLS_CERTFILE: int +OPT_X_TLS_CIPHER: int +OPT_X_TLS_CIPHER_SUITE: int +OPT_X_TLS_CRLCHECK: int +OPT_X_TLS_CRLFILE: int +OPT_X_TLS_CRL_ALL: int +OPT_X_TLS_CRL_NONE: int +OPT_X_TLS_CRL_PEER: int +OPT_X_TLS_CTX: int +OPT_X_TLS_DEMAND: int +OPT_X_TLS_DHFILE: int +OPT_X_TLS_ECNAME: int +OPT_X_TLS_HARD: int +OPT_X_TLS_KEYFILE: int +OPT_X_TLS_NEVER: int +OPT_X_TLS_NEWCTX: int +OPT_X_TLS_PACKAGE: int +OPT_X_TLS_PEERCERT: int +OPT_X_TLS_PROTOCOL_MAX: int +OPT_X_TLS_PROTOCOL_MIN: int +OPT_X_TLS_PROTOCOL_SSL3: int +OPT_X_TLS_PROTOCOL_TLS1_0: int +OPT_X_TLS_PROTOCOL_TLS1_1: int +OPT_X_TLS_PROTOCOL_TLS1_2: int +OPT_X_TLS_PROTOCOL_TLS1_3: int +OPT_X_TLS_RANDOM_FILE: int +OPT_X_TLS_REQUIRE_CERT: int +OPT_X_TLS_REQUIRE_SAN: int +OPT_X_TLS_TRY: int +OPT_X_TLS_VERSION: int +PORT: int +REQ_ABANDON: int +REQ_ADD: int +REQ_BIND: int +REQ_COMPARE: int +REQ_DELETE: int +REQ_EXTENDED: int +REQ_MODIFY: int +REQ_MODRDN: int +REQ_SEARCH: int +REQ_UNBIND: int +RES_ADD: int +RES_ANY: int +RES_BIND: int +RES_COMPARE: int +RES_DELETE: int +RES_EXTENDED: int +RES_INTERMEDIATE: int +RES_MODIFY: int +RES_MODRDN: int +RES_SEARCH_ENTRY: int +RES_SEARCH_REFERENCE: int +RES_SEARCH_RESULT: int +RES_UNSOLICITED: int +SASL_AUTOMATIC: int +SASL_AVAIL: int +SASL_INTERACTIVE: int +SASL_QUIET: int +SCOPE_BASE: int +SCOPE_ONELEVEL: int +SCOPE_SUBORDINATE: int +SCOPE_SUBTREE: int +SYNC_INFO: str +TAG_CONTROLS: int +TAG_EXOP_REQ_OID: int +TAG_EXOP_REQ_VALUE: int +TAG_EXOP_RES_OID: int +TAG_EXOP_RES_VALUE: int +TAG_LDAPCRED: int +TAG_LDAPDN: int +TAG_MESSAGE: int +TAG_MSGID: int +TAG_NEWSUPERIOR: int +TAG_REFERRAL: int +TAG_SASL_RES_CREDS: int +TLS_AVAIL: int +URL_ERR_BADSCOPE: int +URL_ERR_MEM: int +VENDOR_VERSION: int +VERSION: int +VERSION1: int +VERSION2: int +VERSION3: int +VERSION_MAX: int +VERSION_MIN: int + +class LDAPError(Exception): ... + +class ADMINLIMIT_EXCEEDED(LDAPError): + errnum: ClassVar[int] = ... + +class AFFECTS_MULTIPLE_DSAS(LDAPError): + errnum: ClassVar[int] = ... + +class ALIAS_DEREF_PROBLEM(LDAPError): + errnum: ClassVar[int] = ... + +class ALIAS_PROBLEM(LDAPError): + errnum: ClassVar[int] = ... + +class ALREADY_EXISTS(LDAPError): + errnum: ClassVar[int] = ... + +class ASSERTION_FAILED(LDAPError): + errnum: ClassVar[int] = ... + +class AUTH_METHOD_NOT_SUPPORTED(LDAPError): + errnum: ClassVar[int] = ... + +class AUTH_UNKNOWN(LDAPError): + errnum: ClassVar[int] = ... + +class BUSY(LDAPError): + errnum: ClassVar[int] = ... + +class CANCELLED(LDAPError): + errnum: ClassVar[int] = ... + +class CANNOT_CANCEL(LDAPError): + errnum: ClassVar[int] = ... + +class CLIENT_LOOP(LDAPError): + errnum: ClassVar[int] = ... + +class COMPARE_FALSE(LDAPError): + errnum: ClassVar[int] = ... + +class COMPARE_TRUE(LDAPError): + errnum: ClassVar[int] = ... + +class CONFIDENTIALITY_REQUIRED(LDAPError): + errnum: ClassVar[int] = ... + +class CONNECT_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class CONSTRAINT_VIOLATION(LDAPError): + errnum: ClassVar[int] = ... + +class CONTROL_NOT_FOUND(LDAPError): + errnum: ClassVar[int] = ... + +class DECODING_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class ENCODING_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class FILTER_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class INAPPROPRIATE_AUTH(LDAPError): + errnum: ClassVar[int] = ... + +class INAPPROPRIATE_MATCHING(LDAPError): + errnum: ClassVar[int] = ... + +class INSUFFICIENT_ACCESS(LDAPError): + errnum: ClassVar[int] = ... + +class INVALID_CREDENTIALS(LDAPError): + errnum: ClassVar[int] = ... + +class INVALID_DN_SYNTAX(LDAPError): + errnum: ClassVar[int] = ... + +class INVALID_SYNTAX(LDAPError): + errnum: ClassVar[int] = ... + +class IS_LEAF(LDAPError): + errnum: ClassVar[int] = ... + +class LOCAL_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class LOOP_DETECT(LDAPError): + errnum: ClassVar[int] = ... + +class MORE_RESULTS_TO_RETURN(LDAPError): + errnum: ClassVar[int] = ... + +class NAMING_VIOLATION(LDAPError): + errnum: ClassVar[int] = ... + +class NOT_ALLOWED_ON_NONLEAF(LDAPError): + errnum: ClassVar[int] = ... + +class NOT_ALLOWED_ON_RDN(LDAPError): + errnum: ClassVar[int] = ... + +class NOT_SUPPORTED(LDAPError): + errnum: ClassVar[int] = ... + +class NO_MEMORY(LDAPError): + errnum: ClassVar[int] = ... + +class NO_OBJECT_CLASS_MODS(LDAPError): + errnum: ClassVar[int] = ... + +class NO_RESULTS_RETURNED(LDAPError): + errnum: ClassVar[int] = ... + +class NO_SUCH_ATTRIBUTE(LDAPError): + errnum: ClassVar[int] = ... + +class NO_SUCH_OBJECT(LDAPError): + errnum: ClassVar[int] = ... + +class NO_SUCH_OPERATION(LDAPError): + errnum: ClassVar[int] = ... + +class OBJECT_CLASS_VIOLATION(LDAPError): + errnum: ClassVar[int] = ... + +class OPERATIONS_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class OTHER(LDAPError): + errnum: ClassVar[int] = ... + +class PARAM_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class PARTIAL_RESULTS(LDAPError): + errnum: ClassVar[int] = ... + +class PROTOCOL_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class PROXIED_AUTHORIZATION_DENIED(LDAPError): + errnum: ClassVar[int] = ... + +class REFERRAL(LDAPError): + errnum: ClassVar[int] = ... + +class REFERRAL_LIMIT_EXCEEDED(LDAPError): + errnum: ClassVar[int] = ... + +class RESULTS_TOO_LARGE(LDAPError): + errnum: ClassVar[int] = ... + +class SASL_BIND_IN_PROGRESS(LDAPError): + errnum: ClassVar[int] = ... + +class SERVER_DOWN(LDAPError): + errnum: ClassVar[int] = ... + +class SIZELIMIT_EXCEEDED(LDAPError): + errnum: ClassVar[int] = ... + +class STRONG_AUTH_NOT_SUPPORTED(LDAPError): + errnum: ClassVar[int] = ... + +class STRONG_AUTH_REQUIRED(LDAPError): + errnum: ClassVar[int] = ... + +class SUCCESS(LDAPError): + errnum: ClassVar[int] = ... + +class TIMELIMIT_EXCEEDED(LDAPError): + errnum: ClassVar[int] = ... + +class TIMEOUT(LDAPError): + errnum: ClassVar[int] = ... + +class TOO_LATE(LDAPError): + errnum: ClassVar[int] = ... + +class TYPE_OR_VALUE_EXISTS(LDAPError): + errnum: ClassVar[int] = ... + +class UNAVAILABLE(LDAPError): + errnum: ClassVar[int] = ... + +class UNAVAILABLE_CRITICAL_EXTENSION(LDAPError): + errnum: ClassVar[int] = ... + +class UNDEFINED_TYPE(LDAPError): + errnum: ClassVar[int] = ... + +class UNWILLING_TO_PERFORM(LDAPError): + errnum: ClassVar[int] = ... + +class USER_CANCELLED(LDAPError): + errnum: ClassVar[int] = ... + +class VLV_ERROR(LDAPError): + errnum: ClassVar[int] = ... + +class X_PROXY_AUTHZ_FAILURE(LDAPError): + errnum: ClassVar[int] = ... + +class error(Exception): ... + +def decode_page_control(*args: Any, **kwargs: Any) -> Any: ... +def encode_assertion_control(*args: Any, **kwargs: Any) -> Any: ... +def encode_page_control(*args: Any, **kwargs: Any) -> Any: ... +def encode_valuesreturnfilter_control(*args: Any, **kwargs: Any) -> Any: ... +def get_option(*args: Any, **kwargs: Any) -> Any: ... +def initialize(*args: Any, **kwargs: Any) -> Any: ... +def initialize_fd(*args: Any, **kwargs: Any) -> Any: ... +def set_option(*args: Any, **kwargs: Any) -> Any: ... +def str2dn(*args: Any, **kwargs: Any) -> Any: ... diff --git a/Lib/ldap/async.py b/Lib/ldap/async.py index 1d4505bc..9f933b8b 100644 --- a/Lib/ldap/async.py +++ b/Lib/ldap/async.py @@ -6,7 +6,7 @@ import warnings from ldap.asyncsearch import * -from ldap.asyncsearch import __version__ +from ldap.pkginfo import __version__ warnings.warn( "'ldap.async module' is deprecated, import 'ldap.asyncsearch' instead.", diff --git a/Lib/ldap/asyncsearch.py b/Lib/ldap/asyncsearch.py index 6c6929dd..863fb99b 100644 --- a/Lib/ldap/asyncsearch.py +++ b/Lib/ldap/asyncsearch.py @@ -6,7 +6,20 @@ import ldap -from ldap import __version__ +from ldap.pkginfo import __version__ +from ldap.controls import RequestControl +from typing import ( + Any, + Dict as DictType, + Iterable, + List as ListType, + Sequence, + TextIO, + Tuple, + Optional, + Union, +) +from ldap.types import LDAPSearchResult, LDAPEntryDict import ldif @@ -24,15 +37,19 @@ class WrongResultType(Exception): - def __init__(self,receivedResultType,expectedResultTypes): + def __init__( + self, + receivedResultType: int, + expectedResultTypes: Iterable[int], + ) -> None: self.receivedResultType = receivedResultType self.expectedResultTypes = expectedResultTypes Exception.__init__(self) - def __str__(self): + def __str__(self) -> str: return 'Received wrong result type {} (expected one of {}).'.format( self.receivedResultType, - ', '.join(self.expectedResultTypes), + ', '.join([str(x) for x in self.expectedResultTypes]), ) @@ -46,23 +63,23 @@ class AsyncSearchHandler: LDAPObject instance """ - def __init__(self,l): + def __init__(self, l: ldap.ldapobject.LDAPObject) -> None: self._l = l - self._msgId = None + self._msgId: Optional[int] = None self._afterFirstResult = 1 def startSearch( self, - searchRoot, - searchScope, - filterStr, - attrList=None, - attrsOnly=0, - timeout=-1, - sizelimit=0, - serverctrls=None, - clientctrls=None - ): + searchRoot: str, + searchScope: int, + filterStr: str, + attrList: Optional[ListType[str]] = None, + attrsOnly: int = 0, + timeout: int = -1, + sizelimit: int = 0, + serverctrls: Optional[ListType[RequestControl]] = None, + clientctrls: Optional[ListType[RequestControl]] = None, + ) -> None: """ searchRoot See parameter base of method LDAPObject.search() @@ -89,26 +106,30 @@ def startSearch( attrList,attrsOnly,serverctrls,clientctrls,timeout,sizelimit ) self._afterFirstResult = 1 - return # startSearch() - def preProcessing(self): + def preProcessing(self) -> Any: """ Do anything you want after starting search but before receiving and processing results """ - def afterFirstResult(self): + def afterFirstResult(self) -> Any: """ Do anything you want right after successfully receiving but before processing first result """ - def postProcessing(self): + def postProcessing(self) -> Any: """ Do anything you want after receiving and processing all results """ - def processResults(self,ignoreResultsNumber=0,processResultsCount=0,timeout=-1): + def processResults( + self, + ignoreResultsNumber: int = 0, + processResultsCount: int = 0, + timeout: int = -1, + ) -> int: """ ignoreResultsNumber Don't process the first ignoreResultsNumber results. @@ -118,6 +139,9 @@ def processResults(self,ignoreResultsNumber=0,processResultsCount=0,timeout=-1): timeout See parameter timeout of ldap.LDAPObject.result() """ + if self._msgId is None: + raise RuntimeError('processResults() called without calling startSearch() first') + self.preProcessing() result_counter = 0 end_result_counter = ignoreResultsNumber+processResultsCount @@ -156,7 +180,11 @@ def processResults(self,ignoreResultsNumber=0,processResultsCount=0,timeout=-1): self.postProcessing() return partial # processResults() - def _processSingleResult(self,resultType,resultItem): + def _processSingleResult( + self, + resultType: int, + resultItem: LDAPSearchResult, + ) -> Any: """ Process single entry @@ -177,11 +205,15 @@ class List(AsyncSearchHandler): results. """ - def __init__(self,l): + def __init__(self, l: ldap.ldapobject.LDAPObject) -> None: AsyncSearchHandler.__init__(self,l) - self.allResults = [] + self.allResults: ListType[Tuple[int, LDAPSearchResult]] = [] - def _processSingleResult(self,resultType,resultItem): + def _processSingleResult( + self, + resultType: int, + resultItem: LDAPSearchResult, + ) -> None: self.allResults.append((resultType,resultItem)) @@ -190,11 +222,15 @@ class Dict(AsyncSearchHandler): Class for collecting all search results into a dictionary {dn:entry} """ - def __init__(self,l): + def __init__(self, l: ldap.ldapobject.LDAPObject) -> None: AsyncSearchHandler.__init__(self,l) - self.allEntries = {} + self.allEntries: DictType[str, LDAPEntryDict] = {} - def _processSingleResult(self,resultType,resultItem): + def _processSingleResult( + self, + resultType: int, + resultItem: LDAPSearchResult, + ) -> None: if resultType in ENTRY_RESULT_TYPES: # Search continuations are ignored dn,entry = resultItem @@ -207,12 +243,20 @@ class IndexedDict(Dict): and maintain case-sensitive equality indexes to entries """ - def __init__(self,l,indexed_attrs=None): + def __init__( + self, + l: ldap.ldapobject.LDAPObject, + indexed_attrs: Optional[Sequence[str]] = None, + ) -> None: Dict.__init__(self,l) self.indexed_attrs = indexed_attrs or () - self.index = {}.fromkeys(self.indexed_attrs,{}) + self.index: DictType[str, DictType[bytes, ListType[str]]] = {}.fromkeys(self.indexed_attrs,{}) - def _processSingleResult(self,resultType,resultItem): + def _processSingleResult( + self, + resultType: int, + resultItem: LDAPSearchResult, + ) -> None: if resultType in ENTRY_RESULT_TYPES: # Search continuations are ignored dn,entry = resultItem @@ -237,20 +281,26 @@ class FileWriter(AsyncSearchHandler): File object instance where the LDIF data is written to """ - def __init__(self,l,f,headerStr='',footerStr=''): + def __init__( + self, + l: ldap.ldapobject.LDAPObject, + f: TextIO, + headerStr: str = '', + footerStr: str = '', + ) -> None: AsyncSearchHandler.__init__(self,l) self._f = f self.headerStr = headerStr self.footerStr = footerStr - def preProcessing(self): + def preProcessing(self) -> None: """ The headerStr is written to output after starting search but before receiving and processing results. """ self._f.write(self.headerStr) - def postProcessing(self): + def postProcessing(self) -> None: """ The footerStr is written to output after receiving and processing results. @@ -270,14 +320,24 @@ class LDIFWriter(FileWriter): Either a file-like object or a ldif.LDIFWriter instance used for output """ - def __init__(self,l,writer_obj,headerStr='',footerStr=''): + def __init__( + self, + l: ldap.ldapobject.LDAPObject, + writer_obj: Union[TextIO, ldif.LDIFWriter], + headerStr: str = '', + footerStr: str = '', + ) -> None: if isinstance(writer_obj,ldif.LDIFWriter): self._ldif_writer = writer_obj else: self._ldif_writer = ldif.LDIFWriter(writer_obj) FileWriter.__init__(self,l,self._ldif_writer._output_file,headerStr,footerStr) - def _processSingleResult(self,resultType,resultItem): + def _processSingleResult( + self, + resultType: int, + resultItem: LDAPSearchResult, + ) -> None: if resultType in ENTRY_RESULT_TYPES: # Search continuations are ignored dn,entry = resultItem diff --git a/Lib/ldap/cidict.py b/Lib/ldap/cidict.py index f846fd29..423b56e1 100644 --- a/Lib/ldap/cidict.py +++ b/Lib/ldap/cidict.py @@ -8,52 +8,69 @@ import warnings from collections.abc import MutableMapping -from ldap import __version__ +from ldap.pkginfo import __version__ +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + MutableMapping as MutableMappingType, + TypeVar, + Optional, +) -class cidict(MutableMapping): +T = TypeVar('T', bound=Any) + +from typing_extensions import Self + + +class cidict(MutableMappingType[str, T]): """ Case-insensitive but case-respecting dictionary. """ __slots__ = ('_keys', '_data') - def __init__(self, default=None): - self._keys = {} - self._data = {} + def __init__(self, default: Optional[Mapping[str, T]] = None) -> None: + self._keys: Dict[str, str] = {} + self._data: Dict[str, T] = {} if default: self.update(default) # MutableMapping abstract methods - def __getitem__(self, key): + def __getitem__(self, key: str) -> T: return self._data[key.lower()] - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: T) -> None: lower_key = key.lower() self._keys[lower_key] = key self._data[lower_key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: lower_key = key.lower() del self._keys[lower_key] del self._data[lower_key] - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._keys.values()) - def __len__(self): + def __len__(self) -> int: return len(self._keys) # Specializations for performance - def __contains__(self, key): + def __contains__(self, key: object) -> bool: + if not isinstance(key, str): + return False return key.lower() in self._keys - def clear(self): + def clear(self) -> None: self._keys.clear() self._data.clear() - def copy(self): + def copy(self) -> Self: inst = self.__class__.__new__(self.__class__) inst._data = self._data.copy() inst._keys = self._keys.copy() @@ -63,12 +80,12 @@ def copy(self): # Backwards compatibility - def has_key(self, key): + def has_key(self, key: str) -> bool: """Compatibility with python-ldap 2.x""" return key in self @property - def data(self): + def data(self) -> Dict[str, T]: """Compatibility with older IterableUserDict-based implementation""" warnings.warn( 'ldap.cidict.cidict.data is an internal attribute; it may be ' + @@ -79,7 +96,7 @@ def data(self): return self._data -def strlist_minus(a,b): +def strlist_minus(a: List[str], b: List[str]) -> List[str]: """ Return list of all items in a which are not in b (a - b). a,b are supposed to be lists of case-insensitive strings. @@ -89,7 +106,7 @@ def strlist_minus(a,b): category=DeprecationWarning, stacklevel=2, ) - temp = cidict() + temp: cidict[str] = cidict() for elt in b: temp[elt] = elt result = [ @@ -100,7 +117,7 @@ def strlist_minus(a,b): return result -def strlist_intersection(a,b): +def strlist_intersection(a: List[str], b: List[str]) -> List[str]: """ Return intersection of two lists of case-insensitive strings a,b. """ @@ -109,7 +126,7 @@ def strlist_intersection(a,b): category=DeprecationWarning, stacklevel=2, ) - temp = cidict() + temp: cidict[str] = cidict() for elt in a: temp[elt] = elt result = [ @@ -120,7 +137,7 @@ def strlist_intersection(a,b): return result -def strlist_union(a,b): +def strlist_union(a: List[str], b: List[str]) -> List[str]: """ Return union of two lists of case-insensitive strings a,b. """ @@ -129,9 +146,9 @@ def strlist_union(a,b): category=DeprecationWarning, stacklevel=2, ) - temp = cidict() + temp: cidict[str] = cidict() for elt in a: temp[elt] = elt for elt in b: temp[elt] = elt - return temp.values() + return [x for x in temp.values()] diff --git a/Lib/ldap/compat.py b/Lib/ldap/compat.py deleted file mode 100644 index a287ce4e..00000000 --- a/Lib/ldap/compat.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Compatibility wrappers for Py2/Py3.""" -import warnings - -warnings.warn( - "The ldap.compat module is deprecated and will be removed in the future", - DeprecationWarning, -) - -from collections import UserDict -IterableUserDict = UserDict -from urllib.parse import quote, quote_plus, unquote, urlparse -from urllib.request import urlopen -from collections.abc import MutableMapping -from shutil import which - -def reraise(exc_type, exc_value, exc_traceback): - """Re-raise an exception given information from sys.exc_info() - - Note that unlike six.reraise, this does not support replacing the - traceback. All arguments must come from a single sys.exc_info() call. - """ - # In Python 3, all exception info is contained in one object. - raise exc_value diff --git a/Lib/ldap/constants.py b/Lib/ldap/constants.py index 0e7df6e7..5b493dcf 100644 --- a/Lib/ldap/constants.py +++ b/Lib/ldap/constants.py @@ -9,6 +9,7 @@ - Provide support for building documentation without compiling python-ldap """ +from typing import Any, List, Sequence, Optional # This module cannot import anything from ldap. # When building documentation, it is used to initialize ldap.__init__. @@ -18,7 +19,15 @@ class Constant: """Base class for a definition of an OpenLDAP constant """ - def __init__(self, name, optional=False, requirements=(), doc=None): + c_template: Optional[str] = None + + def __init__( + self, + name: str, + optional: bool = False, + requirements: Sequence[str] = (), + doc: Optional[str] = None, + ) -> None: self.name = name if optional: self_requirement = f'defined(LDAP_{self.name})' @@ -46,9 +55,9 @@ class Int(Constant): class TLSInt(Int): """Definition for a TLS integer constant -- requires HAVE_TLS""" - def __init__(self, *args, **kwargs): - requrements = list(kwargs.get('requirements', ())) - kwargs['requirements'] = ['HAVE_TLS'] + requrements + def __init__(self, *args: Any, **kwargs: Any) -> None: + requirements = list(kwargs.get('requirements', ())) + kwargs['requirements'] = ['HAVE_TLS'] + requirements super().__init__(*args, **kwargs) @@ -68,7 +77,7 @@ class Feature(Constant): ]) - def __init__(self, name, c_feature, **kwargs): + def __init__(self, name: str, c_feature: str, **kwargs: Any) -> None: super().__init__(name, **kwargs) self.c_feature = c_feature @@ -379,7 +388,7 @@ class Str(Constant): ) -def print_header(): # pragma: no cover +def print_header() -> None: # pragma: no cover """Print the C header file to standard output""" print('/*') @@ -390,9 +399,9 @@ def print_header(): # pragma: no cover print(' */') print('') - current_requirements = [] + current_requirements: List[str] = [] - def pop_requirement(): + def pop_requirement() -> None: popped = current_requirements.pop() print('#endif') print() @@ -407,7 +416,8 @@ def pop_requirement(): print() print(f'#if {requirement}') - print(definition.c_template.format(self=definition)) + if definition.c_template is not None: + print(definition.c_template.format(self=definition)) while current_requirements: pop_requirement() diff --git a/Lib/ldap/controls/__init__.py b/Lib/ldap/controls/__init__.py index 73557168..e7a10427 100644 --- a/Lib/ldap/controls/__init__.py +++ b/Lib/ldap/controls/__init__.py @@ -10,7 +10,7 @@ from ldap.pkginfo import __version__ -import _ldap +import ldap._ldap as _ldap assert _ldap.__version__==__version__, \ ImportError(f'ldap {__version__} and _ldap {_ldap.__version__} version mismatch!') @@ -18,6 +18,8 @@ from pyasn1.error import PyAsn1Error +from typing import Dict, List, Tuple, Type, Optional + __all__ = [ 'KNOWN_RESPONSE_CONTROLS', @@ -38,7 +40,7 @@ ] # response control OID to class registry -KNOWN_RESPONSE_CONTROLS = {} +KNOWN_RESPONSE_CONTROLS: Dict[str, Type["ResponseControl"]] = {} class RequestControl: @@ -54,12 +56,17 @@ class RequestControl: (here it is the BER-encoded ASN.1 control value) """ - def __init__(self,controlType=None,criticality=False,encodedControlValue=None): + def __init__( + self, + controlType: Optional[str] = None, + criticality: bool = False, + encodedControlValue: Optional[bytes] = None + ) -> None: self.controlType = controlType self.criticality = criticality self.encodedControlValue = encodedControlValue - def encodeControlValue(self): + def encodeControlValue(self) -> Optional[bytes]: """ sets class attribute encodedControlValue to the BER-encoded ASN.1 control value composed by class attributes set before @@ -77,32 +84,45 @@ class ResponseControl: sets the criticality of the received control (boolean) """ - def __init__(self,controlType=None,criticality=False): + def __init__( + self, + controlType: Optional[str] = None, + criticality: bool = False + ) -> None: self.controlType = controlType self.criticality = criticality - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: """ decodes the BER-encoded ASN.1 control value and sets the appropriate class attributes """ - self.encodedControlValue = encodedControlValue + # The type hint can be removed once class LDAPControl is removed + self.encodedControlValue: Optional[bytes] = encodedControlValue -class LDAPControl(RequestControl,ResponseControl): +class LDAPControl(RequestControl, ResponseControl): """ Base class for combined request/response controls mainly for backward-compatibility to python-ldap 2.3.x """ - def __init__(self,controlType=None,criticality=False,controlValue=None,encodedControlValue=None): + def __init__( + self, + controlType: Optional[str] = None, + criticality: bool = False, + controlValue: Optional[str] = None, + encodedControlValue: Optional[bytes] = None + ) -> None: self.controlType = controlType self.criticality = criticality self.controlValue = controlValue self.encodedControlValue = encodedControlValue -def RequestControlTuples(ldapControls): +def RequestControlTuples( + ldapControls: Optional[List[RequestControl]] + ) -> Optional[List[Tuple[Optional[str], bool, Optional[bytes]]]]: """ Return list of readily encoded 3-tuples which can be directly passed to C module _ldap @@ -120,7 +140,10 @@ def RequestControlTuples(ldapControls): return result -def DecodeControlTuples(ldapControlTuples,knownLDAPControls=None): +def DecodeControlTuples( + ldapControlTuples: Optional[List[Tuple[str, bool, bytes]]], + knownLDAPControls: Optional[Dict[str, Type[ResponseControl]]] = None, + ) -> List[ResponseControl]: """ Returns list of readily decoded ResponseControl objects diff --git a/Lib/ldap/controls/deref.py b/Lib/ldap/controls/deref.py index e5b2a7ec..22fa7494 100644 --- a/Lib/ldap/controls/deref.py +++ b/Lib/ldap/controls/deref.py @@ -18,6 +18,7 @@ from pyasn1.codec.ber import encoder,decoder from pyasn1_modules.rfc2251 import LDAPDN,AttributeDescription,AttributeDescriptionList,AttributeValue +from typing import Dict, List, Tuple, Optional DEREF_CONTROL_OID = '1.3.6.1.4.1.4203.666.5.16' @@ -28,7 +29,7 @@ # For compatibility with ASN.1 declaration in I-D AttributeList = AttributeDescriptionList -class DerefSpec(univ.Sequence): +class DerefSpec(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType( 'derefAttr', @@ -40,32 +41,32 @@ class DerefSpec(univ.Sequence): ), ) -class DerefSpecs(univ.SequenceOf): +class DerefSpecs(univ.SequenceOf): # type: ignore componentType = DerefSpec() # Response types #--------------------------------------------------------------------------- -class AttributeValues(univ.SetOf): +class AttributeValues(univ.SetOf): # type: ignore componentType = AttributeValue() -class PartialAttribute(univ.Sequence): +class PartialAttribute(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('type', AttributeDescription()), namedtype.NamedType('vals', AttributeValues()), ) -class PartialAttributeList(univ.SequenceOf): +class PartialAttributeList(univ.SequenceOf): # type: ignore componentType = PartialAttribute() tagSet = univ.Sequence.tagSet.tagImplicitly( tag.Tag(tag.tagClassContext,tag.tagFormatConstructed,0) ) -class DerefRes(univ.Sequence): +class DerefRes(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('derefAttr', AttributeDescription()), namedtype.NamedType('derefVal', LDAPDN()), @@ -73,18 +74,22 @@ class DerefRes(univ.Sequence): ) -class DerefResultControlValue(univ.SequenceOf): +class DerefResultControlValue(univ.SequenceOf): # type: ignore componentType = DerefRes() class DereferenceControl(LDAPControl): controlType = DEREF_CONTROL_OID - def __init__(self,criticality=False,derefSpecs=None): + def __init__( + self, + criticality: bool = False, + derefSpecs: Optional[Dict[str, List[str]]] = None, + ) -> None: LDAPControl.__init__(self,self.controlType,criticality) self.derefSpecs = derefSpecs or {} - def _derefSpecs(self): + def _derefSpecs(self) -> DerefSpecs: deref_specs = DerefSpecs() i = 0 for deref_attr,deref_attribute_names in self.derefSpecs.items(): @@ -98,12 +103,17 @@ def _derefSpecs(self): i += 1 return deref_specs - def encodeControlValue(self): - return encoder.encode(self._derefSpecs()) + def encodeControlValue(self) -> bytes: + return encoder.encode(self._derefSpecs()) # type: ignore - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: decodedValue,_ = decoder.decode(encodedControlValue,asn1Spec=DerefResultControlValue()) - self.derefRes = {} + # Starting from the inside out: + # The innermost dict maps attribute names to lists of attribute values + # (note: the attribute values are encoded as str, not bytes) + # The tuple pairs a DN and one of the above dicts. + # The outermost dict maps the dereferenced attribute to a list of the above tuples + self.derefRes: Dict[str, List[Tuple[str, Dict[str, List[str]]]]] = {} for deref_res in decodedValue: deref_attr,deref_val,deref_vals = deref_res[0],deref_res[1],deref_res[2] partial_attrs_dict = { diff --git a/Lib/ldap/controls/libldap.py b/Lib/ldap/controls/libldap.py index 9a102379..23b7d8c8 100644 --- a/Lib/ldap/controls/libldap.py +++ b/Lib/ldap/controls/libldap.py @@ -7,7 +7,7 @@ from ldap.pkginfo import __version__ -import _ldap +import ldap._ldap as _ldap assert _ldap.__version__==__version__, \ ImportError(f'ldap {__version__} and _ldap {_ldap.__version__} version mismatch!') @@ -15,6 +15,8 @@ from ldap.controls import RequestControl,LDAPControl,KNOWN_RESPONSE_CONTROLS +from typing import Optional, Union + class AssertionControl(RequestControl): """ @@ -26,14 +28,18 @@ class AssertionControl(RequestControl): """ controlType = ldap.CONTROL_ASSERT - def __init__(self,criticality=True,filterstr='(objectClass=*)'): + + def __init__( + self, criticality: bool = True, filterstr: str = '(objectClass=*)' + ) -> None: self.criticality = criticality self.filterstr = filterstr - def encodeControlValue(self): - return _ldap.encode_assertion_control(self.filterstr) + def encodeControlValue(self) -> bytes: + return _ldap.encode_assertion_control(self.filterstr) # type: ignore -KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_ASSERT] = AssertionControl +# FIXME: This is a request control though? +#KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_ASSERT] = AssertionControl class MatchedValuesControl(RequestControl): @@ -47,14 +53,19 @@ class MatchedValuesControl(RequestControl): controlType = ldap.CONTROL_VALUESRETURNFILTER - def __init__(self,criticality=False,filterstr='(objectClass=*)'): + def __init__( + self, + criticality: bool = False, + filterstr: str = '(objectClass=*)', + ) -> None: self.criticality = criticality self.filterstr = filterstr - def encodeControlValue(self): - return _ldap.encode_valuesreturnfilter_control(self.filterstr) + def encodeControlValue(self) -> bytes: + return _ldap.encode_valuesreturnfilter_control(self.filterstr) # type: ignore -KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_VALUESRETURNFILTER] = MatchedValuesControl +# FIXME: This is a request control though? +#KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_VALUESRETURNFILTER] = MatchedValuesControl class SimplePagedResultsControl(LDAPControl): @@ -68,14 +79,19 @@ class SimplePagedResultsControl(LDAPControl): """ controlType = ldap.CONTROL_PAGEDRESULTS - def __init__(self,criticality=False,size=None,cookie=None): + def __init__( + self, + criticality: bool = False, + size: Optional[int] = None, + cookie: Optional[Union[str, bytes]] = None + ) -> None: self.criticality = criticality self.size,self.cookie = size,cookie - def encodeControlValue(self): - return _ldap.encode_page_control(self.size,self.cookie) + def encodeControlValue(self) -> bytes: + return _ldap.encode_page_control(self.size,self.cookie) # type: ignore - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self,encodedControlValue: bytes) -> None: self.size,self.cookie = _ldap.decode_page_control(encodedControlValue) KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_PAGEDRESULTS] = SimplePagedResultsControl diff --git a/Lib/ldap/controls/openldap.py b/Lib/ldap/controls/openldap.py index 24040ed7..038afb53 100644 --- a/Lib/ldap/controls/openldap.py +++ b/Lib/ldap/controls/openldap.py @@ -10,6 +10,7 @@ from pyasn1.type import univ from pyasn1.codec.ber import decoder +from typing import List, Tuple, Union __all__ = [ 'SearchNoOpControl', @@ -26,13 +27,13 @@ class SearchNoOpControl(ValueLessRequestControl,ResponseControl): """ controlType = '1.3.6.1.4.1.4203.666.5.18' - def __init__(self,criticality=False): + def __init__(self, criticality: bool = False) -> None: self.criticality = criticality - class SearchNoOpControlValue(univ.Sequence): + class SearchNoOpControlValue(univ.Sequence): # type: ignore pass - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: decodedValue,_ = decoder.decode(encodedControlValue,asn1Spec=self.SearchNoOpControlValue()) self.resultCode = int(decodedValue[0]) self.numSearchResults = int(decodedValue[1]) @@ -42,7 +43,7 @@ def decodeControlValue(self,encodedControlValue): ldap.controls.KNOWN_RESPONSE_CONTROLS[SearchNoOpControl.controlType] = SearchNoOpControl -class SearchNoOpMixIn: +class SearchNoOpMixIn(ldap.ldapobject.SimpleLDAPObject): """ Mix-in class to be used with class LDAPObject and friends. @@ -50,7 +51,13 @@ class SearchNoOpMixIn: for easily using the no-op search control. """ - def noop_search_st(self,base,scope=ldap.SCOPE_SUBTREE,filterstr='(objectClass=*)',timeout=-1): + def noop_search_st( + self, + base: str, + scope: int = ldap.SCOPE_SUBTREE, + filterstr: str = '(objectClass=*)', + timeout: int = -1, + ) -> Union[Tuple[int, int], Tuple[None, None]]: try: msg_id = self.search_ext( base, @@ -72,8 +79,8 @@ def noop_search_st(self,base,scope=ldap.SCOPE_SUBTREE,filterstr='(objectClass=*) else: noop_srch_ctrl = [ c - for c in search_response_ctrls - if c.controlType==SearchNoOpControl.controlType + for c in search_response_ctrls or [] + if isinstance(c, SearchNoOpControl) ] if noop_srch_ctrl: return noop_srch_ctrl[0].numSearchResults,noop_srch_ctrl[0].numSearchContinuations diff --git a/Lib/ldap/controls/pagedresults.py b/Lib/ldap/controls/pagedresults.py index 12ca573d..fafa71c8 100644 --- a/Lib/ldap/controls/pagedresults.py +++ b/Lib/ldap/controls/pagedresults.py @@ -18,10 +18,13 @@ from pyasn1.codec.ber import encoder,decoder from pyasn1_modules.rfc2251 import LDAPString +from typing import Union -class PagedResultsControlValue(univ.Sequence): + +class PagedResultsControlValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('size',univ.Integer()), + # FIXME: This should be univ.OctetString, not LDAPString()? namedtype.NamedType('cookie',LDAPString()), ) @@ -29,18 +32,29 @@ class PagedResultsControlValue(univ.Sequence): class SimplePagedResultsControl(RequestControl,ResponseControl): controlType = '1.2.840.113556.1.4.319' - def __init__(self,criticality=False,size=10,cookie=''): + def __init__( + self, + criticality: bool = False, + size: int = 10, + cookie: Union[str, bytes] = '', + ) -> None: self.criticality = criticality self.size = size - self.cookie = cookie or '' - def encodeControlValue(self): + if cookie is None: + cookie = b'' + elif isinstance(cookie, str): + self.cookie = cookie.encode('utf-8') + else: + self.cookie = cookie + + def encodeControlValue(self) -> bytes: pc = PagedResultsControlValue() pc.setComponentByName('size',univ.Integer(self.size)) pc.setComponentByName('cookie',LDAPString(self.cookie)) - return encoder.encode(pc) + return encoder.encode(pc) # type: ignore - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: decodedValue,_ = decoder.decode(encodedControlValue,asn1Spec=PagedResultsControlValue()) self.size = int(decodedValue.getComponentByName('size')) self.cookie = bytes(decodedValue.getComponentByName('cookie')) diff --git a/Lib/ldap/controls/ppolicy.py b/Lib/ldap/controls/ppolicy.py index f3a8416d..46369ea1 100644 --- a/Lib/ldap/controls/ppolicy.py +++ b/Lib/ldap/controls/ppolicy.py @@ -18,8 +18,10 @@ from pyasn1.type import tag,namedtype,namedval,univ,constraint from pyasn1.codec.der import decoder +from typing import Optional -class PasswordPolicyWarning(univ.Choice): + +class PasswordPolicyWarning(univ.Choice): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('timeBeforeExpiration',univ.Integer().subtype( implicitTag=tag.Tag(tag.tagClassContext,tag.tagFormatSimple,0) @@ -30,7 +32,7 @@ class PasswordPolicyWarning(univ.Choice): ) -class PasswordPolicyError(univ.Enumerated): +class PasswordPolicyError(univ.Enumerated): # type: ignore namedValues = namedval.NamedValues( ('passwordExpired',0), ('accountLocked',1), @@ -46,7 +48,7 @@ class PasswordPolicyError(univ.Enumerated): subtypeSpec = univ.Enumerated.subtypeSpec + constraint.SingleValueConstraint(0,1,2,3,4,5,6,7,8,9) -class PasswordPolicyResponseValue(univ.Sequence): +class PasswordPolicyResponseValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.OptionalNamedType( 'warning', @@ -69,24 +71,24 @@ class PasswordPolicyControl(ValueLessRequestControl,ResponseControl): Attributes ---------- - timeBeforeExpiration : int + timeBeforeExpiration : Optional[int] The time before the password expires. - graceAuthNsRemaining : int + graceAuthNsRemaining : Optional[int] The number of grace authentications remaining. - error: int + error: Optional[int] The password and authentication errors. """ controlType = '1.3.6.1.4.1.42.2.27.8.5.1' - def __init__(self,criticality=False): + def __init__(self, criticality: bool = False) -> None: self.criticality = criticality - self.timeBeforeExpiration = None - self.graceAuthNsRemaining = None - self.error = None + self.timeBeforeExpiration: Optional[int] = None + self.graceAuthNsRemaining: Optional[int] = None + self.error: Optional[int] = None - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: ppolicyValue,_ = decoder.decode(encodedControlValue,asn1Spec=PasswordPolicyResponseValue()) warning = ppolicyValue.getComponentByName('warning') if warning.hasValue(): diff --git a/Lib/ldap/controls/psearch.py b/Lib/ldap/controls/psearch.py index 32900c8b..979711ea 100644 --- a/Lib/ldap/controls/psearch.py +++ b/Lib/ldap/controls/psearch.py @@ -21,6 +21,9 @@ from pyasn1.codec.ber import encoder,decoder from pyasn1_modules.rfc2251 import LDAPDN +from typing import Dict, List, Tuple, Optional + + #--------------------------------------------------------------------------- # Constants and classes for Persistent Search Control #--------------------------------------------------------------------------- @@ -48,7 +51,7 @@ class PersistentSearchControl(RequestControl): Entry Change Notification response control """ - class PersistentSearchControlValue(univ.Sequence): + class PersistentSearchControlValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('changeTypes',univ.Integer()), namedtype.NamedType('changesOnly',univ.Boolean()), @@ -57,26 +60,31 @@ class PersistentSearchControlValue(univ.Sequence): controlType = "2.16.840.1.113730.3.4.3" - def __init__(self,criticality=True,changeTypes=None,changesOnly=False,returnECs=True): + def __init__( + self, + criticality: bool = True, + changeTypes: Optional[List[str]] = None, + changesOnly: bool = False, + returnECs: bool = True + ) -> None: self.criticality,self.changesOnly,self.returnECs = \ criticality,changesOnly,returnECs - self.changeTypes = changeTypes or CHANGE_TYPES_INT.values() - - def encodeControlValue(self): - if not type(self.changeTypes)==type(0): - # Assume a sequence type of integers to be OR-ed - changeTypes_int = 0 - for ct in self.changeTypes: - changeTypes_int = changeTypes_int|CHANGE_TYPES_INT.get(ct,ct) - self.changeTypes = changeTypes_int + self.changeTypes = changeTypes or CHANGE_TYPES_INT.keys() + + def encodeControlValue(self) -> bytes: + # Assume a sequence type of names of integers to be OR-ed + changeTypes_int = 0 + for ct in self.changeTypes: + changeTypes_int |= CHANGE_TYPES_INT.get(ct, 0) + p = self.PersistentSearchControlValue() - p.setComponentByName('changeTypes',univ.Integer(self.changeTypes)) + p.setComponentByName('changeTypes',univ.Integer(changeTypes_int)) p.setComponentByName('changesOnly',univ.Boolean(self.changesOnly)) p.setComponentByName('returnECs',univ.Boolean(self.returnECs)) - return encoder.encode(p) + return encoder.encode(p) # type: ignore -class ChangeType(univ.Enumerated): +class ChangeType(univ.Enumerated): # type: ignore namedValues = namedval.NamedValues( ('add',1), ('delete',2), @@ -86,7 +94,7 @@ class ChangeType(univ.Enumerated): subtypeSpec = univ.Enumerated.subtypeSpec + constraint.SingleValueConstraint(1,2,4,8) -class EntryChangeNotificationValue(univ.Sequence): +class EntryChangeNotificationValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('changeType',ChangeType()), namedtype.OptionalNamedType('previousDN', LDAPDN()), @@ -111,19 +119,18 @@ class EntryChangeNotificationControl(ResponseControl): controlType = "2.16.840.1.113730.3.4.7" - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: ecncValue,_ = decoder.decode(encodedControlValue,asn1Spec=EntryChangeNotificationValue()) self.changeType = int(ecncValue.getComponentByName('changeType')) previousDN = ecncValue.getComponentByName('previousDN') if previousDN.hasValue(): - self.previousDN = str(previousDN) + self.previousDN: Optional[str] = str(previousDN) else: self.previousDN = None changeNumber = ecncValue.getComponentByName('changeNumber') if changeNumber.hasValue(): - self.changeNumber = int(changeNumber) + self.changeNumber: Optional[int] = int(changeNumber) else: self.changeNumber = None - return (self.changeType,self.previousDN,self.changeNumber) KNOWN_RESPONSE_CONTROLS[EntryChangeNotificationControl.controlType] = EntryChangeNotificationControl diff --git a/Lib/ldap/controls/pwdpolicy.py b/Lib/ldap/controls/pwdpolicy.py index 54f1a700..8dba3e51 100644 --- a/Lib/ldap/controls/pwdpolicy.py +++ b/Lib/ldap/controls/pwdpolicy.py @@ -21,7 +21,7 @@ class PasswordExpiringControl(ResponseControl): """ controlType = '2.16.840.1.113730.3.4.5' - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: self.gracePeriod = int(encodedControlValue) KNOWN_RESPONSE_CONTROLS[PasswordExpiringControl.controlType] = PasswordExpiringControl @@ -33,7 +33,7 @@ class PasswordExpiredControl(ResponseControl): """ controlType = '2.16.840.1.113730.3.4.4' - def decodeControlValue(self,encodedControlValue): - self.passwordExpired = encodedControlValue=='0' + def decodeControlValue(self, encodedControlValue: bytes) -> None: + self.passwordExpired = encodedControlValue == b'0' KNOWN_RESPONSE_CONTROLS[PasswordExpiredControl.controlType] = PasswordExpiredControl diff --git a/Lib/ldap/controls/readentry.py b/Lib/ldap/controls/readentry.py index 7b2a7e89..0fcbc6aa 100644 --- a/Lib/ldap/controls/readentry.py +++ b/Lib/ldap/controls/readentry.py @@ -12,6 +12,8 @@ from pyasn1_modules.rfc2251 import AttributeDescriptionList,SearchResultEntry +from typing import Dict, List, Optional +from ldap.types import LDAPEntryDict class ReadEntryControl(LDAPControl): """ @@ -28,16 +30,22 @@ class ReadEntryControl(LDAPControl): dictionary holding the LDAP entry """ - def __init__(self,criticality=False,attrList=None): - self.criticality,self.attrList,self.entry = criticality,attrList or [],None + def __init__( + self, + criticality: bool = False, + attrList: Optional[List[str]] = None + ) -> None: + self.criticality = criticality + self.attrList = attrList or [] + self.entry: Optional[LDAPEntryDict] = None - def encodeControlValue(self): + def encodeControlValue(self) -> bytes: attributeSelection = AttributeDescriptionList() for i in range(len(self.attrList)): attributeSelection.setComponentByPosition(i,self.attrList[i]) - return encoder.encode(attributeSelection) + return encoder.encode(attributeSelection) # type: ignore - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: decodedEntry,_ = decoder.decode(encodedControlValue,asn1Spec=SearchResultEntry()) self.dn = str(decodedEntry[0]) self.entry = {} diff --git a/Lib/ldap/controls/sessiontrack.py b/Lib/ldap/controls/sessiontrack.py index a1fb8b34..8a2ffc3c 100644 --- a/Lib/ldap/controls/sessiontrack.py +++ b/Lib/ldap/controls/sessiontrack.py @@ -36,7 +36,7 @@ class SessionTrackingControl(RequestControl): String containing a specific tracking ID """ - class SessionIdentifierControlValue(univ.Sequence): + class SessionIdentifierControlValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('sessionSourceIp',LDAPString()), namedtype.NamedType('sessionSourceName',LDAPString()), @@ -46,16 +46,22 @@ class SessionIdentifierControlValue(univ.Sequence): controlType = SESSION_TRACKING_CONTROL_OID - def __init__(self,sessionSourceIp,sessionSourceName,formatOID,sessionTrackingIdentifier): + def __init__( + self, + sessionSourceIp: str, + sessionSourceName: str, + formatOID: str, + sessionTrackingIdentifier:str, + ) -> None: # criticality MUST be false for this control self.criticality = False self.sessionSourceIp,self.sessionSourceName,self.formatOID,self.sessionTrackingIdentifier = \ sessionSourceIp,sessionSourceName,formatOID,sessionTrackingIdentifier - def encodeControlValue(self): + def encodeControlValue(self) -> bytes: s = self.SessionIdentifierControlValue() s.setComponentByName('sessionSourceIp',LDAPString(self.sessionSourceIp)) s.setComponentByName('sessionSourceName',LDAPString(self.sessionSourceName)) s.setComponentByName('formatOID',LDAPOID(self.formatOID)) s.setComponentByName('sessionTrackingIdentifier',LDAPString(self.sessionTrackingIdentifier)) - return encoder.encode(s) + return encoder.encode(s) # type: ignore diff --git a/Lib/ldap/controls/simple.py b/Lib/ldap/controls/simple.py index 96837e2a..acefea61 100644 --- a/Lib/ldap/controls/simple.py +++ b/Lib/ldap/controls/simple.py @@ -10,6 +10,8 @@ from pyasn1.type import univ from pyasn1.codec.ber import encoder,decoder +from typing import Optional + class ValueLessRequestControl(RequestControl): """ @@ -23,11 +25,13 @@ class ValueLessRequestControl(RequestControl): criticality request control """ - def __init__(self,controlType=None,criticality=False): + def __init__( + self, controlType: Optional[str] = None, criticality: bool = False + ) -> None: self.controlType = controlType self.criticality = criticality - def encodeControlValue(self): + def encodeControlValue(self) -> None: return None @@ -39,15 +43,20 @@ class OctetStringInteger(LDAPControl): Integer to be sent as OctetString """ - def __init__(self,controlType=None,criticality=False,integerValue=None): + def __init__( + self, + controlType: Optional[str] = None, + criticality: bool = False, + integerValue: Optional[int] = None + ) -> None: self.controlType = controlType self.criticality = criticality self.integerValue = integerValue - def encodeControlValue(self): + def encodeControlValue(self) -> bytes: return struct.pack('!Q',self.integerValue) - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: self.integerValue = struct.unpack('!Q',encodedControlValue)[0] @@ -61,15 +70,20 @@ class BooleanControl(LDAPControl): Boolean (True/False or 1/0) which is the boolean controlValue. """ - def __init__(self,controlType=None,criticality=False,booleanValue=False): + def __init__( + self, + controlType: Optional[str] = None, + criticality: bool = False, + booleanValue: bool = False + ) -> None: self.controlType = controlType self.criticality = criticality self.booleanValue = booleanValue - def encodeControlValue(self): - return encoder.encode(self.booleanValue,asn1Spec=univ.Boolean()) + def encodeControlValue(self) -> bytes: + return encoder.encode(self.booleanValue,asn1Spec=univ.Boolean()) # type: ignore - def decodeControlValue(self,encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: decodedValue,_ = decoder.decode(encodedControlValue,asn1Spec=univ.Boolean()) self.booleanValue = bool(int(decodedValue)) @@ -79,10 +93,11 @@ class ManageDSAITControl(ValueLessRequestControl): Manage DSA IT Control """ - def __init__(self,criticality=False): + def __init__(self, criticality: bool = False) -> None: ValueLessRequestControl.__init__(self,ldap.CONTROL_MANAGEDSAIT,criticality=False) -KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_MANAGEDSAIT] = ManageDSAITControl +# FIXME: This is a request control though? +#KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_MANAGEDSAIT] = ManageDSAITControl class RelaxRulesControl(ValueLessRequestControl): @@ -90,10 +105,11 @@ class RelaxRulesControl(ValueLessRequestControl): Relax Rules Control """ - def __init__(self,criticality=False): + def __init__(self, criticality: bool = False) -> None: ValueLessRequestControl.__init__(self,ldap.CONTROL_RELAX,criticality=False) -KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_RELAX] = RelaxRulesControl +# FIXME: This is a request control though? +#KNOWN_RESPONSE_CONTROLS[ldap.CONTROL_RELAX] = RelaxRulesControl class ProxyAuthzControl(RequestControl): @@ -105,8 +121,8 @@ class ProxyAuthzControl(RequestControl): on behalf which the server should process the request """ - def __init__(self,criticality,authzId): - RequestControl.__init__(self,ldap.CONTROL_PROXY_AUTHZ,criticality,authzId) + def __init__(self, criticality: bool, authzId: str) -> None: + RequestControl.__init__(self,ldap.CONTROL_PROXY_AUTHZ,criticality,authzId.encode('utf-8')) class AuthorizationIdentityRequestControl(ValueLessRequestControl): @@ -115,7 +131,7 @@ class AuthorizationIdentityRequestControl(ValueLessRequestControl): """ controlType = '2.16.840.1.113730.3.4.16' - def __init__(self,criticality): + def __init__(self, criticality: bool) -> None: ValueLessRequestControl.__init__(self,self.controlType,criticality) @@ -130,8 +146,8 @@ class AuthorizationIdentityResponseControl(ResponseControl): """ controlType = '2.16.840.1.113730.3.4.15' - def decodeControlValue(self,encodedControlValue): - self.authzId = encodedControlValue + def decodeControlValue(self, encodedControlValue: bytes) -> None: + self.authzId = encodedControlValue.decode('utf-8') KNOWN_RESPONSE_CONTROLS[AuthorizationIdentityResponseControl.controlType] = AuthorizationIdentityResponseControl @@ -141,6 +157,7 @@ class GetEffectiveRightsControl(RequestControl): """ Get Effective Rights Control """ + controlType = '1.3.6.1.4.1.42.2.27.9.5.2' - def __init__(self,criticality,authzId=None): - RequestControl.__init__(self,'1.3.6.1.4.1.42.2.27.9.5.2',criticality,authzId) + def __init__(self, criticality: bool, authzId: str) -> None: + RequestControl.__init__(self,self.controlType,criticality,authzId.encode('utf-8')) diff --git a/Lib/ldap/controls/sss.py b/Lib/ldap/controls/sss.py index e6ee3686..f9e61d29 100644 --- a/Lib/ldap/controls/sss.py +++ b/Lib/ldap/controls/sss.py @@ -21,6 +21,7 @@ from pyasn1.type import univ, namedtype, tag, namedval, constraint from pyasn1.codec.ber import encoder, decoder +from typing import List, Union # SortKeyList ::= SEQUENCE OF SEQUENCE { # attributeType AttributeDescription, @@ -28,7 +29,7 @@ # reverseOrder [1] BOOLEAN DEFAULT FALSE } -class SortKeyType(univ.Sequence): +class SortKeyType(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('attributeType', univ.OctetString()), namedtype.OptionalNamedType('orderingRule', @@ -40,7 +41,7 @@ class SortKeyType(univ.Sequence): implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 1)))) -class SortKeyListType(univ.SequenceOf): +class SortKeyListType(univ.SequenceOf): # type: ignore componentType = SortKeyType() @@ -53,18 +54,18 @@ class SSSRequestControl(RequestControl): def __init__( self, - criticality=False, - ordering_rules=None, + criticality: bool = False, + ordering_rules: Union[List[str], str] = [], ): RequestControl.__init__(self,self.controlType,criticality) self.ordering_rules = ordering_rules if isinstance(ordering_rules, str): ordering_rules = [ordering_rules] for rule in ordering_rules: - rule = rule.split(':') - assert len(rule) < 3, 'syntax for ordering rule: [-][:ordering-rule]' + rule_parts = rule.split(':') + assert len(rule_parts) < 3, 'syntax for ordering rule: [-][:ordering-rule]' - def asn1(self): + def asn1(self) -> SortKeyListType: p = SortKeyListType() for i, rule in enumerate(self.ordering_rules): q = SortKeyType() @@ -83,11 +84,11 @@ def asn1(self): p.setComponentByPosition(i, q) return p - def encodeControlValue(self): - return encoder.encode(self.asn1()) + def encodeControlValue(self) -> bytes: + return encoder.encode(self.asn1()) # type: ignore -class SortResultType(univ.Sequence): +class SortResultType(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('sortResult', univ.Enumerated().subtype( namedValues=namedval.NamedValues( @@ -114,10 +115,10 @@ class SortResultType(univ.Sequence): class SSSResponseControl(ResponseControl): controlType = '1.2.840.113556.1.4.474' - def __init__(self,criticality=False): + def __init__(self, criticality: bool = False): ResponseControl.__init__(self,self.controlType,criticality) - def decodeControlValue(self, encoded): + def decodeControlValue(self, encoded: bytes) -> None: p, rest = decoder.decode(encoded, asn1Spec=SortResultType()) assert not rest, 'all data could not be decoded' sort_result = p.getComponentByName('sortResult') diff --git a/Lib/ldap/controls/vlv.py b/Lib/ldap/controls/vlv.py index 5fc7ce88..f7f48b10 100644 --- a/Lib/ldap/controls/vlv.py +++ b/Lib/ldap/controls/vlv.py @@ -18,8 +18,10 @@ from pyasn1.type import univ, namedtype, tag, namedval, constraint from pyasn1.codec.ber import encoder, decoder +from typing import Optional -class ByOffsetType(univ.Sequence): + +class ByOffsetType(univ.Sequence): # type: ignore tagSet = univ.Sequence.tagSet.tagImplicitly( tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 0)) componentType = namedtype.NamedTypes( @@ -27,7 +29,7 @@ class ByOffsetType(univ.Sequence): namedtype.NamedType('contentCount', univ.Integer())) -class TargetType(univ.Choice): +class TargetType(univ.Choice): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('byOffset', ByOffsetType()), namedtype.NamedType('greaterThanOrEqual', univ.OctetString().subtype( @@ -35,7 +37,7 @@ class TargetType(univ.Choice): tag.tagFormatSimple, 1)))) -class VirtualListViewRequestType(univ.Sequence): +class VirtualListViewRequestType(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('beforeCount', univ.Integer()), namedtype.NamedType('afterCount', univ.Integer()), @@ -48,13 +50,13 @@ class VLVRequestControl(RequestControl): def __init__( self, - criticality=False, - before_count=0, - after_count=0, - offset=None, - content_count=None, - greater_than_or_equal=None, - context_id=None, + criticality: bool = False, + before_count: int = 0, + after_count: int = 0, + offset: Optional[int] = None, + content_count: Optional[int] = None, + greater_than_or_equal: Optional[str] = None, + context_id: Optional[str] = None, ): RequestControl.__init__(self,self.controlType,criticality) assert (offset is not None and content_count is not None) or \ @@ -69,7 +71,7 @@ def __init__( self.greater_than_or_equal = greater_than_or_equal self.context_id = context_id - def encodeControlValue(self): + def encodeControlValue(self) -> bytes: p = VirtualListViewRequestType() p.setComponentByName('beforeCount', self.before_count) p.setComponentByName('afterCount', self.after_count) @@ -86,12 +88,13 @@ def encodeControlValue(self): else: raise NotImplementedError p.setComponentByName('target', target) - return encoder.encode(p) + return encoder.encode(p) # type: ignore -KNOWN_RESPONSE_CONTROLS[VLVRequestControl.controlType] = VLVRequestControl +# FIXME: This is a request control though? +#KNOWN_RESPONSE_CONTROLS[VLVRequestControl.controlType] = VLVRequestControl -class VirtualListViewResultType(univ.Enumerated): +class VirtualListViewResultType(univ.Enumerated): # type: ignore namedValues = namedval.NamedValues( ('success', 0), ('operationsError', 1), @@ -106,7 +109,7 @@ class VirtualListViewResultType(univ.Enumerated): ) -class VirtualListViewResponseType(univ.Sequence): +class VirtualListViewResponseType(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType('targetPosition', univ.Integer()), namedtype.NamedType('contentCount', univ.Integer()), @@ -118,10 +121,10 @@ class VirtualListViewResponseType(univ.Sequence): class VLVResponseControl(ResponseControl): controlType = '2.16.840.1.113730.3.4.10' - def __init__(self,criticality=False): + def __init__(self, criticality: bool = False) -> None: ResponseControl.__init__(self,self.controlType,criticality) - def decodeControlValue(self,encoded): + def decodeControlValue(self, encoded: bytes) -> None: p, rest = decoder.decode(encoded, asn1Spec=VirtualListViewResponseType()) assert not rest, 'all data could not be decoded' self.targetPosition = int(p.getComponentByName('targetPosition')) @@ -130,7 +133,7 @@ def decodeControlValue(self,encoded): self.virtualListViewResult = int(virtual_list_view_result) context_id = p.getComponentByName('contextID') if context_id.hasValue(): - self.contextID = str(context_id) + self.contextID: Optional[str] = str(context_id) else: self.contextID = None # backward compatibility class attributes diff --git a/Lib/ldap/dn.py b/Lib/ldap/dn.py index a9d96846..7ab554da 100644 --- a/Lib/ldap/dn.py +++ b/Lib/ldap/dn.py @@ -5,14 +5,16 @@ """ from ldap.pkginfo import __version__ -import _ldap +import ldap._ldap as _ldap assert _ldap.__version__==__version__, \ ImportError(f'ldap {__version__} and _ldap {_ldap.__version__} version mismatch!') import ldap.functions +from typing import List, Tuple -def escape_dn_chars(s): + +def escape_dn_chars(s: str) -> str: """ Escape all DN special characters found in s with a back-slash (see RFC 4514, section 2.4) @@ -34,21 +36,28 @@ def escape_dn_chars(s): return s -def str2dn(dn,flags=0): +def str2dn(dn: str, flags: int = 0) -> List[List[Tuple[str, str, int]]]: """ This function takes a DN as string as parameter and returns a decomposed DN. It's the inverse to dn2str(). + The decomposed DN is a list of sublists, each sublist containing one or + more tuples with the attribute type, attribute value and a flag indicating + the encoding of the value. + + For example, str2dn("dc=example+ou=example,dc=com") would yield: + [[('dc', 'example', 1), ('ou', 'example', 1)], [('dc', 'com', 1)]] + flags describes the format of the dn See also the OpenLDAP man-page ldap_str2dn(3) """ if not dn: return [] - return ldap.functions._ldap_function_call(None,_ldap.str2dn,dn,flags) + return ldap.functions._ldap_function_call(None,_ldap.str2dn,dn,flags) # type: ignore -def dn2str(dn): +def dn2str(dn: List[List[Tuple[str, str, int]]]) -> str: """ This function takes a decomposed DN as parameter and returns a single string. It's the inverse to str2dn() but will always @@ -61,7 +70,7 @@ def dn2str(dn): for rdn in dn ]) -def explode_dn(dn, notypes=False, flags=0): +def explode_dn(dn: str, notypes: bool = False, flags: int = 0) -> List[str]: """ explode_dn(dn [, notypes=False [, flags=0]]) -> list @@ -87,7 +96,7 @@ def explode_dn(dn, notypes=False, flags=0): return rdn_list -def explode_rdn(rdn, notypes=False, flags=0): +def explode_rdn(rdn: str, notypes: bool = False, flags: int = 0) -> List[str]: """ explode_rdn(rdn [, notypes=0 [, flags=0]]) -> list @@ -105,7 +114,7 @@ def explode_rdn(rdn, notypes=False, flags=0): return ['='.join((atype,escape_dn_chars(avalue or ''))) for atype,avalue,dummy in rdn_decomp] -def is_dn(s,flags=0): +def is_dn(s: str, flags: int = 0) -> bool: """ Returns True if `s' can be parsed by ldap.dn.str2dn() as a distinguished host_name (DN), otherwise False is returned. diff --git a/Lib/ldap/extop/__init__.py b/Lib/ldap/extop/__init__.py index dc9aea2f..f41d3fbc 100644 --- a/Lib/ldap/extop/__init__.py +++ b/Lib/ldap/extop/__init__.py @@ -9,7 +9,18 @@ response. """ -from ldap import __version__ +from ldap.pkginfo import __version__ + +from typing import Any, Optional + + +__all__ = [ + # dds + 'RefreshRequest', + 'RefreshResponse', + # passwd + 'PasswordModifyResponse', +] class ExtendedRequest: @@ -23,14 +34,14 @@ class ExtendedRequest: (here it is the BER-encoded ASN.1 request value) """ - def __init__(self,requestName,requestValue): + def __init__(self, requestName: str, requestValue: bytes) -> None: self.requestName = requestName self.requestValue = requestValue - def __repr__(self): - return f'{self.__class__.__name__}({self.requestName},{self.requestValue})' + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.requestName},{self.requestValue!r})' - def encodedRequestValue(self): + def encodedRequestValue(self) -> bytes: """ returns the BER-encoded ASN.1 request value composed by class attributes set before @@ -48,14 +59,18 @@ class ExtendedResponse: BER-encoded ASN.1 value of the LDAPv3 extended operation response """ - def __init__(self,responseName,encodedResponseValue): + def __init__( + self, + responseName: Optional[str], + encodedResponseValue: bytes + ) -> None: self.responseName = responseName self.responseValue = self.decodeResponseValue(encodedResponseValue) - def __repr__(self): - return f'{self.__class__.__name__}({self.responseName},{self.responseValue})' + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.responseName},{self.responseValue!r})' - def decodeResponseValue(self,value): + def decodeResponseValue(self, value: bytes) -> Any: """ decodes the BER-encoded ASN.1 extended operation response value and sets the appropriate class attributes @@ -64,5 +79,5 @@ def decodeResponseValue(self,value): # Import sub-modules -from ldap.extop.dds import * +from ldap.extop.dds import RefreshRequest, RefreshResponse from ldap.extop.passwd import PasswordModifyResponse diff --git a/Lib/ldap/extop/dds.py b/Lib/ldap/extop/dds.py index 7fab0813..b9fcb15e 100644 --- a/Lib/ldap/extop/dds.py +++ b/Lib/ldap/extop/dds.py @@ -12,13 +12,15 @@ from pyasn1.codec.der import encoder,decoder from pyasn1_modules.rfc2251 import LDAPDN +from typing import Optional + class RefreshRequest(ExtendedRequest): requestName = '1.3.6.1.4.1.1466.101.119.1' defaultRequestTtl = 86400 - class RefreshRequestValue(univ.Sequence): + class RefreshRequestValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType( 'entryName', @@ -34,11 +36,16 @@ class RefreshRequestValue(univ.Sequence): ), ) - def __init__(self,requestName=None,entryName=None,requestTtl=None): + def __init__( + self, + requestName: Optional[str] = None, + entryName: Optional[str] = None, + requestTtl: Optional[int] = None + ) -> None: self.entryName = entryName self.requestTtl = requestTtl or self.defaultRequestTtl - def encodedRequestValue(self): + def encodedRequestValue(self) -> bytes: p = self.RefreshRequestValue() p.setComponentByName( 'entryName', @@ -52,13 +59,13 @@ def encodedRequestValue(self): implicitTag=tag.Tag(tag.tagClassContext,tag.tagFormatSimple,1) ) ) - return encoder.encode(p) + return encoder.encode(p) # type: ignore[no-any-return] class RefreshResponse(ExtendedResponse): responseName = '1.3.6.1.4.1.1466.101.119.1' - class RefreshResponseValue(univ.Sequence): + class RefreshResponseValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.NamedType( 'responseTtl', @@ -68,7 +75,7 @@ class RefreshResponseValue(univ.Sequence): ) ) - def decodeResponseValue(self,value): + def decodeResponseValue(self, value: bytes) -> int: respValue,_ = decoder.decode(value,asn1Spec=self.RefreshResponseValue()) self.responseTtl = int(respValue.getComponentByName('responseTtl')) return self.responseTtl diff --git a/Lib/ldap/extop/passwd.py b/Lib/ldap/extop/passwd.py index 13e9f252..594b1b7c 100644 --- a/Lib/ldap/extop/passwd.py +++ b/Lib/ldap/extop/passwd.py @@ -15,7 +15,7 @@ class PasswordModifyResponse(ExtendedResponse): responseName = None - class PasswordModifyResponseValue(univ.Sequence): + class PasswordModifyResponseValue(univ.Sequence): # type: ignore componentType = namedtype.NamedTypes( namedtype.OptionalNamedType( 'genPasswd', @@ -26,7 +26,7 @@ class PasswordModifyResponseValue(univ.Sequence): ) ) - def decodeResponseValue(self, value): + def decodeResponseValue(self, value: bytes) -> bytes: respValue, _ = decoder.decode(value, asn1Spec=self.PasswordModifyResponseValue()) self.genPasswd = bytes(respValue.getComponentByName('genPasswd')) return self.genPasswd diff --git a/Lib/ldap/filter.py b/Lib/ldap/filter.py index 782737aa..6df2eb53 100644 --- a/Lib/ldap/filter.py +++ b/Lib/ldap/filter.py @@ -7,14 +7,16 @@ - Tested with Python 2.0+ """ -from ldap import __version__ +from ldap.pkginfo import __version__ from ldap.functions import strf_secs +from typing import Iterable, Optional, Union + import time -def escape_filter_chars(assertion_value,escape_mode=0): +def escape_filter_chars(assertion_value: str, escape_mode: int = 0) -> str: """ Replace all special characters found in assertion_value by quoted notation. @@ -46,7 +48,7 @@ def escape_filter_chars(assertion_value,escape_mode=0): return s -def filter_format(filter_template,assertion_values): +def filter_format(filter_template: str, assertion_values: Iterable[str]) -> str: """ filter_template String containing %s as placeholder for assertion values. @@ -58,11 +60,11 @@ def filter_format(filter_template,assertion_values): def time_span_filter( - filterstr='', - from_timestamp=0, - until_timestamp=None, - delta_attr='modifyTimestamp', - ): + filterstr: str = '', + from_timestamp: Union[int, float] = 0, + until_timestamp: Optional[Union[int, float]] = None, + delta_attr: str = 'modifyTimestamp', + ) -> str: """ If last_run_timestr is non-zero filterstr will be extended """ diff --git a/Lib/ldap/functions.py b/Lib/ldap/functions.py index 8658db40..b6b90b54 100644 --- a/Lib/ldap/functions.py +++ b/Lib/ldap/functions.py @@ -1,10 +1,10 @@ """ -functions.py - wraps functions of module _ldap +functions.py - wraps functions of module ldap._ldap See https://www.python-ldap.org/ for details. """ -from ldap import __version__ +from ldap.pkginfo import __version__ __all__ = [ 'open','initialize','init', @@ -14,7 +14,9 @@ 'strf_secs','strp_secs', ] -import sys,pprint,time,_ldap,ldap +import sys,pprint,time +import ldap._ldap as _ldap +import ldap from calendar import timegm from ldap import LDAPError @@ -23,12 +25,20 @@ from ldap.ldapobject import LDAPObject +from typing import Any, BinaryIO, Callable, TextIO, Optional, Union + + if __debug__: # Tracing is only supported in debugging mode import traceback -def _ldap_function_call(lock,func,*args,**kwargs): +def _ldap_function_call( + lock: Optional[ldap.LDAPLock], + func: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: """ Wrapper function which locks and logs calls to function @@ -63,9 +73,14 @@ def _ldap_function_call(lock,func,*args,**kwargs): def initialize( - uri, trace_level=0, trace_file=sys.stdout, trace_stack_limit=None, - bytes_mode=None, fileno=None, **kwargs -): + uri: str, + trace_level: int = 0, + trace_file: TextIO = sys.stdout, + trace_stack_limit: int = 5, + bytes_mode: Optional[Any] = None, + fileno: Optional[Union[int, BinaryIO]] = None, + **kwargs: Any, +) -> LDAPObject: """ Return LDAPObject instance by opening LDAP connection to LDAP host specified by LDAP URL @@ -94,7 +109,7 @@ def initialize( ) -def get_option(option): +def get_option(option: int) -> Any: """ get_option(name) -> value @@ -103,16 +118,16 @@ def get_option(option): return _ldap_function_call(None,_ldap.get_option,option) -def set_option(option,invalue): +def set_option(option: int, invalue: Any) -> int: """ set_option(name, value) Set the value of an LDAP global option. """ - return _ldap_function_call(None,_ldap.set_option,option,invalue) + return _ldap_function_call(None,_ldap.set_option,option,invalue) # type: ignore -def escape_str(escape_func,s,*args): +def escape_str(escape_func: Callable[[str], str], s: str, *args: str) -> str: """ Applies escape_func() to all items of `args' and returns a string based on format string `s'. @@ -120,14 +135,14 @@ def escape_str(escape_func,s,*args): return s % tuple(escape_func(v) for v in args) -def strf_secs(secs): +def strf_secs(secs: float) -> str: """ Convert seconds since epoch to a string compliant to LDAP syntax GeneralizedTime """ return time.strftime('%Y%m%d%H%M%SZ', time.gmtime(secs)) -def strp_secs(dt_str): +def strp_secs(dt_str: str) -> int: """ Convert LDAP syntax GeneralizedTime to seconds since epoch """ diff --git a/Lib/ldap/ldapobject.py b/Lib/ldap/ldapobject.py index 7a9c17f6..3e00646f 100644 --- a/Lib/ldap/ldapobject.py +++ b/Lib/ldap/ldapobject.py @@ -1,5 +1,5 @@ """ -ldapobject.py - wraps class _ldap.LDAPObject +ldapobject.py - wraps class ldap._ldap.LDAPObject See https://www.python-ldap.org/ for details. """ @@ -7,6 +7,23 @@ from ldap.pkginfo import __version__, __author__, __license__ +from ldap.controls import RequestControl, ResponseControl + +from ldap.types import LDAPAddModList, LDAPModifyModList, LDAPEntryDict +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + List, + Sequence, + TextIO, + Tuple, + Type, + Optional, + Union, +) + __all__ = [ 'LDAPObject', 'SimpleLDAPObject', @@ -19,7 +36,9 @@ # Tracing is only supported in debugging mode import traceback -import sys,time,pprint,_ldap,ldap,ldap.sasl,ldap.functions +import sys,time,pprint +import ldap._ldap as _ldap +import ldap, ldap.sasl, ldap.functions import warnings from ldap.schema import SCHEMA_ATTRS @@ -32,7 +51,7 @@ class LDAPBytesWarning(BytesWarning): """Python 2 bytes mode warning""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( "LDAPBytesWarning is deprecated and will be removed in the future", DeprecationWarning, @@ -67,9 +86,14 @@ class SimpleLDAPObject: } def __init__( - self,uri, - trace_level=0,trace_file=None,trace_stack_limit=5,bytes_mode=None, - bytes_strictness=None, fileno=None + self, + uri: str, + trace_level: int = 0, + trace_file: Optional[TextIO] = None, + trace_stack_limit: int = 5, + bytes_mode: Optional[Any] = None, + bytes_strictness: Optional[str] = None, + fileno: Optional[Union[int, BinaryIO]] = None, ): self._trace_level = trace_level or ldap._trace_level self._trace_file = trace_file or ldap._trace_file @@ -93,20 +117,20 @@ def __init__( raise ValueError("bytes_mode is *not* supported under Python 3.") @property - def bytes_mode(self): + def bytes_mode(self) -> bool: return False @property - def bytes_strictness(self): + def bytes_strictness(self) -> str: return 'error' - def _ldap_lock(self,desc=''): + def _ldap_lock(self, desc: str = '') -> ldap.LDAPLock: if ldap.LIBLDAP_R: return ldap.LDAPLock(desc='%s within %s' %(desc,repr(self))) else: return ldap._ldap_module_lock - def _ldap_call(self,func,*args,**kwargs): + def _ldap_call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """ Wrapper method mainly for serializing calls into OpenLDAP libs and trace logs @@ -147,13 +171,13 @@ def _ldap_call(self,func,*args,**kwargs): self._trace_file.write('=> result:\n%s\n' % (pprint.pformat(result))) return result - def __setattr__(self,name,value): + def __setattr__(self, name: str, value: Any) -> None: if name in self.CLASSATTR_OPTION_MAPPING: self.set_option(self.CLASSATTR_OPTION_MAPPING[name],value) else: self.__dict__[name] = value - def __getattr__(self,name): + def __getattr__(self, name: str) -> Any: if name in self.CLASSATTR_OPTION_MAPPING: return self.get_option(self.CLASSATTR_OPTION_MAPPING[name]) elif name in self.__dict__: @@ -163,15 +187,24 @@ def __getattr__(self,name): self.__class__.__name__,repr(name) )) - def fileno(self): + def fileno(self) -> int: """ Returns file description of LDAP connection. Just a convenience wrapper for LDAPObject.get_option(ldap.OPT_DESC) """ - return self.get_option(ldap.OPT_DESC) + fd = self.get_option(ldap.OPT_DESC) + if isinstance(fd, int): + return fd + else: + return -1 - def abandon_ext(self,msgid,serverctrls=None,clientctrls=None): + def abandon_ext( + self, + msgid: int, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> None: """ abandon_ext(msgid[,serverctrls=None[,clientctrls=None]]) -> None abandon(msgid) -> None @@ -181,12 +214,17 @@ def abandon_ext(self,msgid,serverctrls=None,clientctrls=None): can expect that the result of an abandoned operation will not be returned from a future call to result(). """ - return self._ldap_call(self._l.abandon_ext,msgid,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + self._ldap_call(self._l.abandon_ext,msgid,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) - def abandon(self,msgid): + def abandon(self, msgid: int) -> None: return self.abandon_ext(msgid,None,None) - def cancel(self,cancelid,serverctrls=None,clientctrls=None): + def cancel( + self, + cancelid: int, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ cancel(cancelid[,serverctrls=None[,clientctrls=None]]) -> int Send cancels extended operation for an LDAP operation specified by cancelid. @@ -197,17 +235,28 @@ def cancel(self,cancelid,serverctrls=None,clientctrls=None): In opposite to abandon() this extended operation gets an result from the server and thus should be preferred if the server supports it. """ - return self._ldap_call(self._l.cancel,cancelid,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.cancel,cancelid,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def cancel_s(self,cancelid,serverctrls=None,clientctrls=None): + def cancel_s( + self, + cancelid: int, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> Optional[int]: msgid = self.cancel(cancelid,serverctrls,clientctrls) try: res = self.result(msgid,all=1,timeout=self.timeout) except (ldap.CANCELLED,ldap.SUCCESS): res = None - return res + return res # type: ignore - def add_ext(self,dn,modlist,serverctrls=None,clientctrls=None): + def add_ext( + self, + dn: str, + modlist: LDAPAddModList, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ add_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int This function adds a new entry with a distinguished name @@ -215,14 +264,25 @@ def add_ext(self,dn,modlist,serverctrls=None,clientctrls=None): The parameter modlist is similar to the one passed to modify(), except that no operation integer need be included in the tuples. """ - return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def add_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None): + def add_ext_s( + self, + dn: str, + modlist: LDAPAddModList, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> Tuple[Any, Any, Any, Any]: + # FIXME: The return value could be more specific msgid = self.add_ext(dn,modlist,serverctrls,clientctrls) resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout) return resp_type, resp_data, resp_msgid, resp_ctrls - def add(self,dn,modlist): + def add( + self, + dn: str, + modlist: LDAPAddModList, + ) -> int: """ add(dn, modlist) -> int This function adds a new entry with a distinguished name @@ -232,44 +292,85 @@ def add(self,dn,modlist): """ return self.add_ext(dn,modlist,None,None) - def add_s(self,dn,modlist): + def add_s( + self, + dn: str, + modlist: LDAPAddModList, + ) -> Tuple[Any, Any, Any, Any]: return self.add_ext_s(dn,modlist,None,None) - def simple_bind(self,who=None,cred=None,serverctrls=None,clientctrls=None): + def simple_bind( + self, + who: Optional[str] = None, + cred: Optional[str] = None, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ - simple_bind([who='' [,cred='']]) -> int + simple_bind([who=''[,cred=''[,serverctrls=None[,clientctrls=None]]]]) -> int """ - return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def simple_bind_s(self,who=None,cred=None,serverctrls=None,clientctrls=None): + def simple_bind_s( + self, + who: Optional[str] = None, + cred: Optional[str] = None, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> Tuple[Any, Any, Any, Any]: + # FIXME: The return value could be more specific """ - simple_bind_s([who='' [,cred='']]) -> 4-tuple + simple_bind_s([who=''[,cred=''[,serverctrls=None[,clientctrls=None]]]]) -> 4-tuple """ msgid = self.simple_bind(who,cred,serverctrls,clientctrls) resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout) return resp_type, resp_data, resp_msgid, resp_ctrls - def bind(self,who,cred,method=ldap.AUTH_SIMPLE): + def bind( + self, + who: str, + cred: str, + method: int = ldap.AUTH_SIMPLE, + ) -> int: """ bind(who, cred, method) -> int """ assert method==ldap.AUTH_SIMPLE,'Only simple bind supported in LDAPObject.bind()' return self.simple_bind(who,cred) - def bind_s(self,who,cred,method=ldap.AUTH_SIMPLE): + def bind_s( + self, + who: str, + cred: str, + method: int = ldap.AUTH_SIMPLE, + ) -> None: """ bind_s(who, cred, method) -> None """ msgid = self.bind(who,cred,method) - return self.result(msgid,all=1,timeout=self.timeout) + return self.result(msgid,all=1,timeout=self.timeout) # type: ignore - def sasl_interactive_bind_s(self,who,auth,serverctrls=None,clientctrls=None,sasl_flags=ldap.SASL_QUIET): + def sasl_interactive_bind_s( + self, + who: str, + auth: "ldap.sasl.sasl", + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + sasl_flags: int = ldap.SASL_QUIET, + ) -> None: """ sasl_interactive_bind_s(who, auth [,serverctrls=None[,clientctrls=None[,sasl_flags=ldap.SASL_QUIET]]]) -> None """ - return self._ldap_call(self._l.sasl_interactive_bind_s,who,auth,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls),sasl_flags) + return self._ldap_call(self._l.sasl_interactive_bind_s,who,auth,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls),sasl_flags) # type: ignore - def sasl_non_interactive_bind_s(self,sasl_mech,serverctrls=None,clientctrls=None,sasl_flags=ldap.SASL_QUIET,authz_id=''): + def sasl_non_interactive_bind_s( + self, + sasl_mech: str, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + sasl_flags: int = ldap.SASL_QUIET, + authz_id: str = '', + ) -> None: """ Send a SASL bind request using a non-interactive SASL method (e.g. GSSAPI, EXTERNAL) """ @@ -279,25 +380,51 @@ def sasl_non_interactive_bind_s(self,sasl_mech,serverctrls=None,clientctrls=None ) self.sasl_interactive_bind_s('',auth,serverctrls,clientctrls,sasl_flags) - def sasl_external_bind_s(self,serverctrls=None,clientctrls=None,sasl_flags=ldap.SASL_QUIET,authz_id=''): + def sasl_external_bind_s( + self, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + sasl_flags: int = ldap.SASL_QUIET, + authz_id: str = '', + ) -> None: """ Send SASL bind request using SASL mech EXTERNAL """ self.sasl_non_interactive_bind_s('EXTERNAL',serverctrls,clientctrls,sasl_flags,authz_id) - def sasl_gssapi_bind_s(self,serverctrls=None,clientctrls=None,sasl_flags=ldap.SASL_QUIET,authz_id=''): + def sasl_gssapi_bind_s( + self, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + sasl_flags: int = ldap.SASL_QUIET, + authz_id: str = '', + ) -> None: """ Send SASL bind request using SASL mech GSSAPI """ self.sasl_non_interactive_bind_s('GSSAPI',serverctrls,clientctrls,sasl_flags,authz_id) - def sasl_bind_s(self,dn,mechanism,cred,serverctrls=None,clientctrls=None): + def sasl_bind_s( + self, + dn: str, + mechanism: str, + cred: str, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> Union[int, str]: """ sasl_bind_s(dn, mechanism, cred [,serverctrls=None[,clientctrls=None]]) -> int|str """ - return self._ldap_call(self._l.sasl_bind_s,dn,mechanism,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.sasl_bind_s,dn,mechanism,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None): + def compare_ext( + self, + dn: str, + attr: str, + value: bytes, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ compare_ext(dn, attr, value [,serverctrls=None[,clientctrls=None]]) -> int compare_ext_s(dn, attr, value [,serverctrls=None[,clientctrls=None]]) -> bool @@ -315,9 +442,16 @@ def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None): A design bug in the library prevents value from containing nul characters. """ - return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None): + def compare_ext_s( + self, + dn: str, + attr: str, + value: bytes, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> bool: msgid = self.compare_ext(dn,attr,value,serverctrls,clientctrls) try: ldap_res = self.result3(msgid,all=1,timeout=self.timeout) @@ -329,13 +463,28 @@ def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None): f'Compare operation returned wrong result: {ldap_res!r}' ) - def compare(self,dn,attr,value): + def compare( + self, + dn: str, + attr: str, + value: bytes, + ) -> int: return self.compare_ext(dn,attr,value,None,None) - def compare_s(self,dn,attr,value): + def compare_s( + self, + dn: str, + attr: str, + value: bytes, + ) -> bool: return self.compare_ext_s(dn,attr,value,None,None) - def delete_ext(self,dn,serverctrls=None,clientctrls=None): + def delete_ext( + self, + dn: str, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ delete(dn) -> int delete_s(dn) -> None @@ -345,20 +494,30 @@ def delete_ext(self,dn,serverctrls=None,clientctrls=None): form returns the message id of the initiated request, and the result can be obtained from a subsequent call to result(). """ - return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def delete_ext_s(self,dn,serverctrls=None,clientctrls=None): + def delete_ext_s( + self, + dn: str, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> Tuple[Any, Any, Any, Any]: msgid = self.delete_ext(dn,serverctrls,clientctrls) resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout) return resp_type, resp_data, resp_msgid, resp_ctrls - def delete(self,dn): + def delete(self, dn: str) -> int: return self.delete_ext(dn,None,None) - def delete_s(self,dn): - return self.delete_ext_s(dn,None,None) + def delete_s(self, dn: str) -> None: + self.delete_ext_s(dn,None,None) - def extop(self,extreq,serverctrls=None,clientctrls=None): + def extop( + self, + extreq: "ldap.extop.ExtendedRequest", + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ extop(extreq[,serverctrls=None[,clientctrls=None]]]) -> int extop_s(extreq[,serverctrls=None[,clientctrls=None[,extop_resp_class=None]]]]) -> @@ -372,13 +531,25 @@ def extop(self,extreq,serverctrls=None,clientctrls=None): ldap.extop.ExtendedResponse this class is used to return an object of this class instead of a raw BER value in respvalue. """ - return self._ldap_call(self._l.extop,extreq.requestName,extreq.encodedRequestValue(),RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) - - def extop_result(self,msgid=ldap.RES_ANY,all=1,timeout=None): - resulttype,msg,msgid,respctrls,respoid,respvalue = self.result4(msgid,all=1,timeout=self.timeout,add_ctrls=1,add_intermediates=1,add_extop=1) - return (respoid,respvalue) - - def extop_s(self,extreq,serverctrls=None,clientctrls=None,extop_resp_class=None): + return self._ldap_call(self._l.extop,extreq.requestName,extreq.encodedRequestValue(),RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore + + def extop_result( + self, + msgid: int = ldap.RES_ANY, + all: int = 1, + timeout: Optional[Union[int, float]] = None, + ) -> Tuple[str, bytes]: + # FIXME: The timeout argument isn't used? + resulttype,msg,rmsgid,respctrls,respoid,respvalue = self.result4(msgid,all=1,timeout=self.timeout,add_ctrls=1,add_intermediates=1,add_extop=1) + return (respoid,respvalue) # type: ignore + + def extop_s( + self, + extreq: "ldap.extop.ExtendedRequest", + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + extop_resp_class: Optional[Type["ldap.extop.ExtendedResponse"]] = None, + ) -> Union[Tuple[str, bytes], "ldap.extop.ExtendedResponse"]: msgid = self.extop(extreq,serverctrls,clientctrls) res = self.extop_result(msgid,all=1,timeout=self.timeout) if extop_resp_class: @@ -389,18 +560,34 @@ def extop_s(self,extreq,serverctrls=None,clientctrls=None,extop_resp_class=None) else: return res - def modify_ext(self,dn,modlist,serverctrls=None,clientctrls=None): + def modify_ext( + self, + dn: str, + modlist: LDAPModifyModList, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ modify_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int """ - return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) + return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore - def modify_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None): + def modify_ext_s( + self, + dn: str, + modlist: LDAPModifyModList, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> Tuple[Any, Any, Any, Any]: msgid = self.modify_ext(dn,modlist,serverctrls,clientctrls) resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout) return resp_type, resp_data, resp_msgid, resp_ctrls - def modify(self,dn,modlist): + def modify( + self, + dn: str, + modlist: LDAPModifyModList, + ) -> int: """ modify(dn, modlist) -> int modify_s(dn, modlist) -> None @@ -423,10 +610,19 @@ def modify(self,dn,modlist): """ return self.modify_ext(dn,modlist,None,None) - def modify_s(self,dn,modlist): - return self.modify_ext_s(dn,modlist,None,None) + def modify_s( + self, + dn: str, + modlist: LDAPModifyModList, + ) -> None: + self.modify_ext_s(dn,modlist,None,None) - def modrdn(self,dn,newrdn,delold=1): + def modrdn( + self, + dn: str, + newrdn: str, + delold: int = 1, + ) -> int: """ modrdn(dn, newrdn [,delold=1]) -> int modrdn_s(dn, newrdn [,delold=1]) -> None @@ -442,24 +638,53 @@ def modrdn(self,dn,newrdn,delold=1): """ return self.rename(dn,newrdn,None,delold) - def modrdn_s(self,dn,newrdn,delold=1): + def modrdn_s( + self, + dn: str, + newrdn: str, + delold: int = 1, + ) -> None: return self.rename_s(dn,newrdn,None,delold) - def passwd(self,user,oldpw,newpw,serverctrls=None,clientctrls=None): - return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) - - def passwd_s(self, user, oldpw, newpw, serverctrls=None, clientctrls=None, extract_newpw=False): + def passwd( + self, + user: str, + oldpw: str, + newpw: str, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: + return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore + + def passwd_s( + self, + user: str, + oldpw: str, + newpw: str, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + extract_newpw: bool = False, + ) -> Tuple[str, Union[bytes, PasswordModifyResponse]]: msgid = self.passwd(user, oldpw, newpw, serverctrls, clientctrls) respoid, respvalue = self.extop_result(msgid, all=1, timeout=self.timeout) if respoid != PasswordModifyResponse.responseName: raise ldap.PROTOCOL_ERROR("Unexpected OID %s in extended response!" % respoid) - if extract_newpw and respvalue: - respvalue = PasswordModifyResponse(PasswordModifyResponse.responseName, respvalue) - - return respoid, respvalue - def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None): + if extract_newpw and respvalue: + return respoid, PasswordModifyResponse(PasswordModifyResponse.responseName, respvalue) + else: + return respoid, respvalue + + def rename( + self, + dn: str, + newrdn: str, + newsuperior: Optional[str] = None, + delold: int = 1, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ rename(dn, newrdn [, newsuperior=None [,delold=1][,serverctrls=None[,clientctrls=None]]]) -> int rename_s(dn, newrdn [, newsuperior=None] [,delold=1][,serverctrls=None[,clientctrls=None]]) -> None @@ -474,14 +699,26 @@ def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls This actually corresponds to the rename* routines in the LDAP-EXT C API library. """ - return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) - - def rename_s(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None): + return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) # type: ignore + + def rename_s( + self, + dn: str, + newrdn: str, + newsuperior: Optional[str] = None, + delold: int = 1, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> None: msgid = self.rename(dn,newrdn,newsuperior,delold,serverctrls,clientctrls) resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all=1,timeout=self.timeout) - return resp_type, resp_data, resp_msgid, resp_ctrls - def result(self,msgid=ldap.RES_ANY,all=1,timeout=None): + def result( + self, + msgid: int = ldap.RES_ANY, + all: int = 1, + timeout: Optional[Union[int, float]] = None, + ) -> Tuple[Optional[int], Optional[Any]]: """ result([msgid=RES_ANY [,all=1 [,timeout=None]]]) -> (result_type, result_data) @@ -535,11 +772,22 @@ def result(self,msgid=ldap.RES_ANY,all=1,timeout=None): resp_type, resp_data, resp_msgid = self.result2(msgid,all,timeout) return resp_type, resp_data - def result2(self,msgid=ldap.RES_ANY,all=1,timeout=None): + def result2( + self, + msgid: int = ldap.RES_ANY, + all: int = 1, + timeout: Optional[Union[int, float]] = None, + ) -> Tuple[Optional[int], Optional[Any], Optional[int]]: resp_type, resp_data, resp_msgid, resp_ctrls = self.result3(msgid,all,timeout) return resp_type, resp_data, resp_msgid - def result3(self,msgid=ldap.RES_ANY,all=1,timeout=None,resp_ctrl_classes=None): + def result3( + self, + msgid: int = ldap.RES_ANY, + all: int = 1, + timeout: Optional[Union[int, float]] = None, + resp_ctrl_classes: Optional[Dict[str, Type[ResponseControl]]] = None, + ) -> Tuple[Optional[int], Optional[Any], Optional[int], Optional[List[ResponseControl]]]: resp_type, resp_data, resp_msgid, decoded_resp_ctrls, retoid, retval = self.result4( msgid,all,timeout, add_ctrls=0,add_intermediates=0,add_extop=0, @@ -547,7 +795,16 @@ def result3(self,msgid=ldap.RES_ANY,all=1,timeout=None,resp_ctrl_classes=None): ) return resp_type, resp_data, resp_msgid, decoded_resp_ctrls - def result4(self,msgid=ldap.RES_ANY,all=1,timeout=None,add_ctrls=0,add_intermediates=0,add_extop=0,resp_ctrl_classes=None): + def result4( + self, + msgid: int = ldap.RES_ANY, + all: int = 1, + timeout: Optional[Union[int, float]] = None, + add_ctrls: int = 0, + add_intermediates: int = 0, + add_extop: int = 0, + resp_ctrl_classes: Optional[Dict[str, Type[ResponseControl]]] = None, + ) -> Tuple[Optional[int], Optional[Any], Optional[int], Optional[List[ResponseControl]], Optional[Any], Optional[Any]]: if timeout is None: timeout = self.timeout ldap_result = self._ldap_call(self._l.result4,msgid,all,timeout,add_ctrls,add_intermediates,add_extop) @@ -564,7 +821,18 @@ def result4(self,msgid=ldap.RES_ANY,all=1,timeout=None,add_ctrls=0,add_intermedi decoded_resp_ctrls = DecodeControlTuples(resp_ctrls,resp_ctrl_classes) return resp_type, resp_data, resp_msgid, decoded_resp_ctrls, resp_name, resp_value - def search_ext(self,base,scope,filterstr=None,attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0): + def search_ext( + self, + base: str, + scope: int, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + attrsonly: int = 0, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + timeout: Union[int, float] = -1, + sizelimit: int = 0, + ) -> int: """ search(base, scope [,filterstr='(objectClass=*)' [,attrlist=None [,attrsonly=0]]]) -> int search_s(base, scope [,filterstr='(objectClass=*)' [,attrlist=None [,attrsonly=0]]]) @@ -611,7 +879,7 @@ def search_ext(self,base,scope,filterstr=None,attrlist=None,attrsonly=0,serverct """ if filterstr is None: filterstr = '(objectClass=*)' - return self._ldap_call( + return self._ldap_call( # type: ignore self._l.search_ext, base,scope,filterstr, attrlist,attrsonly, @@ -620,29 +888,66 @@ def search_ext(self,base,scope,filterstr=None,attrlist=None,attrsonly=0,serverct timeout,sizelimit, ) - def search_ext_s(self,base,scope,filterstr=None,attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0): + def search_ext_s( + self, + base: str, + scope: int, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + attrsonly: int = 0, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + timeout: Union[int, float] = -1, + sizelimit: int = 0, + ) -> List[Tuple[str, LDAPEntryDict]]: msgid = self.search_ext(base,scope,filterstr,attrlist,attrsonly,serverctrls,clientctrls,timeout,sizelimit) - return self.result(msgid,all=1,timeout=timeout)[1] - - def search(self,base,scope,filterstr=None,attrlist=None,attrsonly=0): + return self.result(msgid,all=1,timeout=timeout)[1] # type: ignore + + def search( + self, + base: str, + scope: int, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + attrsonly: int = 0, + ) -> int: return self.search_ext(base,scope,filterstr,attrlist,attrsonly,None,None) - def search_s(self,base,scope,filterstr=None,attrlist=None,attrsonly=0): + def search_s( + self, + base: str, + scope: int, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + attrsonly: int = 0, + ) -> List[Tuple[str, LDAPEntryDict]]: return self.search_ext_s(base,scope,filterstr,attrlist,attrsonly,None,None,timeout=self.timeout) - def search_st(self,base,scope,filterstr=None,attrlist=None,attrsonly=0,timeout=-1): + def search_st( + self, + base: str, + scope: int, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + attrsonly: int = 0, + timeout: Union[int, float] = -1, + ) -> List[Tuple[str, LDAPEntryDict]]: return self.search_ext_s(base,scope,filterstr,attrlist,attrsonly,None,None,timeout) - def start_tls_s(self): + def start_tls_s(self) -> None: """ start_tls_s() -> None Negotiate TLS with server. The `version' attribute must have been set to VERSION3 before calling start_tls_s. If TLS could not be started an exception will be raised. """ - return self._ldap_call(self._l.start_tls_s) + self._ldap_call(self._l.start_tls_s) - def unbind_ext(self,serverctrls=None,clientctrls=None): + def unbind_ext( + self, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> int: """ unbind() -> int unbind_s() -> None @@ -662,9 +967,13 @@ def unbind_ext(self,serverctrls=None,clientctrls=None): del self._l except AttributeError: pass - return res + return res # type: ignore - def unbind_ext_s(self,serverctrls=None,clientctrls=None): + def unbind_ext_s( + self, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> None: msgid = self.unbind_ext(serverctrls,clientctrls) if msgid!=None: result = self.result3(msgid,all=1,timeout=self.timeout) @@ -675,29 +984,35 @@ def unbind_ext_s(self,serverctrls=None,clientctrls=None): self._trace_file.flush() except AttributeError: pass - return result - def unbind(self): + def unbind(self) -> int: return self.unbind_ext(None,None) - def unbind_s(self): + def unbind_s(self) -> None: return self.unbind_ext_s(None,None) - def whoami_s(self,serverctrls=None,clientctrls=None): - return self._ldap_call(self._l.whoami_s,serverctrls,clientctrls) + def whoami_s( + self, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + ) -> str: + return self._ldap_call(self._l.whoami_s,serverctrls,clientctrls) # type: ignore - def get_option(self,option): + def get_option(self, option: int) -> Any: result = self._ldap_call(self._l.get_option,option) if option==ldap.OPT_SERVER_CONTROLS or option==ldap.OPT_CLIENT_CONTROLS: result = DecodeControlTuples(result) return result - def set_option(self,option,invalue): + def set_option(self, option: int, invalue: Any) -> Any: if option==ldap.OPT_SERVER_CONTROLS or option==ldap.OPT_CLIENT_CONTROLS: invalue = RequestControlTuples(invalue) return self._ldap_call(self._l.set_option,option,invalue) - def search_subschemasubentry_s(self,dn=None): + def search_subschemasubentry_s( + self, + dn: Optional[str] = None, + ) -> Optional[str]: """ Returns the distinguished name of the sub schema sub entry for a part of a DIT specified by dn. @@ -705,7 +1020,7 @@ def search_subschemasubentry_s(self,dn=None): None as result indicates that the DN of the sub schema sub entry could not be determined. - Returns: None or text/bytes depending on bytes_mode. + Returns: None or the DN as a string. """ empty_dn = '' attrname = 'subschemaSubentry' @@ -722,8 +1037,8 @@ def search_subschemasubentry_s(self,dn=None): try: if r: e = ldap.cidict.cidict(r[0][1]) - search_subschemasubentry_dn = e.get(attrname,[None])[0] - if search_subschemasubentry_dn is None: + search_subschemasubentry_dn = e.get(attrname,[b''])[0] + if search_subschemasubentry_dn == b'': if dn: # Try to find sub schema sub entry in root DSE return self.search_subschemasubentry_s(dn=empty_dn) @@ -731,12 +1046,22 @@ def search_subschemasubentry_s(self,dn=None): # If dn was already root DSE we can return here return None else: - if search_subschemasubentry_dn is not None: - return search_subschemasubentry_dn.decode('utf-8') + dn_str: str = search_subschemasubentry_dn.decode('utf-8') + return dn_str except IndexError: return None - def read_s(self,dn,filterstr=None,attrlist=None,serverctrls=None,clientctrls=None,timeout=-1): + return None + + def read_s( + self, + dn: str, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + timeout: Union[int, float] = -1, + ) -> Optional[LDAPEntryDict]: """ Reads and returns a single entry specified by `dn'. @@ -756,7 +1081,11 @@ def read_s(self,dn,filterstr=None,attrlist=None,serverctrls=None,clientctrls=Non else: return None - def read_subschemasubentry_s(self,subschemasubentry_dn,attrs=None): + def read_subschemasubentry_s( + self, + subschemasubentry_dn: str, + attrs: Optional[List[str]] = None, + ) -> Optional[LDAPEntryDict]: """ Returns the sub schema sub entry's data """ @@ -774,7 +1103,17 @@ def read_subschemasubentry_s(self,subschemasubentry_dn,attrs=None): else: return subschemasubentry - def find_unique_entry(self,base,scope=ldap.SCOPE_SUBTREE,filterstr=None,attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1): + def find_unique_entry( + self, + base: str, + scope: int, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + attrsonly: int = 0, + serverctrls: Optional[List[RequestControl]] = None, + clientctrls: Optional[List[RequestControl]] = None, + timeout: Union[int, float] = -1, + ) -> Tuple[str, LDAPEntryDict]: """ Returns a unique entry, raises exception if not unique """ @@ -793,7 +1132,11 @@ def find_unique_entry(self,base,scope=ldap.SCOPE_SUBTREE,filterstr=None,attrlist raise NO_UNIQUE_ENTRY('No or non-unique search result for %s' % (repr(filterstr))) return r[0] - def read_rootdse_s(self, filterstr=None, attrlist=None): + def read_rootdse_s( + self, + filterstr: Optional[str] = None, + attrlist: Optional[List[str]] = None, + ) -> Optional[LDAPEntryDict]: """ convenience wrapper around read_s() for reading rootDSE """ @@ -806,15 +1149,17 @@ def read_rootdse_s(self, filterstr=None, attrlist=None): ) return ldap_rootdse # read_rootdse_s() - def get_naming_contexts(self): + def get_naming_contexts(self) -> List[bytes]: """ returns all attribute values of namingContexts in rootDSE if namingContexts is not present (not readable) then empty list is returned """ name = 'namingContexts' - return self.read_rootdse_s( - attrlist=[name] - ).get(name, []) + rootdse = self.read_rootdse_s(attrlist=[name]) + if rootdse is None: + return [] + else: + return rootdse.get(name, []) class ReconnectLDAPObject(SimpleLDAPObject): @@ -844,10 +1189,17 @@ class ReconnectLDAPObject(SimpleLDAPObject): } def __init__( - self,uri, - trace_level=0,trace_file=None,trace_stack_limit=5,bytes_mode=None, - bytes_strictness=None, retry_max=1, retry_delay=60.0, fileno=None - ): + self, + uri: str, + trace_level: int = 0, + trace_file: Optional[TextIO] = None, + trace_stack_limit: int = 5, + bytes_mode: Optional[Any] = None, + bytes_strictness: Optional[str] = None, + retry_max: int = 1, + retry_delay: float = 60.0, + fileno: Optional[Union[int, BinaryIO]] = None, + ) -> None: """ Parameters like SimpleLDAPObject.__init__() with these additional arguments: @@ -858,8 +1210,8 @@ def __init__( Time span to wait between two reconnect trials """ self._uri = uri - self._options = [] - self._last_bind = None + self._options: List[Tuple[int, Any]] = [] + self._last_bind: Optional[Tuple[Union[Callable[..., Any], str], Tuple[Any, ...], Dict[str, Any]]] = None SimpleLDAPObject.__init__(self, uri, trace_level, trace_file, trace_stack_limit, bytes_mode, bytes_strictness=bytes_strictness, @@ -870,17 +1222,20 @@ def __init__( self._start_tls = 0 self._reconnects_done = 0 - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: """return data representation for pickled object""" state = { k: v for k,v in self.__dict__.items() if k not in self.__transient_attrs__ } - state['_last_bind'] = self._last_bind[0].__name__, self._last_bind[1], self._last_bind[2] + if self._last_bind is not None and not isinstance(self._last_bind[0], str): + state['_last_bind'] = self._last_bind[0].__name__, self._last_bind[1], self._last_bind[2] + else: + state['_last_bind'] = None return state - def __setstate__(self,d): + def __setstate__(self, d: Dict[str, Any]) -> None: """set up the object from pickled data""" hardfail = d.get('bytes_mode_hardfail') if hardfail: @@ -888,33 +1243,49 @@ def __setstate__(self,d): else: d.setdefault('bytes_strictness', 'warn') self.__dict__.update(d) - self._last_bind = getattr(SimpleLDAPObject, self._last_bind[0]), self._last_bind[1], self._last_bind[2] + if self._last_bind is not None and isinstance(self._last_bind[0], str): + self._last_bind = getattr(SimpleLDAPObject, self._last_bind[0]), self._last_bind[1], self._last_bind[2] self._ldap_object_lock = self._ldap_lock() self._reconnect_lock = ldap.LDAPLock(desc='reconnect lock within %s' % (repr(self))) # XXX cannot pickle file, use default trace file self._trace_file = ldap._trace_file self.reconnect(self._uri,force=True) - def _store_last_bind(self,_method,*args,**kwargs): + def _store_last_bind( + self, + _method: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> None: self._last_bind = (_method,args,kwargs) - def _apply_last_bind(self): - if self._last_bind!=None: + def _apply_last_bind(self) -> None: + if self._last_bind is not None and callable(self._last_bind[0]): func,args,kwargs = self._last_bind - func(self,*args,**kwargs) + func(self,*args,**kwargs) # type: ignore else: # Send explicit anon simple bind request to provoke ldap.SERVER_DOWN in method reconnect() SimpleLDAPObject.simple_bind_s(self, None, None) - def _restore_options(self): + def _restore_options(self) -> None: """Restore all recorded options""" for k,v in self._options: SimpleLDAPObject.set_option(self,k,v) - def passwd_s(self,*args,**kwargs): - return self._apply_method_s(SimpleLDAPObject.passwd_s,*args,**kwargs) - - def reconnect(self,uri,retry_max=1,retry_delay=60.0,force=True): + def passwd_s( + self, + *args: Any, + **kwargs: Any, + ) -> Tuple[str, Union[bytes, PasswordModifyResponse]]: + return self._apply_method_s(SimpleLDAPObject.passwd_s,*args,**kwargs) # type: ignore + + def reconnect( + self, + uri: str, + retry_max: int = 1, + retry_delay: float = 60.0, + force: bool = True + ) -> None: # Drop and clean up old connection completely # Reconnect self._reconnect_lock.acquire() @@ -966,7 +1337,12 @@ def reconnect(self,uri,retry_max=1,retry_delay=60.0,force=True): self._reconnect_lock.release() return # reconnect() - def _apply_method_s(self,func,*args,**kwargs): + def _apply_method_s( + self, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: self.reconnect(self._uri,retry_max=self._retry_max,retry_delay=self._retry_delay,force=False) try: return func(self,*args,**kwargs) @@ -976,26 +1352,27 @@ def _apply_method_s(self,func,*args,**kwargs): # Re-try last operation return func(self,*args,**kwargs) - def set_option(self,option,invalue): + def set_option(self, option: int, invalue: Any) -> Any: self._options.append((option,invalue)) return SimpleLDAPObject.set_option(self,option,invalue) - def bind_s(self,*args,**kwargs): + # FIXME: The following method signatures could match the SimpleLDAPObject counterpart? + def bind_s(self, *args: Any, **kwargs: Any) -> Any: res = self._apply_method_s(SimpleLDAPObject.bind_s,*args,**kwargs) self._store_last_bind(SimpleLDAPObject.bind_s,*args,**kwargs) return res - def simple_bind_s(self,*args,**kwargs): + def simple_bind_s(self, *args: Any, **kwargs: Any) -> Any: res = self._apply_method_s(SimpleLDAPObject.simple_bind_s,*args,**kwargs) self._store_last_bind(SimpleLDAPObject.simple_bind_s,*args,**kwargs) return res - def start_tls_s(self,*args,**kwargs): + def start_tls_s(self, *args: Any, **kwargs: Any) -> Any: res = self._apply_method_s(SimpleLDAPObject.start_tls_s,*args,**kwargs) self._start_tls = 1 return res - def sasl_interactive_bind_s(self,*args,**kwargs): + def sasl_interactive_bind_s(self, *args: Any, **kwargs: Any) -> Any: """ sasl_interactive_bind_s(who, auth) -> None """ @@ -1003,36 +1380,36 @@ def sasl_interactive_bind_s(self,*args,**kwargs): self._store_last_bind(SimpleLDAPObject.sasl_interactive_bind_s,*args,**kwargs) return res - def sasl_bind_s(self,*args,**kwargs): + def sasl_bind_s(self, *args: Any, **kwargs: Any) -> Any: res = self._apply_method_s(SimpleLDAPObject.sasl_bind_s,*args,**kwargs) self._store_last_bind(SimpleLDAPObject.sasl_bind_s,*args,**kwargs) return res - def add_ext_s(self,*args,**kwargs): + def add_ext_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.add_ext_s,*args,**kwargs) - def cancel_s(self,*args,**kwargs): + def cancel_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.cancel_s,*args,**kwargs) - def compare_ext_s(self,*args,**kwargs): + def compare_ext_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.compare_ext_s,*args,**kwargs) - def delete_ext_s(self,*args,**kwargs): + def delete_ext_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.delete_ext_s,*args,**kwargs) - def extop_s(self,*args,**kwargs): + def extop_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.extop_s,*args,**kwargs) - def modify_ext_s(self,*args,**kwargs): + def modify_ext_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.modify_ext_s,*args,**kwargs) - def rename_s(self,*args,**kwargs): + def rename_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.rename_s,*args,**kwargs) - def search_ext_s(self,*args,**kwargs): + def search_ext_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.search_ext_s,*args,**kwargs) - def whoami_s(self,*args,**kwargs): + def whoami_s(self, *args: Any, **kwargs: Any) -> Any: return self._apply_method_s(SimpleLDAPObject.whoami_s,*args,**kwargs) diff --git a/Lib/ldap/logger.py b/Lib/ldap/logger.py index ae66bd08..839a7d47 100644 --- a/Lib/ldap/logger.py +++ b/Lib/ldap/logger.py @@ -6,13 +6,13 @@ class logging_file_class: - def __init__(self, logging_level): + def __init__(self, logging_level: int) -> None: self._logging_level = logging_level - def write(self, msg): + def write(self, msg: str) -> None: logging.log(self._logging_level, msg[:-1]) - def flush(self): + def flush(self) -> None: return logging_file_obj = logging_file_class(logging.DEBUG) diff --git a/Lib/ldap/modlist.py b/Lib/ldap/modlist.py index bf4e4819..2843fd17 100644 --- a/Lib/ldap/modlist.py +++ b/Lib/ldap/modlist.py @@ -4,17 +4,28 @@ See https://www.python-ldap.org/ for details. """ -from ldap import __version__ +from ldap.pkginfo import __version__ import ldap +from typing import List, Optional +from ldap.types import ( + LDAPEntryDict, + LDAPAddModList, + LDAPModifyModList, + LDAPModListModifyEntry, +) -def addModlist(entry,ignore_attr_types=None): + +def addModlist( + entry: LDAPEntryDict, + ignore_attr_types: Optional[List[str]] = None, + ) -> LDAPAddModList: """Build modify list for call of method LDAPObject.add()""" - ignore_attr_types = {v.lower() for v in ignore_attr_types or []} + ignore_attr_types_set = {v.lower() for v in ignore_attr_types or []} modlist = [] for attrtype, value in entry.items(): - if attrtype.lower() in ignore_attr_types: + if attrtype.lower() in ignore_attr_types_set: # This attribute type is ignored continue # Eliminate empty attr value strings in list @@ -25,8 +36,12 @@ def addModlist(entry,ignore_attr_types=None): def modifyModlist( - old_entry,new_entry,ignore_attr_types=None,ignore_oldexistent=0,case_ignore_attr_types=None -): + old_entry: LDAPEntryDict, + new_entry: LDAPEntryDict, + ignore_attr_types: Optional[List[str]] = None, + ignore_oldexistent:int = 0, + case_ignore_attr_types: Optional[List[str]] = None, +) -> LDAPModifyModList: """ Build differential modify list for calling LDAPObject.modify()/modify_s() @@ -46,15 +61,15 @@ def modifyModlist( List of attribute type names for which comparison will be made case-insensitive """ - ignore_attr_types = {v.lower() for v in ignore_attr_types or []} - case_ignore_attr_types = {v.lower() for v in case_ignore_attr_types or []} - modlist = [] + ignore_attr_types_set = {v.lower() for v in ignore_attr_types or []} + case_ignore_attr_types_set = {v.lower() for v in case_ignore_attr_types or []} + modlist: List[LDAPModListModifyEntry] = [] attrtype_lower_map = {} for a in old_entry: attrtype_lower_map[a.lower()]=a for attrtype, value in new_entry.items(): attrtype_lower = attrtype.lower() - if attrtype_lower in ignore_attr_types: + if attrtype_lower in ignore_attr_types_set: # This attribute type is ignored continue # Filter away null-strings @@ -72,7 +87,7 @@ def modifyModlist( # Replace existing attribute replace_attr_value = len(old_value)!=len(new_value) if not replace_attr_value: - if attrtype_lower in case_ignore_attr_types: + if attrtype_lower in case_ignore_attr_types_set: old_value_set = {v.lower() for v in old_value} new_value_set = {v.lower() for v in new_value} else: @@ -89,7 +104,7 @@ def modifyModlist( # Remove all attributes of old_entry which are not present # in new_entry at all for a, val in attrtype_lower_map.items(): - if a in ignore_attr_types: + if a in ignore_attr_types_set: # This attribute type is ignored continue attrtype = val diff --git a/Lib/ldap/py.typed b/Lib/ldap/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/Lib/ldap/resiter.py b/Lib/ldap/resiter.py index dc912eb3..1aaf66aa 100644 --- a/Lib/ldap/resiter.py +++ b/Lib/ldap/resiter.py @@ -3,16 +3,25 @@ See https://www.python-ldap.org/ for details. """ +import ldap from ldap.pkginfo import __version__, __author__, __license__ +from ldap.controls import ResponseControl -class ResultProcessor: +from typing import Any, List, Tuple, Iterator, Optional + +class ResultProcessor(ldap.ldapobject.LDAPObject): """ Mix-in class used with ldap.ldapopbject.LDAPObject or derived classes. """ - def allresults(self, msgid, timeout=-1, add_ctrls=0): + def allresults( + self, + msgid: int, + timeout: int = -1, + add_ctrls: int = 0, + ) -> Iterator[Tuple[Optional[int], Optional[Any], Optional[int], Optional[List[ResponseControl]]]]: """ Generator function which returns an iterator for processing all LDAP operation results of the given msgid like retrieved with LDAPObject.result3() -> 4-tuple diff --git a/Lib/ldap/sasl.py b/Lib/ldap/sasl.py index cc0a2ead..3ec646a2 100644 --- a/Lib/ldap/sasl.py +++ b/Lib/ldap/sasl.py @@ -12,7 +12,9 @@ the examples of digest_md5 and gssapi. """ -from ldap import __version__ +from ldap.pkginfo import __version__ + +from typing import Dict, Optional, Union if __debug__: # Tracing is only supported in debugging mode @@ -38,7 +40,7 @@ class sasl: overridden """ - def __init__(self, cb_value_dict, mech): + def __init__(self, cb_value_dict: Dict[int, str], mech: Union[str, bytes]) -> None: """ The (generic) base class takes a cb_value_dictionary of question-answer pairs. Questions are specified by the respective @@ -46,11 +48,18 @@ def __init__(self, cb_value_dict, mech): the SASL mechaninsm to be uesd. """ self.cb_value_dict = cb_value_dict or {} - if not isinstance(mech, bytes): - mech = mech.encode('utf-8') - self.mech = mech - - def callback(self, cb_id, challenge, prompt, defresult): + if isinstance(mech, str): + self.mech = mech.encode('utf-8') + else: + self.mech = mech + + def callback( + self, + cb_id: int, + challenge: Union[str, bytes], + prompt: Union[str, bytes], + defresult: Optional[Union[str, bytes]], + ) -> bytes: """ The callback method will be called by the sasl_bind_s() method several times. Each time it will provide the id, which @@ -72,18 +81,22 @@ def callback(self, cb_id, challenge, prompt, defresult): # The following print command might be useful for debugging # new sasl mechanisms. So it is left here - cb_result = self.cb_value_dict.get(cb_id, defresult) or '' + cb_result: Optional[Union[str, bytes]] = self.cb_value_dict.get(cb_id) + if cb_result is None: + cb_result = defresult or '' + if __debug__: if _trace_level >= 1: - _trace_file.write("*** id=%d, challenge=%s, prompt=%s, defresult=%s\n-> %s\n" % ( + _trace_file.write("*** id=%d, challenge=%r, prompt=%r, defresult=%s\n-> %s\n" % ( cb_id, challenge, prompt, repr(defresult), - repr(self.cb_value_dict.get(cb_result)) + repr(self.cb_value_dict.get(cb_id)) )) - if not isinstance(cb_result, bytes): - cb_result = cb_result.encode('utf-8') + + if isinstance(cb_result, str): + return cb_result.encode('utf-8') return cb_result @@ -92,7 +105,7 @@ class cram_md5(sasl): This class handles SASL CRAM-MD5 authentication. """ - def __init__(self, authc_id, password, authz_id=""): + def __init__(self, authc_id: str, password: str, authz_id: str = "") -> None: auth_dict = { CB_AUTHNAME: authc_id, CB_PASS: password, @@ -106,7 +119,7 @@ class digest_md5(sasl): This class handles SASL DIGEST-MD5 authentication. """ - def __init__(self, authc_id, password, authz_id=""): + def __init__(self, authc_id: str, password: str, authz_id: str = "") -> None: auth_dict = { CB_AUTHNAME: authc_id, CB_PASS: password, @@ -120,7 +133,7 @@ class gssapi(sasl): This class handles SASL GSSAPI (i.e. Kerberos V) authentication. """ - def __init__(self, authz_id=""): + def __init__(self, authz_id: str = "") -> None: sasl.__init__(self, {CB_USER: authz_id}, "GSSAPI") @@ -130,5 +143,5 @@ class external(sasl): (i.e. X.509 client certificate) """ - def __init__(self, authz_id=""): + def __init__(self, authz_id: str = "") -> None: sasl.__init__(self, {CB_USER: authz_id}, "EXTERNAL") diff --git a/Lib/ldap/schema/__init__.py b/Lib/ldap/schema/__init__.py index 2349ae21..2d824bb1 100644 --- a/Lib/ldap/schema/__init__.py +++ b/Lib/ldap/schema/__init__.py @@ -4,7 +4,12 @@ See https://www.python-ldap.org/ for details. """ -from ldap import __version__ +from ldap.pkginfo import __version__ from ldap.schema.subentry import SubSchema,SCHEMA_ATTRS,SCHEMA_CLASS_MAPPING,SCHEMA_ATTR_MAPPING,urlfetch from ldap.schema.models import * + + +__all__ = [ + 'SCHEMA_ATTRS', +] diff --git a/Lib/ldap/schema/models.py b/Lib/ldap/schema/models.py index 3d9322c0..776d3079 100644 --- a/Lib/ldap/schema/models.py +++ b/Lib/ldap/schema/models.py @@ -6,10 +6,35 @@ import sys -import ldap.cidict +import collections +from ldap.cidict import cidict from collections import UserDict -from ldap.schema.tokenizer import split_tokens,extract_tokens +from ldap.schema.tokenizer import parse_tokens, split_tokens + +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + KeysView, + List, + Tuple, + MutableMapping, + Optional, + Union, +) + +from ldap.types import LDAPEntryDict +if TYPE_CHECKING: + EntryBase = UserDict[str, List[bytes]] + import ldap.schema.subentry +else: + # Python <= 3.8 compatibility + EntryBase = UserDict + +from ldap.schema.tokenizer import LDAPTokenDict, LDAPTokenDictValue +from ldap.schema.subentry import SCHEMA_CLASS_MAPPING, SCHEMA_ATTR_MAPPING + NOT_HUMAN_READABLE_LDAP_SYNTAXES = { '1.3.6.1.4.1.1466.115.121.1.4', # Audio @@ -34,61 +59,83 @@ class SchemaElement: String which contains the schema element description to be parsed. (Bytestrings are decoded using UTF-8) + Instance attributes: + + oid + OID assigned to the schema element + names + All NAMEs of the schema element (tuple of strings) + desc + Description text (DESC) of the schema element (string, or None if missing) + Class attributes: schema_attribute LDAP attribute type containing a certain schema element description - token_defaults - Dictionary internally used by the schema element parser - containing the defaults for certain schema description key-words + known_tokens + List used internally containing the valid tokens """ - token_defaults = { - 'DESC':(None,), - } + schema_attribute = 'SchemaElement (base class)' + known_tokens = ['DESC', 'NAME'] - def __init__(self,schema_element_str=None): + def __init__(self, schema_element_str: Optional[Union[str, bytes]] = None) -> None: if isinstance(schema_element_str, bytes): - schema_element_str = schema_element_str.decode('utf-8') - if schema_element_str: - l = split_tokens(schema_element_str) - self.set_id(l[1]) - d = extract_tokens(l,self.token_defaults) - self._set_attrs(l,d) - - def _set_attrs(self,l,d): - self.desc = d['DESC'][0] - return + schema_element_string = schema_element_str.decode('utf-8') + elif isinstance(schema_element_str, str): + schema_element_string = schema_element_str + elif schema_element_str is None: + return + else: + raise TypeError("schema_element_str must be str/bytes, was %r" % schema_element_str) + + if schema_element_string == '': + return + + tokens = split_tokens(schema_element_string) + oid, schema_element_attributes = parse_tokens(tokens, self.known_tokens) + self.set_id(oid) + self._set_attrs(tokens, schema_element_attributes) - def set_id(self,element_id): + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + self.desc = d.get('DESC', (None,))[0] + self.names = d.get('NAME', ()) + + def set_id(self, element_id: str) -> None: self.oid = element_id - def get_id(self): + def get_id(self) -> str: return self.oid - def key_attr(self,key,value,quoted=0): - assert value is None or type(value)==str,TypeError("value has to be of str, was %r" % value) - if value: - if quoted: - return " {} '{}'".format(key,value.replace("'","\\'")) - else: - return f" {key} {value}" - else: + def key_attr(self, key: str, value: Optional[str], quoted: int = 0) -> str: + if value is None: return "" + elif not isinstance(value, str): + raise TypeError("value has to be of str, was %r" % value) + elif value == "": + return "" + elif quoted: + return " {} '{}'".format(key,value.replace("'","\\'")) + else: + return f" {key} {value}" - def key_list(self,key,values,sep=' ',quoted=0): - assert type(values)==tuple,TypeError("values has to be a tuple, was %r" % values) + def key_list( + self, key: str, values: Tuple[str, ...], sep: str = ' ', quoted: int = 0 + ) -> str: + assert isinstance(values, tuple),TypeError("values has to be a tuple, was %r" % values) if not values: return '' + if quoted: quoted_values = [ "'%s'" % value.replace("'","\\'") for value in values ] else: - quoted_values = values - if len(values)==1: + quoted_values = list(values) + + if len(quoted_values)==1: return ' {} {}'.format(key,quoted_values[0]) else: return ' {} ( {} )'.format(key,sep.join(quoted_values)) - def __str__(self): + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_attr('DESC',self.desc,quoted=1)) return '( %s )' % ''.join(result) @@ -110,8 +157,8 @@ class ObjectClass(SchemaElement): desc Description text (DESC) of the object class (string, or None if missing) obsolete - Integer flag (0 or 1) indicating whether the object class is marked - as OBSOLETE in the schema + Boolean indicating whether the object class is marked as OBSOLETE in the + schema must NAMEs or OIDs of all attributes an entry of the object class must have (tuple of strings) @@ -134,53 +181,55 @@ class ObjectClass(SchemaElement): element """ schema_attribute = 'objectClasses' - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'SUP':(()), - 'STRUCTURAL':None, - 'AUXILIARY':None, - 'ABSTRACT':None, - 'MUST':(()), - 'MAY':(), - 'X-ORIGIN':() - } - - def _set_attrs(self,l,d): - self.obsolete = d['OBSOLETE']!=None - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.must = d['MUST'] - self.may = d['MAY'] - self.x_origin = d['X-ORIGIN'] + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'SUP', + 'STRUCTURAL', + 'AUXILIARY', + 'ABSTRACT', + 'MUST', + 'MAY', + 'X-ORIGIN', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l, d) + self.obsolete = 'OBSOLETE' in d + self.must = d.get('MUST', ()) + self.may = d.get('MAY', ()) + self.x_origin = d.get('X-ORIGIN', ()) + # Default is STRUCTURAL, see RFC2552 or draft-ietf-ldapbis-syntaxes self.kind = 0 - if d['ABSTRACT']!=None: + if 'ABSTRACT' in d: self.kind = 1 - elif d['AUXILIARY']!=None: + elif 'AUXILIARY' in d: self.kind = 2 - if self.kind==0 and not d['SUP'] and self.oid!='2.5.6.0': + + if self.kind==0 and len(d.get('SUP', ())) == 0 and self.oid!='2.5.6.0': # STRUCTURAL object classes are sub-classes of 'top' by default - self.sup = ('top',) + self.sup: Tuple[str, ...] = ('top',) else: - self.sup = d['SUP'] - return + self.sup = d.get('SUP', ()) - def __str__(self): + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) result.append(self.key_list('SUP',self.sup,sep=' $ ')) - result.append({0:'',1:' OBSOLETE'}[self.obsolete]) + result.append({False:'',True:' OBSOLETE'}[self.obsolete]) result.append({0:' STRUCTURAL',1:' ABSTRACT',2:' AUXILIARY'}[self.kind]) result.append(self.key_list('MUST',self.must,sep=' $ ')) result.append(self.key_list('MAY',self.may,sep=' $ ')) result.append(self.key_list('X-ORIGIN',self.x_origin,quoted=1)) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[ObjectClass.schema_attribute] = ObjectClass +SCHEMA_ATTR_MAPPING[ObjectClass] = ObjectClass.schema_attribute -AttributeUsage = ldap.cidict.cidict({ +AttributeUsage = cidict({ 'userApplication':0, # work-around for non-compliant schema 'userApplications':0, 'directoryOperation':1, @@ -205,16 +254,15 @@ class AttributeType(SchemaElement): desc Description text (DESC) of the attribute type (string, or None if missing) obsolete - Integer flag (0 or 1) indicating whether the attribute type is marked - as OBSOLETE in the schema + Boolean flag indicating whether the attribute type is marked as OBSOLETE in + the schema single_value - Integer flag (0 or 1) indicating whether the attribute must - have only one value + Boolean flag indicating whether the attribute must have only one value syntax OID of the LDAP syntax assigned to the attribute type no_user_mod - Integer flag (0 or 1) indicating whether the attribute is modifiable - by a client application + Boolean flag indicating whether the attribute is modifiable by a client + application equality NAME or OID of the matching rule used for checking whether attribute values are equal (string, or None if missing) @@ -241,35 +289,35 @@ class AttributeType(SchemaElement): element """ schema_attribute = 'attributeTypes' - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'SUP':(()), - 'EQUALITY':(None,), - 'ORDERING':(None,), - 'SUBSTR':(None,), - 'SYNTAX':(None,), - 'SINGLE-VALUE':None, - 'COLLECTIVE':None, - 'NO-USER-MODIFICATION':None, - 'USAGE':('userApplications',), - 'X-ORIGIN':(), - 'X-ORDERED':(None,), - } - - def _set_attrs(self,l,d): - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.obsolete = d['OBSOLETE']!=None - self.sup = d['SUP'] - self.equality = d['EQUALITY'][0] - self.ordering = d['ORDERING'][0] - self.substr = d['SUBSTR'][0] - self.x_origin = d['X-ORIGIN'] - self.x_ordered = d['X-ORDERED'][0] + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'SUP', + 'EQUALITY', + 'ORDERING', + 'SUBSTR', + 'SYNTAX', + 'SINGLE-VALUE', + 'COLLECTIVE', + 'NO-USER-MODIFICATION', + 'USAGE', + 'X-ORIGIN', + 'X-ORDERED', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l, d) + self.obsolete = 'OBSOLETE' in d + self.sup = d.get('SUP', ()) + self.equality = d.get('EQUALITY', (None,))[0] + self.ordering = d.get('ORDERING', (None,))[0] + self.substr = d.get('SUBSTR', (None,))[0] + self.x_origin = d.get('X-ORIGIN', ()) + self.x_ordered = d.get('X-ORDERED', (None,))[0] + try: - syntax = d['SYNTAX'][0] + syntax = d.get('SYNTAX', (None,))[0] except IndexError: self.syntax = None self.syntax_len = None @@ -279,22 +327,24 @@ def _set_attrs(self,l,d): self.syntax_len = None else: try: - self.syntax,syntax_len = d['SYNTAX'][0].split("{") + self.syntax,syntax_len = syntax.split("{") except ValueError: - self.syntax = d['SYNTAX'][0] + self.syntax = syntax self.syntax_len = None for i in l: if i.startswith("{") and i.endswith("}"): self.syntax_len = int(i[1:-1]) else: self.syntax_len = int(syntax_len[:-1]) - self.single_value = d['SINGLE-VALUE']!=None - self.collective = d['COLLECTIVE']!=None - self.no_user_mod = d['NO-USER-MODIFICATION']!=None - self.usage = AttributeUsage.get(d['USAGE'][0],0) - return - - def __str__(self): + self.single_value = 'SINGLE-VALUE' in d + self.collective = 'COLLECTIVE' in d + self.no_user_mod = 'NO-USER-MODIFICATION' in d + self.usage = 0 + usage = d.get('USAGE', (None,))[0] + if usage is not None: + self.usage = AttributeUsage.get(usage, 0) + + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) @@ -304,7 +354,7 @@ def __str__(self): result.append(self.key_attr('ORDERING',self.ordering)) result.append(self.key_attr('SUBSTR',self.substr)) result.append(self.key_attr('SYNTAX',self.syntax)) - if self.syntax_len!=None: + if self.syntax_len is not None: result.append(('{%d}' % (self.syntax_len))*(self.syntax_len>0)) result.append({0:'',1:' SINGLE-VALUE'}[self.single_value]) result.append({0:'',1:' COLLECTIVE'}[self.collective]) @@ -321,6 +371,9 @@ def __str__(self): result.append(self.key_attr('X-ORDERED',self.x_ordered,quoted=1)) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[AttributeType.schema_attribute] = AttributeType +SCHEMA_ATTR_MAPPING[AttributeType] = AttributeType.schema_attribute + class LDAPSyntax(SchemaElement): """ @@ -328,30 +381,32 @@ class LDAPSyntax(SchemaElement): oid OID assigned to the LDAP syntax + names + All NAMEs of the LDAP syntax (tuple of strings) desc Description text (DESC) of the LDAP syntax (string, or None if missing) not_human_readable - Integer flag (0 or 1) indicating whether the attribute type is marked - as not human-readable (X-NOT-HUMAN-READABLE) + Boolean flag indicating whether the attribute type is marked as not + human-readable (X-NOT-HUMAN-READABLE) """ schema_attribute = 'ldapSyntaxes' - token_defaults = { - 'DESC':(None,), - 'X-NOT-HUMAN-READABLE':(None,), - 'X-BINARY-TRANSFER-REQUIRED':(None,), - 'X-SUBST':(None,), - } - - def _set_attrs(self,l,d): - self.desc = d['DESC'][0] - self.x_subst = d['X-SUBST'][0] + known_tokens = [ + 'NAME', + 'DESC', + 'X-NOT-HUMAN-READABLE', + 'X-BINARY-TRANSFER-REQUIRED', + 'X-SUBST', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l, d) + self.x_subst = d.get('X-SUBST', (None,))[0] self.not_human_readable = \ self.oid in NOT_HUMAN_READABLE_LDAP_SYNTAXES or \ - d['X-NOT-HUMAN-READABLE'][0]=='TRUE' - self.x_binary_transfer_required = d['X-BINARY-TRANSFER-REQUIRED'][0]=='TRUE' - return + d.get('X-NOT-HUMAN-READABLE', (None,))[0] == 'TRUE' + self.x_binary_transfer_required = d.get('X-BINARY-TRANSFER-REQUIRED', (None,))[0] == 'TRUE' - def __str__(self): + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_attr('DESC',self.desc,quoted=1)) result.append(self.key_attr('X-SUBST',self.x_subst,quoted=1)) @@ -360,6 +415,9 @@ def __str__(self): ) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[LDAPSyntax.schema_attribute] = LDAPSyntax +SCHEMA_ATTR_MAPPING[LDAPSyntax] = LDAPSyntax.schema_attribute + class MatchingRule(SchemaElement): """ @@ -377,28 +435,27 @@ class MatchingRule(SchemaElement): desc Description text (DESC) of the matching rule obsolete - Integer flag (0 or 1) indicating whether the matching rule is marked - as OBSOLETE in the schema + Boolean flag indicating whether the matching rule is marked as OBSOLETE in + the schema syntax OID of the LDAP syntax this matching rule is usable with (string, or None if missing) """ schema_attribute = 'matchingRules' - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'SYNTAX':(None,), - } - - def _set_attrs(self,l,d): - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.obsolete = d['OBSOLETE']!=None - self.syntax = d['SYNTAX'][0] + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'SYNTAX', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l, d) + self.obsolete = 'OBSOLETE' in d + self.syntax = d.get('SYNTAX', (None,))[0] return - def __str__(self): + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) @@ -406,6 +463,9 @@ def __str__(self): result.append(self.key_attr('SYNTAX',self.syntax)) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[MatchingRule.schema_attribute] = MatchingRule +SCHEMA_ATTR_MAPPING[MatchingRule] = MatchingRule.schema_attribute + class MatchingRuleUse(SchemaElement): """ @@ -423,28 +483,27 @@ class MatchingRuleUse(SchemaElement): desc Description text (DESC) of the matching rule (string, or None if missing) obsolete - Integer flag (0 or 1) indicating whether the matching rule is marked + Boolean flag indicating whether the matching rule is marked as OBSOLETE in the schema applies NAMEs or OIDs of attribute types for which this matching rule is used (tuple of strings) """ schema_attribute = 'matchingRuleUse' - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'APPLIES':(()), - } - - def _set_attrs(self,l,d): - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.obsolete = d['OBSOLETE']!=None - self.applies = d['APPLIES'] + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'APPLIES', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l, d) + self.obsolete = 'OBSOLETE' in d + self.applies = d.get('APPLIES', ()) return - def __str__(self): + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) @@ -452,6 +511,9 @@ def __str__(self): result.append(self.key_list('APPLIES',self.applies,sep=' $ ')) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[MatchingRuleUse.schema_attribute] = MatchingRuleUse +SCHEMA_ATTR_MAPPING[MatchingRuleUse] = MatchingRuleUse.schema_attribute + class DITContentRule(SchemaElement): """ @@ -470,7 +532,7 @@ class DITContentRule(SchemaElement): Description text (DESC) of the DIT content rule (string, or None if missing) obsolete - Integer flag (0 or 1) indicating whether the DIT content rule is marked + Boolean flag indicating whether the DIT content rule is marked as OBSOLETE in the schema aux NAMEs or OIDs of all auxiliary object classes usable in an entry of the @@ -490,27 +552,25 @@ class DITContentRule(SchemaElement): object class. (tuple of strings) """ schema_attribute = 'dITContentRules' - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'AUX':(()), - 'MUST':(()), - 'MAY':(()), - 'NOT':(()), - } - - def _set_attrs(self,l,d): - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.obsolete = d['OBSOLETE']!=None - self.aux = d['AUX'] - self.must = d['MUST'] - self.may = d['MAY'] - self.nots = d['NOT'] - return - - def __str__(self): + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'AUX', + 'MUST', + 'MAY', + 'NOT', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l ,d) + self.obsolete = 'OBSOLETE' in d + self.aux = d.get('AUX', ()) + self.must = d.get('MUST', ()) + self.may = d.get('MAY', ()) + self.nots = d.get('NOT', ()) + + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) @@ -521,6 +581,9 @@ def __str__(self): result.append(self.key_list('NOT',self.nots,sep=' $ ')) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[DITContentRule.schema_attribute] = DITContentRule +SCHEMA_ATTR_MAPPING[DITContentRule] = DITContentRule.schema_attribute + class DITStructureRule(SchemaElement): """ @@ -539,39 +602,37 @@ class DITStructureRule(SchemaElement): Description text (DESC) of the DIT structure rule (string, or None if missing) obsolete - Integer flag (0 or 1) indicating whether the DIT content rule is marked + Boolean flag indicating whether the DIT content rule is marked as OBSOLETE in the schema form - NAMEs or OIDs of associated name forms (tuple of strings) + NAMEs or OIDs of associated name forms (string) sup NAMEs or OIDs of allowed structural object classes of superior entries in the DIT (tuple of strings) """ schema_attribute = 'dITStructureRules' - - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'FORM':(None,), - 'SUP':(()), - } - - def set_id(self,element_id): + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'FORM', + 'SUP', + ] + + def set_id(self, element_id: str) -> None: self.ruleid = element_id - def get_id(self): + def get_id(self) -> str: return self.ruleid - def _set_attrs(self,l,d): - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.obsolete = d['OBSOLETE']!=None - self.form = d['FORM'][0] - self.sup = d['SUP'] + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l ,d) + self.obsolete = 'OBSOLETE' in d + self.form = d.get('FORM', (None,))[0] + self.sup = d.get('SUP', ()) return - def __str__(self): + def __str__(self) -> str: result = [str(self.ruleid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) @@ -580,6 +641,9 @@ def __str__(self): result.append(self.key_list('SUP',self.sup,sep=' $ ')) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[DITStructureRule.schema_attribute] = DITStructureRule +SCHEMA_ATTR_MAPPING[DITStructureRule] = DITStructureRule.schema_attribute + class NameForm(SchemaElement): """ @@ -597,8 +661,8 @@ class NameForm(SchemaElement): desc Description text (DESC) of the name form (string, or None if missing) obsolete - Integer flag (0 or 1) indicating whether the name form is marked - as OBSOLETE in the schema + Boolean flag indicating whether the name form is marked as OBSOLETE in the + schema form NAMEs or OIDs of associated name forms (tuple of strings) oc @@ -611,25 +675,23 @@ class NameForm(SchemaElement): (tuple of strings) """ schema_attribute = 'nameForms' - token_defaults = { - 'NAME':(()), - 'DESC':(None,), - 'OBSOLETE':None, - 'OC':(None,), - 'MUST':(()), - 'MAY':(()), - } - - def _set_attrs(self,l,d): - self.names = d['NAME'] - self.desc = d['DESC'][0] - self.obsolete = d['OBSOLETE']!=None - self.oc = d['OC'][0] - self.must = d['MUST'] - self.may = d['MAY'] - return - - def __str__(self): + known_tokens = [ + 'NAME', + 'DESC', + 'OBSOLETE', + 'OC', + 'MUST', + 'MAY', + ] + + def _set_attrs(self, l: List[str], d: LDAPTokenDict) -> None: + super()._set_attrs(l ,d) + self.obsolete = 'OBSOLETE' in d + self.oc = d.get('OC', (None,))[0] + self.must = d.get('MUST', ()) + self.may = d.get('MAY', ()) + + def __str__(self) -> str: result = [str(self.oid)] result.append(self.key_list('NAME',self.names,quoted=1)) result.append(self.key_attr('DESC',self.desc,quoted=1)) @@ -639,8 +701,11 @@ def __str__(self): result.append(self.key_list('MAY',self.may,sep=' $ ')) return '( %s )' % ''.join(result) +SCHEMA_CLASS_MAPPING[NameForm.schema_attribute] = NameForm +SCHEMA_ATTR_MAPPING[NameForm] = NameForm.schema_attribute + -class Entry(UserDict): +class Entry(EntryBase): """ Schema-aware implementation of an LDAP entry class. @@ -648,15 +713,19 @@ class Entry(UserDict): the OID as key. """ - def __init__(self,schema,dn,entry): - self._keytuple2attrtype = {} - self._attrtype2keytuple = {} + def __init__(self, schema: "ldap.schema.subentry.SubSchema", dn: str, entry: LDAPEntryDict) -> None: + self._keytuple2attrtype: Dict[Tuple[str, ...], str] = {} + self._attrtype2keytuple: Dict[str, Tuple[str, ...]] = {} + # This class wants to act like it's a string-keyed dict, but under the + # hood it uses the tuple of OID and sub-types of an attribute type + # as the key, so we can't use the self.data dict and stay type-safe. + self._data: Dict[Tuple[str, ...], List[bytes]] = {} self._s = schema self.dn = dn super().__init__() self.update(entry) - def _at2key(self,nameoroid): + def _at2key(self, nameoroid: str) -> Tuple[str, ...]: """ Return tuple of OID and all sub-types of attribute type specified in nameoroid. @@ -666,55 +735,67 @@ def _at2key(self,nameoroid): return self._attrtype2keytuple[nameoroid] except KeyError: # Mapping has to be constructed - oid = self._s.getoid(ldap.schema.AttributeType,nameoroid) + oid = self._s.getoid(AttributeType,nameoroid) l = nameoroid.lower().split(';') l[0] = oid t = tuple(l) self._attrtype2keytuple[nameoroid] = t return t - def update(self,dict): + def update(self, dict: MutableMapping[str, List[bytes]]) -> None: # type: ignore for key, value in dict.items(): self[key] = value - def __contains__(self,nameoroid): - return self._at2key(nameoroid) in self.data + def __contains__(self, nameoroid: object) -> bool: + if not isinstance(nameoroid, str): + return False + return self._at2key(nameoroid) in self._data - def __getitem__(self,nameoroid): - return self.data[self._at2key(nameoroid)] + def __getitem__(self, nameoroid: object) -> List[bytes]: + if not isinstance(nameoroid, str): + raise KeyError + k = self._at2key(nameoroid) + return self._data[k] - def __setitem__(self,nameoroid,attr_values): + def __setitem__(self, nameoroid: object, attr_values: List[bytes]) -> None: + if not isinstance(nameoroid, str): + raise KeyError k = self._at2key(nameoroid) self._keytuple2attrtype[k] = nameoroid - self.data[k] = attr_values + self._data[k] = attr_values - def __delitem__(self,nameoroid): + def __delitem__(self, nameoroid: object) -> None: + if not isinstance(nameoroid, str): + raise KeyError k = self._at2key(nameoroid) - del self.data[k] + del self._data[k] del self._attrtype2keytuple[nameoroid] del self._keytuple2attrtype[k] - def has_key(self,nameoroid): + def has_key(self, nameoroid: str) -> bool: k = self._at2key(nameoroid) - return k in self.data + return k in self._data - def keys(self): - return self._keytuple2attrtype.values() + def keys(self) -> List[str]: # type: ignore + return self._keytuple2attrtype.values() # type: ignore - def items(self): + def items(self) -> List[Tuple[str, List[bytes]]]: # type: ignore return [ (k,self[k]) for k in self.keys() ] def attribute_types( - self,attr_type_filter=None,raise_keyerror=1 - ): + self, + attr_type_filter: Optional[List[Tuple[str, List[str]]]] = None, + raise_keyerror: int = 1, + ) -> Tuple[cidict[Optional[AttributeType]], cidict[Optional[AttributeType]]]: """ Convenience wrapper around SubSchema.attribute_types() which passes object classes of this particular entry as argument to SubSchema.attribute_types() """ - return self._s.attribute_types( - self.get('objectClass',[]),attr_type_filter,raise_keyerror - ) + bin_ocs = self.get('objectClass', []) + ocs = [oc.decode("utf-8") for oc in bin_ocs] + + return self._s.attribute_types(ocs,attr_type_filter,raise_keyerror) diff --git a/Lib/ldap/schema/subentry.py b/Lib/ldap/schema/subentry.py index b83d819b..9069df3b 100644 --- a/Lib/ldap/schema/subentry.py +++ b/Lib/ldap/schema/subentry.py @@ -7,20 +7,44 @@ import copy from urllib.request import urlopen -import ldap.cidict,ldap.schema -from ldap.schema.models import * +import ldap.schema +from ldap.cidict import cidict import ldapurl import ldif - -SCHEMA_CLASS_MAPPING = ldap.cidict.cidict() -SCHEMA_ATTR_MAPPING = {} - -for o in list(vars().values()): - if hasattr(o,'schema_attribute'): - SCHEMA_CLASS_MAPPING[o.schema_attribute] = o - SCHEMA_ATTR_MAPPING[o] = o.schema_attribute +from ldap.types import LDAPEntryDict +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Set, + Tuple, + Type, + TypeVar, + Optional, + Union, +) + +# Maps schema element description (from class.schema_attribute, +# e.g. 'ObjectClass') to the schema class. +SCHEMA_CLASS_MAPPING: cidict[Type["SchemaElement"]] = cidict() + +# The reverse of SCHEMA_CLASS_MAPPING +SCHEMA_ATTR_MAPPING: Dict[Type["SchemaElement"], str] = {} + +# Note: this cannot be moved up due to circular imports: +# ldap.schema.models imports the two dicts above +from ldap.schema.models import ( + SchemaElement, + AttributeType, + ObjectClass, + DITContentRule, +) + +SchemaElementSubclass = TypeVar('SchemaElementSubclass', bound=SchemaElement) SCHEMA_ATTRS = list(SCHEMA_CLASS_MAPPING) @@ -31,19 +55,19 @@ class SubschemaError(ValueError): class OIDNotUnique(SubschemaError): - def __init__(self,desc): + def __init__(self, desc: str) -> None: self.desc = desc - def __str__(self): + def __str__(self) -> str: return 'OID not unique for %s' % (self.desc) class NameNotUnique(SubschemaError): - def __init__(self,desc): + def __init__(self, desc: str) -> None: self.desc = desc - def __str__(self): + def __str__(self) -> str: return 'NAME not unique for %s' % (self.desc) @@ -79,20 +103,34 @@ class SubSchema: List of NAMEs used at least twice in the subschema for the same schema element """ - def __init__(self,sub_schema_sub_entry,check_uniqueness=1): + def __init__( + self, + sub_schema_sub_entry: LDAPEntryDict, + check_uniqueness: int = 1, + ) -> None: + + # SchemaElement class -> Element name -> Element OID + self.name2oid: Dict[Type[SchemaElement], cidict[str]] = {} + + # SchemaElement class -> Element OID -> Element object instance + self.sed: Dict[Type[SchemaElement], Dict[str, SchemaElement]] = {} + + # Temporary set to hold OIDs which are not unique + non_unique_oids: Set[str] = set() + + # Dict mapping schema element class to a cidict where keys are used to + # indicate OIDs with duplicate names (values are not used) + # FIXME: this seems incomplete (cf. class docstring above and + # compare to how non_unique_oids is handled at the end) + self.non_unique_names: Dict[Type[SchemaElement], cidict[None]] = {} - # Initialize all dictionaries - self.name2oid = {} - self.sed = {} - self.non_unique_oids = {} - self.non_unique_names = {} for c in SCHEMA_CLASS_MAPPING.values(): self.name2oid[c] = ldap.cidict.cidict() self.sed[c] = {} self.non_unique_names[c] = ldap.cidict.cidict() # Transform entry dict to case-insensitive dict - e = ldap.cidict.cidict(sub_schema_sub_entry) + e: LDAPEntryDict = ldap.cidict.cidict(sub_schema_sub_entry) # Build the schema registry in dictionaries for attr_type in SCHEMA_ATTRS: @@ -104,7 +142,7 @@ def __init__(self,sub_schema_sub_entry,check_uniqueness=1): se_id = se_instance.get_id() if check_uniqueness and se_id in self.sed[se_class]: - self.non_unique_oids[se_id] = None + non_unique_oids.add(se_id) if check_uniqueness==1: # Add to subschema by adding suffix to ID suffix_counter = 1 @@ -115,31 +153,35 @@ def __init__(self,sub_schema_sub_entry,check_uniqueness=1): else: se_id = new_se_id elif check_uniqueness>=2: - raise OIDNotUnique(attr_value) + raise OIDNotUnique(attr_value.decode('utf-8', errors='backslashreplace')) # Store the schema element instance in the central registry self.sed[se_class][se_id] = se_instance if hasattr(se_instance,'names'): for name in ldap.cidict.cidict({}.fromkeys(se_instance.names)): + # FIXME: should match behaviour for OIDs above? if check_uniqueness and name in self.name2oid[se_class]: self.non_unique_names[se_class][se_id] = None - raise NameNotUnique(attr_value) + raise NameNotUnique(attr_value.decode('utf-8', errors='backslashreplace')) else: self.name2oid[se_class][name] = se_id - # Turn dict into list maybe more handy for applications - self.non_unique_oids = list(self.non_unique_oids) + # Turn set into list, maybe more handy for applications + self.non_unique_oids = list(non_unique_oids) return # subSchema.__init__() - def ldap_entry(self): + def ldap_entry(self) -> Dict[str, List[str]]: """ Returns a dictionary containing the sub schema sub entry + + The keys of the dict are the schema element attribute name and + the values are lists of schema element definition strings. """ # Initialize the dictionary with empty lists - entry = {} + entry: Dict[str, List[str]] = {} # Collect the schema elements and store them in # entry's attributes for se_class, elements in self.sed.items(): @@ -151,10 +193,24 @@ def ldap_entry(self): entry[SCHEMA_ATTR_MAPPING[se_class]] = [ se_str ] return entry - def listall(self,schema_element_class,schema_element_filters=None): + def listall( + self, + schema_element_class: Type[SchemaElement], + schema_element_filters: Optional[Iterable[Tuple[str, Iterable[Union[str, int]]]]] = None, + ) -> List[str]: """ Returns a list of OIDs of all available schema elements of a given schema element class. + + Arguments: + + schema_element_class + The schema element class to limit the search to + + schema_element_filters + A list of 2-tuples containing an attribute name and a sequence + of possible values for the attribute name. If any filter matches, + the element will be included in the returned list. """ avail_se = self.sed[schema_element_class] if schema_element_filters: @@ -164,6 +220,7 @@ def listall(self,schema_element_class,schema_element_filters=None): try: if getattr(se,fk) in fv: result.append(se_key) + # FIXME: should break here? except AttributeError: pass else: @@ -171,15 +228,34 @@ def listall(self,schema_element_class,schema_element_filters=None): return result - def tree(self,schema_element_class,schema_element_filters=None): + def tree( + self, + schema_element_class: Union[Type[ObjectClass], Type[AttributeType]], + schema_element_filters: Optional[Iterable[Tuple[str, Iterable[Union[str, int]]]]] = None, + ) -> cidict[List[str]]: """ Returns a ldap.cidict.cidict dictionary representing the tree structure of the schema elements. + + The dict will have the key '_' as the root element, and each + key maps to a list of OIDs (inferior or child attributes), which + can in turn be used as keys to work down the hierarchy. + + Arguments: + + schema_element_class + The schema element class to limit the search to + Note that only ``ObjectClass`` and ``AttributeType`` are supported. + + schema_element_filters + A list of 2-tuples containing an attribute name and a sequence + of possible values for the attribute name. If any filter matches, + the element will be included in the returned dict. """ assert schema_element_class in [ObjectClass,AttributeType] avail_se = self.listall(schema_element_class,schema_element_filters) top_node = '_' - tree = ldap.cidict.cidict({top_node:[]}) + tree: cidict[List[str]] = ldap.cidict.cidict({top_node:[]}) # 1. Pass: Register all nodes for se in avail_se: tree[se] = [] @@ -190,11 +266,12 @@ def tree(self,schema_element_class,schema_element_filters=None): # Ignore schema elements not matching schema_element_class. # This helps with falsely assigned OIDs. continue + # FIXME: This assertion is superfluous? assert se_obj.__class__==schema_element_class, \ "Schema element referenced by {} must be of class {} but was {}".format( se_oid,schema_element_class.__name__,se_obj.__class__ ) - for s in se_obj.sup or ('_',): + for s in getattr(se_obj, "sup", ['_']): sup_oid = self.getoid(schema_element_class,s) try: tree[sup_oid].append(se_oid) @@ -203,7 +280,12 @@ def tree(self,schema_element_class,schema_element_filters=None): return tree - def getoid(self,se_class,nameoroid,raise_keyerror=0): + def getoid( + self, + se_class: Type[SchemaElementSubclass], + nameoroid: str, + raise_keyerror: int = 0, + ) -> str: """ Get an OID by name or OID """ @@ -222,7 +304,12 @@ def getoid(self,se_class,nameoroid,raise_keyerror=0): return result_oid - def get_inheritedattr(self,se_class,nameoroid,name): + def get_inheritedattr( + self, + se_class: Type[SchemaElementSubclass], + nameoroid: str, + name: str, + ) -> Any: """ Get a possibly inherited attribute specified by name of a schema element specified by nameoroid. @@ -235,12 +322,23 @@ def get_inheritedattr(self,se_class,nameoroid,name): result = getattr(se,name) except AttributeError: result = None - if result is None and se.sup: + if result is None and hasattr(se, 'sup') and se.sup: + # FIXME: sup can be multi-valued result = self.get_inheritedattr(se_class,se.sup[0],name) + + # The return type could be something like this: + # Tuple[str, ...] | Tuple[None] | str | int | None + # But we have no control over what is passed as "name"... return result - def get_obj(self,se_class,nameoroid,default=None,raise_keyerror=0): + def get_obj( + self, + se_class: Type[SchemaElementSubclass], + nameoroid: str, + default: Optional[SchemaElementSubclass] = None, + raise_keyerror: int = 0, + ) -> Optional[SchemaElementSubclass]: """ Get a schema element by name or OID """ @@ -252,25 +350,37 @@ def get_obj(self,se_class,nameoroid,default=None,raise_keyerror=0): raise KeyError('No ldap.schema.{} instance with nameoroid {} and se_oid {}'.format( se_class.__name__,repr(nameoroid),repr(se_oid)) ) + elif default is None: + return None else: se_obj = default + + assert isinstance(se_obj, se_class) return se_obj - def get_inheritedobj(self,se_class,nameoroid,inherited=None): + def get_inheritedobj( + self, + se_class: Type[SchemaElementSubclass], + nameoroid: str, + inherited: Optional[List[str]] = None, + ) -> Optional[SchemaElementSubclass]: """ Get a schema element by name or OID with all class attributes set including inherited class attributes """ + # FIXME: could use a TypeVar to limit the return value to an se_class instance inherited = inherited or [] se = copy.copy(self.sed[se_class].get(self.getoid(se_class,nameoroid))) if se and hasattr(se,'sup'): for class_attr_name in inherited: setattr(se,class_attr_name,self.get_inheritedattr(se_class,nameoroid,class_attr_name)) + + assert isinstance(se, se_class) return se - def get_syntax(self,nameoroid): + def get_syntax(self, nameoroid: str) -> Optional[str]: """ Get the syntax of an attribute type specified by name or OID """ @@ -279,11 +389,14 @@ def get_syntax(self,nameoroid): at_obj = self.get_inheritedobj(AttributeType,at_oid) except KeyError: return None + + if at_obj is None: + return None else: return at_obj.syntax - def get_structural_oc(self,oc_list): + def get_structural_oc(self, oc_list: Iterable[str]) -> Optional[str]: """ Returns OID of structural object class in oc_list if any is present. Returns None else. @@ -291,11 +404,11 @@ def get_structural_oc(self,oc_list): # Get tree of all STRUCTURAL object classes oc_tree = self.tree(ObjectClass,[('kind',[0])]) # Filter all STRUCTURAL object classes - struct_ocs = {} + struct_ocs = set() for oc_nameoroid in oc_list: oc_se = self.get_obj(ObjectClass,oc_nameoroid,None) if oc_se and oc_se.kind==0: - struct_ocs[oc_se.oid] = None + struct_ocs.add(oc_se.oid) result = None # Build a copy of the oid list, to be cleaned as we go. struct_oc_list = list(struct_ocs) @@ -309,7 +422,7 @@ def get_structural_oc(self,oc_list): return result - def get_applicable_aux_classes(self,nameoroid): + def get_applicable_aux_classes(self, nameoroid: str) -> List[str]: """ Return a list of the applicable AUXILIARY object classes for a STRUCTURAL object class specified by 'nameoroid' @@ -320,26 +433,30 @@ def get_applicable_aux_classes(self,nameoroid): content_rule = self.get_obj(DITContentRule,nameoroid) if content_rule: # Return AUXILIARY object classes from DITContentRule instance - return content_rule.aux + return list(content_rule.aux) else: # list all AUXILIARY object classes return self.listall(ObjectClass,[('kind',[2])]) def attribute_types( - self,object_class_list,attr_type_filter=None,raise_keyerror=1,ignore_dit_content_rule=0 - ): + self, + object_class_list: Iterable[str], + attr_type_filter: Optional[Iterable[Tuple[str, Iterable[Union[str, int]]]]] = None, + raise_keyerror: int = 1, + ignore_dit_content_rule: int = 0, + ) -> Tuple[cidict[Optional[AttributeType]], cidict[Optional[AttributeType]]]: """ Returns a 2-tuple of all must and may attributes including all inherited attributes of superior object classes by walking up classes along the SUP attribute. - The attributes are stored in a ldap.cidict.cidict dictionary. + The attributes are stored in ldap.cidict.cidict dictionaries. object_class_list list of strings specifying object class names or OIDs attr_type_filter - list of 2-tuples containing lists of class attributes - which has to be matched + list of 2-tuples containing a class attribute name and a + list of class attributes which has to be matched raise_keyerror All KeyError exceptions for non-existent schema elements are ignored @@ -347,8 +464,8 @@ def attribute_types( A DIT content rule governing the structural object class is ignored """ - AttributeType = ldap.schema.AttributeType - ObjectClass = ldap.schema.ObjectClass + AttributeType = ldap.schema.models.AttributeType + ObjectClass = ldap.schema.models.ObjectClass # Map object_class_list to object_class_oids (list of OIDs) object_class_oids = [ @@ -356,12 +473,15 @@ def attribute_types( for o in object_class_list ] # Initialize - oid_cache = {} + oid_cache: Dict[str, None] = {} + + r_must: cidict[Optional[ldap.schema.models.AttributeType]] = ldap.cidict.cidict() + r_may: cidict[Optional[ldap.schema.models.AttributeType]] = ldap.cidict.cidict() - r_must,r_may = ldap.cidict.cidict(),ldap.cidict.cidict() if '1.3.6.1.4.1.1466.101.120.111' in object_class_oids: # Object class 'extensibleObject' MAY carry every attribute type for at_obj in self.sed[AttributeType].values(): + assert isinstance(at_obj, AttributeType),ValueError(at_obj.oid) r_may[at_obj.oid] = at_obj # Loop over OIDs of all given object classes @@ -403,9 +523,10 @@ def attribute_types( try: dit_content_rule = self.get_obj(DITContentRule,structural_oc,raise_keyerror=1) except KeyError: - # Not DIT content rule found for structural objectclass + # No DIT content rule found for structural objectclass pass else: + assert dit_content_rule is not None for a in dit_content_rule.must: se_oid = self.getoid(AttributeType,a,raise_keyerror=raise_keyerror) r_must[se_oid] = self.get_obj(AttributeType,se_oid,raise_keyerror=raise_keyerror) @@ -447,7 +568,10 @@ def attribute_types( return r_must,r_may # attribute_types() -def urlfetch(uri,trace_level=0): +def urlfetch( + uri: str, + trace_level: int = 0, + ) -> Tuple[Optional[str], Optional[SubSchema]]: """ Fetches a parsed schema entry by uri. @@ -480,8 +604,9 @@ def urlfetch(uri,trace_level=0): ldif_parser = ldif.LDIFRecordList(ldif_file,max_entries=1) ldif_parser.parse() subschemasubentry_dn,s_temp = ldif_parser.all_records[0] + # Work-around for mixed-cased attribute names - subschemasubentry_entry = ldap.cidict.cidict() + subschemasubentry_entry: MutableMapping[str, List[bytes]] = ldap.cidict.cidict() s_temp = s_temp or {} for at,av in s_temp.items(): if at in SCHEMA_CLASS_MAPPING: @@ -491,7 +616,7 @@ def urlfetch(uri,trace_level=0): subschemasubentry_entry[at] = av # Finally parse the schema if subschemasubentry_dn!=None: - parsed_sub_schema = ldap.schema.SubSchema(subschemasubentry_entry) + parsed_sub_schema = SubSchema(subschemasubentry_entry) else: parsed_sub_schema = None return subschemasubentry_dn, parsed_sub_schema diff --git a/Lib/ldap/schema/tokenizer.py b/Lib/ldap/schema/tokenizer.py index 623b86d5..a68ee1eb 100644 --- a/Lib/ldap/schema/tokenizer.py +++ b/Lib/ldap/schema/tokenizer.py @@ -6,6 +6,15 @@ import re +from typing import Dict, List, Tuple, Mapping, Union +from typing_extensions import TypeAlias + +LDAPTokenDictValue: TypeAlias = "Tuple[()] | Tuple[str, ...]" +"""The kind of values which may be found in a token dict.""" + +LDAPTokenDict: TypeAlias = "Mapping[str, LDAPTokenDictValue]" +"""The type of the dict used to keep track of tokens while parsing schema (Mapping because of variance).""" + TOKENS_FINDALL = re.compile( r"(\()" # opening parenthesis r"|" # or @@ -24,7 +33,7 @@ UNESCAPE_PATTERN = re.compile(r"\\(.)") -def split_tokens(s): +def split_tokens(s: str) -> List[str]: """ Returns list of syntax elements with quotes and spaces stripped. """ @@ -50,35 +59,68 @@ def split_tokens(s): raise ValueError("Unbalanced parenthesis in %r" % (s)) return parts -def extract_tokens(l,known_tokens): - """ - Returns dictionary of known tokens with all values - """ - assert l[0].strip()=="(" and l[-1].strip()==")",ValueError(l) - result = {} - result.update(known_tokens) - i = 0 - l_len = len(l) - while i Tuple[str, LDAPTokenDict]: + """ + Process a list of tokens and return a dictionary of known tokens with all + values + + Arguments: + + tokens + A list of tokens to process. + + known_tokens + A list of known tokens, unknown tokens will be ignored + + Returns: + + A tuple of the oid of the schema element and a dictionary mapping the + found tokens to their value(s). + """ + + assert len(tokens) > 2, ValueError(tokens) + assert tokens[0].strip() == "(", ValueError(tokens) + assert tokens[-1].strip() == ")", ValueError(tokens) + + oid = tokens[1] + result = {} + + i = 2 + while i < len(tokens): + token = tokens[i] + i += 1 + + if token not in known_tokens: + # Skip unrecognized token + continue + + if i >= len(tokens): + break + + next_token = tokens[i] + + if next_token in known_tokens: + # non-valued + value: Union[Tuple[()], Tuple[str, ...]] = (()) + + elif next_token == "(": + # multi-valued + i += 1 # Consume left parentheses + start = i + while i < len(tokens) and tokens[i] != ")": + i += 1 + value = tuple(filter(lambda v: v != '$', tokens[start:i])) + i += 1 # Consume right parentheses + else: - # single-valued - result[token] = l[i], - i += 1 # Consume single value - else: - i += 1 # Consume unrecognized item - return result + # single-valued + value = (next_token,) + i += 1 # Consume single value + + result[token] = value + + return oid, result diff --git a/Lib/ldap/syncrepl.py b/Lib/ldap/syncrepl.py index 1708b468..8aeb0b96 100644 --- a/Lib/ldap/syncrepl.py +++ b/Lib/ldap/syncrepl.py @@ -12,26 +12,30 @@ from ldap.pkginfo import __version__, __author__, __license__ from ldap.controls import RequestControl, ResponseControl, KNOWN_RESPONSE_CONTROLS +import ldap + +from ldap.types import LDAPEntryDict +from typing import Any, Dict, List, Type, Tuple, Optional, Union __all__ = [ 'SyncreplConsumer', ] -class SyncUUID(univ.OctetString): +class SyncUUID(univ.OctetString): # type: ignore """ syncUUID ::= OCTET STRING (SIZE(16)) """ subtypeSpec = constraint.ValueSizeConstraint(16, 16) -class SyncCookie(univ.OctetString): +class SyncCookie(univ.OctetString): # type: ignore """ syncCookie ::= OCTET STRING """ -class SyncRequestMode(univ.Enumerated): +class SyncRequestMode(univ.Enumerated): # type: ignore """ mode ENUMERATED { -- 0 unused @@ -47,7 +51,7 @@ class SyncRequestMode(univ.Enumerated): subtypeSpec = univ.Enumerated.subtypeSpec + constraint.SingleValueConstraint(1, 3) -class SyncRequestValue(univ.Sequence): +class SyncRequestValue(univ.Sequence): # type: ignore """ syncRequestValue ::= SEQUENCE { mode ENUMERATED { @@ -79,23 +83,32 @@ class SyncRequestControl(RequestControl): """ controlType = '1.3.6.1.4.1.4203.1.9.1.1' - def __init__(self, criticality=1, cookie=None, mode='refreshOnly', reloadHint=False): - self.criticality = criticality + def __init__( + self, + criticality: Union[int, bool] = True, + cookie: Optional[str] = None, + mode: str = 'refreshOnly', + reloadHint: bool = False, + ) -> None: + if criticality: + self.criticality = True + else: + self.criticality = False self.cookie = cookie self.mode = mode self.reloadHint = reloadHint - def encodeControlValue(self): + def encodeControlValue(self) -> bytes: rcv = SyncRequestValue() rcv.setComponentByName('mode', SyncRequestMode(self.mode)) if self.cookie is not None: rcv.setComponentByName('cookie', SyncCookie(self.cookie)) if self.reloadHint: rcv.setComponentByName('reloadHint', univ.Boolean(self.reloadHint)) - return encoder.encode(rcv) + return encoder.encode(rcv) # type: ignore -class SyncStateOp(univ.Enumerated): +class SyncStateOp(univ.Enumerated): # type: ignore """ state ENUMERATED { present (0), @@ -113,7 +126,7 @@ class SyncStateOp(univ.Enumerated): subtypeSpec = univ.Enumerated.subtypeSpec + constraint.SingleValueConstraint(0, 1, 2, 3) -class SyncStateValue(univ.Sequence): +class SyncStateValue(univ.Sequence): # type: ignore """ syncStateValue ::= SEQUENCE { state ENUMERATED { @@ -146,13 +159,13 @@ class SyncStateControl(ResponseControl): controlType = '1.3.6.1.4.1.4203.1.9.1.2' opnames = ('present', 'add', 'modify', 'delete') - def decodeControlValue(self, encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: d = decoder.decode(encodedControlValue, asn1Spec=SyncStateValue()) state = d[0].getComponentByName('state') uuid = UUID(bytes=bytes(d[0].getComponentByName('entryUUID'))) cookie = d[0].getComponentByName('cookie') if cookie is not None and cookie.hasValue(): - self.cookie = str(cookie) + self.cookie: Optional[str] = str(cookie) else: self.cookie = None self.state = self.__class__.opnames[int(state)] @@ -161,7 +174,7 @@ def decodeControlValue(self, encodedControlValue): KNOWN_RESPONSE_CONTROLS[SyncStateControl.controlType] = SyncStateControl -class SyncDoneValue(univ.Sequence): +class SyncDoneValue(univ.Sequence): # type: ignore """ syncDoneValue ::= SEQUENCE { cookie syncCookie OPTIONAL, @@ -186,23 +199,23 @@ class SyncDoneControl(ResponseControl): """ controlType = '1.3.6.1.4.1.4203.1.9.1.3' - def decodeControlValue(self, encodedControlValue): + def decodeControlValue(self, encodedControlValue: bytes) -> None: d = decoder.decode(encodedControlValue, asn1Spec=SyncDoneValue()) cookie = d[0].getComponentByName('cookie') if cookie.hasValue(): - self.cookie = str(cookie) + self.cookie: Optional[str] = str(cookie) else: self.cookie = None refresh_deletes = d[0].getComponentByName('refreshDeletes') if refresh_deletes.hasValue(): self.refreshDeletes = bool(refresh_deletes) else: - self.refreshDeletes = None + self.refreshDeletes = False KNOWN_RESPONSE_CONTROLS[SyncDoneControl.controlType] = SyncDoneControl -class RefreshDelete(univ.Sequence): +class RefreshDelete(univ.Sequence): # type: ignore """ refreshDelete [1] SEQUENCE { cookie syncCookie OPTIONAL, @@ -215,7 +228,7 @@ class RefreshDelete(univ.Sequence): ) -class RefreshPresent(univ.Sequence): +class RefreshPresent(univ.Sequence): # type: ignore """ refreshPresent [2] SEQUENCE { cookie syncCookie OPTIONAL, @@ -228,14 +241,14 @@ class RefreshPresent(univ.Sequence): ) -class SyncUUIDs(univ.SetOf): +class SyncUUIDs(univ.SetOf): # type: ignore """ syncUUIDs SET OF syncUUID """ componentType = SyncUUID() -class SyncIdSet(univ.Sequence): +class SyncIdSet(univ.Sequence): # type: ignore """ syncIdSet [3] SEQUENCE { cookie syncCookie OPTIONAL, @@ -250,7 +263,7 @@ class SyncIdSet(univ.Sequence): ) -class SyncInfoValue(univ.Choice): +class SyncInfoValue(univ.Choice): # type: ignore """ syncInfoValue ::= CHOICE { newcookie [0] syncCookie, @@ -306,7 +319,7 @@ class SyncInfoMessage: """ responseName = '1.3.6.1.4.1.4203.1.9.1.4' - def __init__(self, encodedMessage): + def __init__(self, encodedMessage: bytes) -> None: d = decoder.decode(encodedMessage, asn1Spec=SyncInfoValue()) self.newcookie = None self.refreshDelete = None @@ -324,7 +337,7 @@ def __init__(self, encodedMessage): self.newcookie = str(comp) return - val = {} + val: Dict[str, Union[str, bool, List[str]]] = {} cookie = comp.getComponentByName('cookie') if cookie.hasValue(): @@ -344,12 +357,20 @@ def __init__(self, encodedMessage): setattr(self, attr, val) -class SyncreplConsumer: +# FIXME: This class expects to be a subclass of ldap.ldapobject.* +class SyncreplConsumer(): """ SyncreplConsumer - LDAP syncrepl consumer object. """ - def syncrepl_search(self, base, scope, mode='refreshOnly', cookie=None, **search_args): + def syncrepl_search( + self, + base: str, + scope: int, + mode: str = 'refreshOnly', + cookie: Optional[str] = None, + **search_args: Any, + ) -> int: """ Starts syncrepl search operation. @@ -384,9 +405,15 @@ def syncrepl_search(self, base, scope, mode='refreshOnly', cookie=None, **search search_args['serverctrls'] = [syncreq] self.__refreshDone = False - return self.search_ext(base, scope, **search_args) - - def syncrepl_poll(self, msgid=-1, timeout=None, all=0): + # FIXME: This assumes that we're subclassing LDAPObject + return self.search_ext(base, scope, **search_args) # type: ignore + + def syncrepl_poll( + self, + msgid: int = -1, + timeout: Optional[int] = None, + all: int = 0, + ) -> bool: """ polls for and processes responses to the syncrepl_search() operation. Returns False when operation finishes, True if it is in progress, or @@ -399,7 +426,8 @@ def syncrepl_poll(self, msgid=-1, timeout=None, all=0): """ while True: - type, msg, mid, ctrls, n, v = self.result4( + # FIXME: This assumes that we're subclassing LDAPObject + type, msg, mid, ctrls, n, v = self.result4( # type: ignore msgid=msgid, timeout=timeout, add_intermediates=1, @@ -411,8 +439,8 @@ def syncrepl_poll(self, msgid=-1, timeout=None, all=0): # search result. This marks the end of a refreshOnly session. # look for a SyncDone control, save the cookie, and if necessary # delete non-present entries. - for c in ctrls: - if c.__class__.__name__ != 'SyncDoneControl': + for c in ctrls or []: + if not isinstance(c, SyncDoneControl): continue self.syncrepl_present(None, refreshDeletes=c.refreshDeletes) if c.cookie is not None: @@ -422,10 +450,10 @@ def syncrepl_poll(self, msgid=-1, timeout=None, all=0): elif type == 100: # search entry with associated SyncState control - for m in msg: + for m in msg or []: dn, attrs, ctrls = m - for c in ctrls: - if c.__class__.__name__ != 'SyncStateControl': + for c in ctrls or []: + if not isinstance(c, SyncStateControl): continue if c.state == 'present': self.syncrepl_present([c.entryUUID]) @@ -441,7 +469,7 @@ def syncrepl_poll(self, msgid=-1, timeout=None, all=0): elif type == 121: # Intermediate message. If it is a SyncInfoMessage, parse it - for m in msg: + for m in msg or []: rname, resp, ctrls = m if rname != SyncInfoMessage.responseName: continue @@ -476,19 +504,25 @@ def syncrepl_poll(self, msgid=-1, timeout=None, all=0): # virtual methods -- subclass must override these to do useful work - def syncrepl_set_cookie(self, cookie): + def syncrepl_set_cookie(self, cookie: str) -> None: """ Called by syncrepl_poll() to store a new cookie provided by the server. """ + # FIXME: The cookie is an opaque octet string, so the type should be bytes? pass - def syncrepl_get_cookie(self): + def syncrepl_get_cookie(self) -> str: """ Called by syncrepl_search() to retrieve the cookie stored by syncrepl_set_cookie() """ - pass - - def syncrepl_present(self, uuids, refreshDeletes=False): + # FIXME: The cookie is an opaque octet string, so the type should be bytes? + return '' + + def syncrepl_present( + self, + uuids: Optional[List[str]], + refreshDeletes: bool = False, + ) -> None: """ Called by syncrepl_poll() whenever entry UUIDs are presented to the client. syncrepl_present() is given a list of entry UUIDs (uuids) and a flag @@ -508,7 +542,7 @@ def syncrepl_present(self, uuids, refreshDeletes=False): """ pass - def syncrepl_delete(self, uuids): + def syncrepl_delete(self, uuids: List[str]) -> None: """ Called by syncrepl_poll() to delete entries. A list of UUIDs of the entries to be deleted is given in the @@ -516,7 +550,7 @@ def syncrepl_delete(self, uuids): """ pass - def syncrepl_entry(self, dn, attrs, uuid): + def syncrepl_entry(self, dn: str, attrs: LDAPEntryDict, uuid: str) -> None: """ Called by syncrepl_poll() for any added or modified entries. @@ -526,7 +560,7 @@ def syncrepl_entry(self, dn, attrs, uuid): """ pass - def syncrepl_refreshdone(self): + def syncrepl_refreshdone(self) -> None: """ Called by syncrepl_poll() between refresh and persist phase. diff --git a/Lib/ldap/types.py b/Lib/ldap/types.py new file mode 100644 index 00000000..dd568110 --- /dev/null +++ b/Lib/ldap/types.py @@ -0,0 +1,59 @@ +""" +types - type annotations which are shared across modules + +See https://www.python-ldap.org/ for details. +""" +from ldap.pkginfo import __version__ + +from typing import ( + List, + MutableMapping, + Tuple, + Sequence, + Optional, + Union, +) +from typing_extensions import TypeAlias + +__all__ = [ + 'LDAPModListAddEntry', + 'LDAPModListModifyEntry', + 'LDAPModListEntry', + 'LDAPAddModList', + 'LDAPModifyModList', + 'LDAPModList', + 'LDAPEntryDict', + 'LDAPControl', + 'LDAPControls', + 'LDAPSearchResult', +] + +LDAPModListAddEntry: TypeAlias = "Tuple[str, List[bytes]]" +"""The type of an addition entry in a modlist.""" + +LDAPModListModifyEntry: TypeAlias = "Tuple[int, str, Optional[Union[bytes, List[bytes]]]]" +"""The type of a modification entry in a modlist.""" + +LDAPModListEntry: TypeAlias = "LDAPModListAddEntry | LDAPModListModifyEntry" +"""The type of any kind of entry in a modlist.""" + +LDAPAddModList: TypeAlias = "Sequence[LDAPModListAddEntry]" +"""The type of an add modlist.""" + +LDAPModifyModList: TypeAlias = "Sequence[LDAPModListModifyEntry]" +"""The type of a modify modlist.""" + +LDAPModList: TypeAlias = "Sequence[LDAPModListEntry]" +"""The type of a mixed modlist.""" + +LDAPEntryDict: TypeAlias = "MutableMapping[str, List[bytes]]" +"""The type used to store attribute-value mappings for a given LDAP entry (attribute name, list of binary values).""" + +LDAPControl: TypeAlias = "Tuple[str, str, Optional[str]]" +"""The type used to store controls (type, criticality, value).""" + +LDAPControls: TypeAlias = "List[LDAPControl]" +"""The type used to store control lists.""" + +LDAPSearchResult: TypeAlias = "Tuple[str, LDAPEntryDict]" +"""The type of a search result, a tuple with a DN string and a dict of attributes.""" diff --git a/Lib/ldapurl.py b/Lib/ldapurl.py index b4dfd890..ab2761a7 100644 --- a/Lib/ldapurl.py +++ b/Lib/ldapurl.py @@ -19,6 +19,8 @@ from collections.abc import MutableMapping from urllib.parse import quote, unquote +from typing import Dict, Iterator, List, Optional, TYPE_CHECKING + LDAP_SCOPE_BASE = 0 LDAP_SCOPE_ONELEVEL = 1 LDAP_SCOPE_SUBTREE = 2 @@ -42,21 +44,18 @@ 'subordinates':LDAP_SCOPE_SUBORDINATES, } -# Some widely used types -StringType = type('') -TupleType=type(()) - -def isLDAPUrl(s): +def isLDAPUrl(s: str) -> bool: """Returns True if s is a LDAP URL, else False """ return s.lower().startswith(('ldap://', 'ldaps://', 'ldapi://')) -def ldapUrlEscape(s): +def ldapUrlEscape(s: str) -> str: """Returns URL encoding of string s""" return quote(s).replace(',','%2C').replace('/','%2F') + class LDAPUrlExtension: """ Class for parsing and unparsing LDAP URL extensions @@ -71,14 +70,20 @@ class LDAPUrlExtension: Value of extension """ - def __init__(self,extensionStr=None,critical=0,extype=None,exvalue=None): + def __init__( + self, + extensionStr: Optional[str] = None, + critical: int = 0, + extype: Optional[str] = None, + exvalue: Optional[str] = None + ) -> None: self.critical = critical self.extype = extype self.exvalue = exvalue if extensionStr: self._parse(extensionStr) - def _parse(self,extension): + def _parse(self, extension: str) -> None: extension = extension.strip() if not extension: # Don't parse empty strings @@ -96,7 +101,7 @@ def _parse(self,extension): self.exvalue = unquote(self.exvalue.strip()) self.extype = self.extype.strip() - def unparse(self): + def unparse(self) -> str: if self.exvalue is None: return '{}{}'.format('!'*(self.critical>0),self.extype) else: @@ -105,10 +110,10 @@ def unparse(self): self.extype,quote(self.exvalue or '') ) - def __str__(self): + def __str__(self) -> str: return self.unparse() - def __repr__(self): + def __repr__(self) -> str: return '<{}.{} instance at {}: {}>'.format( self.__class__.__module__, self.__class__.__name__, @@ -116,29 +121,41 @@ def __repr__(self): self.__dict__ ) - def __eq__(self,other): - return \ - (self.critical==other.critical) and \ - (self.extype==other.extype) and \ - (self.exvalue==other.exvalue) + def __eq__(self, other: object) -> bool: + if not isinstance(other, LDAPUrlExtension): + return False + elif self.critical != other.critical: + return False + elif self.extype != other.extype: + return False + elif self.exvalue != other.exvalue: + return False + else: + return True - def __ne__(self,other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) -class LDAPUrlExtensions(MutableMapping): +if TYPE_CHECKING: + LDAPUrlExtensionsBase = MutableMapping[str, LDAPUrlExtension] +else: + # Python <= 3.8 compatibility + LDAPUrlExtensionsBase = MutableMapping + +class LDAPUrlExtensions(LDAPUrlExtensionsBase): """ Models a collection of LDAP URL extensions as a mapping type """ __slots__ = ('_data', ) - def __init__(self, default=None): - self._data = {} + def __init__(self, default: Optional[Dict[str, LDAPUrlExtension]] = None) -> None: + self._data: Dict[str, LDAPUrlExtension] = {} if default is not None: self.update(default) - def __setitem__(self, name, value): + def __setitem__(self, name: str, value: LDAPUrlExtension) -> None: """Store an extension name @@ -155,22 +172,22 @@ def __setitem__(self, name, value): name, value.extype)) self._data[name] = value - def __getitem__(self, name): + def __getitem__(self, name: str) -> LDAPUrlExtension: return self._data[name] - def __delitem__(self, name): + def __delitem__(self, name: str) -> None: del self._data[name] - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._data) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def __str__(self): + def __str__(self) -> str: return ','.join(str(v) for v in self.values()) - def __repr__(self): + def __repr__(self) -> str: return '<{}.{} instance at {}: {}>'.format( self.__class__.__module__, self.__class__.__name__, @@ -178,18 +195,19 @@ def __repr__(self): self._data ) - def __eq__(self,other): + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented return self._data == other._data - def parse(self,extListStr): + def parse(self, extListStr: str) -> None: for extension_str in extListStr.strip().split(','): if extension_str: e = LDAPUrlExtension(extension_str) - self[e.extype] = e + if e.extype is not None: + self[e.extype] = e - def unparse(self): + def unparse(self) -> str: return ','.join(v.unparse() for v in self.values()) @@ -224,40 +242,59 @@ class LDAPUrl: def __init__( self, - ldapUrl=None, - urlscheme='ldap', - hostport='',dn='',attrs=None,scope=None,filterstr=None, - extensions=None, - who=None,cred=None - ): + ldapUrl: Optional[str] = None, + urlscheme: str = 'ldap', + hostport: str = '', + dn: str = '', + attrs: Optional[List[str]] = None, + scope: Optional[int] = None, + filterstr: Optional[str] = None, + extensions: Optional[LDAPUrlExtensions] = None, + who: Optional[str] = None, + cred: Optional[str] = None + ) -> None: + self.urlscheme=urlscheme.lower() self.hostport=hostport self.dn=dn self.attrs=attrs self.scope=scope self.filterstr=filterstr - self.extensions=(extensions or LDAPUrlExtensions({})) - if ldapUrl!=None: + self.extensions: Optional[LDAPUrlExtensions] = (extensions or LDAPUrlExtensions({})) + + if ldapUrl is not None: self._parse(ldapUrl) if who!=None: self.who = who if cred!=None: self.cred = cred - def __eq__(self,other): - return \ - self.urlscheme==other.urlscheme and \ - self.hostport==other.hostport and \ - self.dn==other.dn and \ - self.attrs==other.attrs and \ - self.scope==other.scope and \ - self.filterstr==other.filterstr and \ - self.extensions==other.extensions - - def __ne__(self,other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, LDAPUrl): + return False + elif self.urlscheme != other.urlscheme: + return False + elif self.urlscheme != other.urlscheme: + return False + elif self.hostport != other.hostport: + return False + elif self.dn != other.dn: + return False + elif self.attrs != other.attrs: + return False + elif self.scope != other.scope: + return False + elif self.filterstr != other.filterstr: + return False + elif self.extensions != other.extensions: + return False + else: + return True + + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def _parse(self,ldap_url): + def _parse(self, ldap_url: str) -> None: """ parse a LDAP URL and set the class attributes urlscheme,host,dn,attrs,scope,filterstr,extensions @@ -312,7 +349,7 @@ def _parse(self,ldap_url): self.extensions = None return - def applyDefaults(self,defaults): + def applyDefaults(self, defaults: Dict[str, str]) -> None: """ Apply defaults to all class attributes which are None. @@ -324,7 +361,7 @@ def applyDefaults(self,defaults): if getattr(self,k) is None: setattr(self, k, value) - def initializeUrl(self): + def initializeUrl(self) -> str: """ Returns LDAP URL suitable to be passed to ldap.initialize() """ @@ -335,7 +372,7 @@ def initializeUrl(self): hostport = self.hostport return f'{self.urlscheme}://{hostport}' - def unparse(self): + def unparse(self) -> str: """ Returns LDAP URL depending on class attributes set. """ @@ -362,7 +399,12 @@ def unparse(self): ldap_url = ldap_url+'?'+self.extensions.unparse() return ldap_url - def htmlHREF(self,urlPrefix='',hrefText=None,hrefTarget=None): + def htmlHREF( + self, + urlPrefix: str = '', + hrefText: Optional[str] = None, + hrefTarget: Optional[str] = None + ) -> str: """ Returns a string with HTML link for this LDAP URL. @@ -392,10 +434,10 @@ def htmlHREF(self,urlPrefix='',hrefText=None,hrefTarget=None): target, urlPrefix, self.unparse(), hrefText ) - def __str__(self): + def __str__(self) -> str: return self.unparse() - def __repr__(self): + def __repr__(self) -> str: return '<{}.{} instance at {}: {}>'.format( self.__class__.__module__, self.__class__.__name__, @@ -403,36 +445,41 @@ def __repr__(self): self.__dict__ ) - def __getattr__(self,name): - if name in self.attr2extype: - extype = self.attr2extype[name] - if self.extensions and \ - extype in self.extensions and \ - not self.extensions[extype].exvalue is None: - result = unquote(self.extensions[extype].exvalue) - else: - return None - else: + def __getattr__(self, name: str) -> Optional[str]: + if name not in self.attr2extype: raise AttributeError('{} has no attribute {}'.format( self.__class__.__name__,name )) - return result # __getattr__() - def __setattr__(self,name,value): + extype = self.attr2extype[name] + if self.extensions is None: + return None + elif extype not in self.extensions: + return None + + exvalue = self.extensions[extype].exvalue + if exvalue is None: + return None + else: + return unquote(exvalue) + + def __setattr__(self, name: str, value: str) -> None: if name in self.attr2extype: extype = self.attr2extype[name] if value is None: # A value of None means that extension is deleted delattr(self,name) - elif value!=None: + else: # Add appropriate extension + if self.extensions is None: + self.extensions = LDAPUrlExtensions() self.extensions[extype] = LDAPUrlExtension( extype=extype,exvalue=unquote(value) ) else: self.__dict__[name] = value - def __delattr__(self,name): + def __delattr__(self, name: str) -> None: if name in self.attr2extype: extype = self.attr2extype[name] if self.extensions: diff --git a/Lib/ldif.py b/Lib/ldif.py index fa41321c..fd789da8 100644 --- a/Lib/ldif.py +++ b/Lib/ldif.py @@ -25,6 +25,15 @@ from urllib.parse import urlparse from urllib.request import urlopen +from ldap.types import ( + LDAPEntryDict, + LDAPModList, + LDAPControls, + LDAPModListModifyEntry, + LDAPModListAddEntry, +) +from typing import BinaryIO, Dict, List, TextIO, Tuple, cast, Optional, Union + attrtype_pattern = r'[\w;.-]+(;[\w_-]+)*' attrvalue_pattern = r'(([^,]|\\,)+|".*?")' attrtypeandvalue_pattern = attrtype_pattern + r'[ ]*=[ ]*' + attrvalue_pattern @@ -46,25 +55,28 @@ } CHANGE_TYPES = ['add','delete','modify','modrdn'] -valid_changetype_dict = {} -for c in CHANGE_TYPES: - valid_changetype_dict[c]=None +valid_changetype_set = set(CHANGE_TYPES) -def is_dn(s): +def is_dn(s: str) -> int: """ returns 1 if s is a LDAP DN """ if s=='': return 1 rm = dn_regex.match(s) - return rm!=None and rm.group(0)==s + if rm is None: + return 0 + elif rm.group(0)!=s: + return 0 + else: + return 1 SAFE_STRING_PATTERN = b'(^(\000|\n|\r| |:|<)|[\000\n\r\200-\377]+|[ ]+$)' safe_string_re = re.compile(SAFE_STRING_PATTERN) -def list_dict(l): +def list_dict(l: List[str]) -> Dict[str, None]: """ return a dictionary with all items of l being the keys of the dictionary """ @@ -78,7 +90,13 @@ class LDIFWriter: via URLs """ - def __init__(self,output_file,base64_attrs=None,cols=76,line_sep='\n'): + def __init__( + self, + output_file: TextIO, + base64_attrs: Optional[List[str]] = [], + cols: int = 76, + line_sep: str = '\n' + ) -> None: """ output_file file object for output; should be opened in *text* mode @@ -96,7 +114,7 @@ def __init__(self,output_file,base64_attrs=None,cols=76,line_sep='\n'): self._last_line_sep = line_sep self.records_written = 0 - def _unfold_lines(self,line): + def _unfold_lines(self, line: str) -> None: """ Write string line as one or more folded lines """ @@ -115,9 +133,8 @@ def _unfold_lines(self,line): self._output_file.write(line[pos:min(line_len,pos+self._cols-1)]) self._output_file.write(self._last_line_sep) pos = pos+self._cols-1 - return # _unfold_lines() - def _needs_base64_encoding(self,attr_type,attr_value): + def _needs_base64_encoding(self, attr_type: str, attr_value: bytes) -> int: """ returns 1 if attr_value has to be base-64 encoded because of special chars or because attr_type is in self._base64_attrs @@ -125,7 +142,7 @@ def _needs_base64_encoding(self,attr_type,attr_value): return attr_type.lower() in self._base64_attrs or \ not safe_string_re.search(attr_value) is None - def _unparseAttrTypeandValue(self,attr_type,attr_value): + def _unparseAttrTypeandValue(self, attr_type: str, attr_value: bytes) -> None: """ Write a single attribute type/value pair @@ -141,9 +158,8 @@ def _unparseAttrTypeandValue(self,attr_type,attr_value): self._unfold_lines(':: '.join([attr_type, encoded])) else: self._unfold_lines(': '.join([attr_type, attr_value.decode('ascii')])) - return # _unparseAttrTypeandValue() - def _unparseEntryRecord(self,entry): + def _unparseEntryRecord(self, entry: LDAPEntryDict) -> None: """ entry dictionary holding an entry @@ -152,7 +168,7 @@ def _unparseEntryRecord(self,entry): for attr_value in values: self._unparseAttrTypeandValue(attr_type,attr_value) - def _unparseChangeRecord(self,modlist): + def _unparseChangeRecord(self, modlist: LDAPModList) -> None: """ modlist list of additions (2-tuple) or modifications (3-tuple) @@ -166,21 +182,27 @@ def _unparseChangeRecord(self,modlist): raise ValueError("modlist item of wrong length: %d" % (mod_len)) self._unparseAttrTypeandValue('changetype',changetype.encode('ascii')) for mod in modlist: - if mod_len==2: - mod_type,mod_vals = mod - elif mod_len==3: + # Note: the following order will give mod_vals the right type + if mod_len==3: + mod = cast(LDAPModListModifyEntry, mod) mod_op,mod_type,mod_vals = mod self._unparseAttrTypeandValue(MOD_OP_STR[mod_op], mod_type.encode('ascii')) + elif mod_len==2: + mod = cast(LDAPModListAddEntry, mod) + mod_type,mod_vals = mod else: raise ValueError("Subsequent modlist item of wrong length") if mod_vals: - for mod_val in mod_vals: - self._unparseAttrTypeandValue(mod_type,mod_val) + if isinstance(mod_vals, bytes): + self._unparseAttrTypeandValue(mod_type,mod_vals) + else: + for mod_val in mod_vals: + self._unparseAttrTypeandValue(mod_type,mod_val) if mod_len==3: self._output_file.write('-'+self._last_line_sep) - def unparse(self,dn,record): + def unparse(self, dn: str, record: Union[LDAPEntryDict, LDAPModList]) -> None: """ dn string-representation of distinguished name @@ -189,8 +211,7 @@ def unparse(self,dn,record): or a list with a modify list like for LDAPObject.modify(). """ # Start with line containing the distinguished name - dn = dn.encode('utf-8') - self._unparseAttrTypeandValue('dn', dn) + self._unparseAttrTypeandValue('dn', dn.encode('utf-8')) # Dispatch to record type specific writers if isinstance(record,dict): self._unparseEntryRecord(record) @@ -202,10 +223,14 @@ def unparse(self,dn,record): self._output_file.write(self._last_line_sep) # Count records written self.records_written = self.records_written+1 - return # unparse() -def CreateLDIF(dn,record,base64_attrs=None,cols=76): +def CreateLDIF( + dn: str, + record: Union[LDAPEntryDict, LDAPModList], + base64_attrs: List[str], + cols: int = 76, + ) -> str: """ Create LDIF single formatted record including trailing empty line. This is a compatibility function. @@ -248,12 +273,12 @@ class and override method handle() to implement something meaningful. def __init__( self, - input_file, - ignored_attr_types=None, - max_entries=0, - process_url_schemes=None, - line_sep='\n' - ): + input_file: Union[TextIO, BinaryIO], + ignored_attr_types: Optional[List[str]] = [], + max_entries: int = 0, + process_url_schemes: Optional[List[str]] = [], + line_sep: str = '\n', + ) -> None: """ Parameters: input_file @@ -270,14 +295,19 @@ def __init__( line_sep String used as line separator """ - self._input_file = input_file # Detect whether the file is open in text or bytes mode. - self._file_sends_bytes = isinstance(self._input_file.read(0), bytes) + if isinstance(input_file.read(0), bytes): + self._binary_input_file: Optional[BinaryIO] = cast(BinaryIO, input_file) + self._text_input_file: Optional[TextIO] = None + else: + self._binary_input_file = None + self._text_input_file = cast(TextIO, input_file) + self._max_entries = max_entries self._process_url_schemes = list_dict([s.lower() for s in (process_url_schemes or [])]) self._ignored_attr_types = list_dict([a.lower() for a in (ignored_attr_types or [])]) self._last_line_sep = line_sep - self.version = None + self.version: Optional[int] = None # Initialize counters self.line_counter = 0 self.byte_counter = 0 @@ -291,19 +321,23 @@ def __init__( except EOFError: self._last_line = '' - def handle(self,dn,entry): + def handle(self, dn: str, entry: LDAPEntryDict) -> Optional[str]: """ Process a single content LDIF record. This method should be implemented by applications using LDIFParser. """ pass - def _readline(self): - s = self._input_file.readline() - if self._file_sends_bytes: + def _readline(self) -> Optional[str]: + if self._text_input_file is not None: + s = self._text_input_file.readline() + elif self._binary_input_file is not None: # The RFC does not allow UTF-8 values; we support it as a # non-official, backwards compatibility layer - s = s.decode('utf-8') + s = self._binary_input_file.readline().decode('utf-8') + else: + return None + self.line_counter = self.line_counter + 1 self.byte_counter = self.byte_counter + len(s) if not s: @@ -315,7 +349,7 @@ def _readline(self): else: return s - def _unfold_lines(self): + def _unfold_lines(self) -> str: """ Unfold several folded lines with trailing space into one line """ @@ -332,7 +366,7 @@ def _unfold_lines(self): self._last_line = next_line return ''.join(unfolded_lines) - def _next_key_and_value(self): + def _next_key_and_value(self) -> Tuple[Optional[str], Optional[bytes]]: """ Parse a single attribute type and value pair from one or more lines of LDIF data @@ -356,16 +390,15 @@ def _next_key_and_value(self): # if needed attribute value is BASE64 decoded value_spec = unfolded_line[colon_pos:colon_pos+2] if value_spec==': ': - attr_value = unfolded_line[colon_pos+2:].lstrip() # All values should be valid ascii; we support UTF-8 as a # non-official, backwards compatibility layer. - attr_value = attr_value.encode('utf-8') + attr_value_str = unfolded_line[colon_pos+2:].lstrip() + attr_value = attr_value_str.encode('utf-8') elif value_spec=='::': # attribute value needs base64-decoding - # base64 makes sens only for ascii - attr_value = unfolded_line[colon_pos+2:] - attr_value = attr_value.encode('ascii') - attr_value = self._b64decode(attr_value) + # base64 makes sense only for ascii + attr_value_str = unfolded_line[colon_pos+2:] + attr_value = self._b64decode(attr_value_str.encode('ascii')) elif value_spec==':<': # fetch attribute value from URL url = unfolded_line[colon_pos+2:].strip() @@ -380,7 +413,7 @@ def _next_key_and_value(self): attr_value = unfolded_line[colon_pos+1:].encode('utf-8') return attr_type,attr_value - def _consume_empty_lines(self): + def _consume_empty_lines(self) -> Tuple[Optional[str], Optional[bytes]]: """ Consume empty lines until first non-empty line. Must only be used between full records! @@ -398,7 +431,7 @@ def _consume_empty_lines(self): k,v = None,None return k,v - def parse_entry_records(self): + def parse_entry_records(self) -> None: """ Continuously read and parse LDIF entry records """ @@ -410,7 +443,8 @@ def parse_entry_records(self): k,v = self._consume_empty_lines() # Consume 'version' line if k=='version': - self.version = int(v.decode('ascii')) + if v is not None: + self.version = int(v.decode('ascii')) k,v = self._consume_empty_lines() except EOFError: return @@ -423,42 +457,51 @@ def parse_entry_records(self): raise ValueError('Line %d: First line of record does not start with "dn:": %s' % (self.line_counter,repr(k))) # Value of a 'dn' field *has* to be valid UTF-8 # k is text, v is bytes. - v = v.decode('utf-8') - if not is_dn(v): + if v is None: + raise ValueError('Line %d: DN has None value.' % (self.line_counter)) + dn = v.decode('utf-8') + if not is_dn(dn): raise ValueError('Line %d: Not a valid string-representation for dn: %s.' % (self.line_counter,repr(v))) - dn = v - entry = {} - # Consume second line of record - k,v = next_key_and_value() + + entry: LDAPEntryDict = {} # Loop for reading the attributes - while k!=None: + while True: + try: + k,v = next_key_and_value() + except EOFError: + break + + if k is None: + break + elif v is None: + continue + # Add the attribute to the entry if not ignored attribute if not k.lower() in self._ignored_attr_types: try: entry[k].append(v) except KeyError: entry[k]=[v] - # Read the next line within the record - try: - k,v = next_key_and_value() - except EOFError: - k,v = None,None # handle record self.handle(dn,entry) self.records_read = self.records_read + 1 # Consume empty separator line(s) k,v = self._consume_empty_lines() - return # parse_entry_records() - def parse(self): + def parse(self) -> None: """ Invokes LDIFParser.parse_entry_records() for backward compatibility """ - return self.parse_entry_records() # parse() + self.parse_entry_records() - def handle_modify(self,dn,modops,controls=None): + def handle_modify( + self, + dn: str, + modops: LDAPModList, + controls: Optional[LDAPControls] = None, + ) -> None: """ Process a single LDIF record representing a single modify operation. This method should be implemented by applications using LDIFParser. @@ -466,14 +509,15 @@ def handle_modify(self,dn,modops,controls=None): controls = [] or None pass - def parse_change_records(self): + def parse_change_records(self) -> None: # Local symbol for better performance next_key_and_value = self._next_key_and_value # Consume empty lines k,v = self._consume_empty_lines() # Consume 'version' line if k=='version': - self.version = int(v) + if v is not None: + self.version = int(v.decode('ascii')) k,v = self._consume_empty_lines() # Loop for processing whole records @@ -484,60 +528,71 @@ def parse_change_records(self): raise ValueError('Line %d: First line of record does not start with "dn:": %s' % (self.line_counter,repr(k))) # Value of a 'dn' field *has* to be valid UTF-8 # k is text, v is bytes. - v = v.decode('utf-8') - if not is_dn(v): + if v is None: + raise ValueError('Line %d: DN has None value.' % (self.line_counter)) + dn = v.decode('utf-8') + if not is_dn(dn): raise ValueError('Line %d: Not a valid string-representation for dn: %s.' % (self.line_counter,repr(v))) - dn = v + # Consume second line of record k,v = next_key_and_value() # Read "control:" lines controls = [] while k!=None and k=='control': + if v is None: + raise ValueError('Line %d: control has None value.' % (self.line_counter)) # v is still bytes, spec says it should be valid utf-8; decode it. - v = v.decode('utf-8') + control = v.decode('utf-8') try: - control_type,criticality,control_value = v.split(' ',2) + control_type,criticality,control_value = control.split(' ',2) except ValueError: control_value = None - control_type,criticality = v.split(' ',1) + control_type,criticality = control.split(' ',1) controls.append((control_type,criticality,control_value)) k,v = next_key_and_value() # Determine changetype first - changetype = None + changetype = '' # Consume changetype line of record if k=='changetype': + if v is None: + raise ValueError('Line %d: changetype has None value.' % (self.line_counter)) # v is still bytes, spec says it should be valid utf-8; decode it. - v = v.decode('utf-8') - if not v in valid_changetype_dict: + changetype = v.decode('utf-8') + if not changetype in valid_changetype_set: raise ValueError('Invalid changetype: %s' % repr(v)) - changetype = v k,v = next_key_and_value() if changetype=='modify': - # From here we assume a change record is read with changetype: modify modops = [] try: # Loop for reading the list of modifications - while k!=None: + while True: + if k is None: + break + # Extract attribute mod-operation (add, delete, replace) try: modop = MOD_OP_INTEGER[k] except KeyError: raise ValueError('Line %d: Invalid mod-op string: %s' % (self.line_counter,repr(k))) + + if v is None: + raise ValueError('Line %d: mod-op has None value.' % (self.line_counter)) + # we now have the attribute name to be modified # v is still bytes, spec says it should be valid utf-8; decode it. - v = v.decode('utf-8') - modattr = v + modattr = v.decode('utf-8') modvalues = [] try: k,v = next_key_and_value() except EOFError: k,v = None,None while k==modattr: - modvalues.append(v) + if v is not None: + modvalues.append(v) try: k,v = next_key_and_value() except EOFError: @@ -570,8 +625,6 @@ def parse_change_records(self): self.changetype_counter[changetype] = 1 self.records_read = self.records_read + 1 - return # parse_change_records() - class LDIFRecordList(LDIFParser): """ @@ -583,22 +636,29 @@ class LDIFRecordList(LDIFParser): def __init__( self, - input_file, - ignored_attr_types=None,max_entries=0,process_url_schemes=None - ): + input_file: Union[TextIO, BinaryIO], + ignored_attr_types: Optional[List[str]] = [], + max_entries: int = 0, + process_url_schemes: Optional[List[str]] = [], + ) -> None: LDIFParser.__init__(self,input_file,ignored_attr_types,max_entries,process_url_schemes) #: List storing parsed records. - self.all_records = [] - self.all_modify_changes = [] + self.all_records: List[Tuple[str, LDAPEntryDict]] = [] + self.all_modify_changes: List[Tuple[str, LDAPModList, Optional[LDAPControls]]] = [] - def handle(self,dn,entry): + def handle(self, dn: str, entry: LDAPEntryDict) -> None: """ Append a single record to the list of all records (:attr:`.all_records`). """ self.all_records.append((dn,entry)) - def handle_modify(self,dn,modops,controls=None): + def handle_modify( + self, + dn: str, + modops: LDAPModList, + controls: Optional[LDAPControls] = None, + ) -> None: """ Process a single LDIF record representing a single modify operation. This method should be implemented by applications using LDIFParser. @@ -615,24 +675,33 @@ class LDIFCopy(LDIFParser): def __init__( self, - input_file,output_file, - ignored_attr_types=None,max_entries=0,process_url_schemes=None, - base64_attrs=None,cols=76,line_sep='\n' - ): + input_file: Union[TextIO, BinaryIO], + output_file: TextIO, + ignored_attr_types: Optional[List[str]] = [], + max_entries: int = 0, + process_url_schemes: Optional[List[str]] = [], + base64_attrs: List[str] = [], + cols: int = 76, + line_sep: str = '\n' + ) -> None: """ See LDIFParser.__init__() and LDIFWriter.__init__() """ LDIFParser.__init__(self,input_file,ignored_attr_types,max_entries,process_url_schemes) self._output_ldif = LDIFWriter(output_file,base64_attrs,cols,line_sep) - def handle(self,dn,entry): + def handle(self, dn: str, entry: LDAPEntryDict) -> None: """ Write single LDIF record to output file. """ self._output_ldif.unparse(dn,entry) -def ParseLDIF(f,ignore_attrs=None,maxentries=0): +def ParseLDIF( + f: Union[TextIO, BinaryIO], + ignore_attrs: Optional[List[str]] = [], + maxentries: int = 0 + ) -> List[Tuple[str, LDAPEntryDict]]: """ Parse LDIF records read from file. This is a compatibility function. @@ -644,7 +713,7 @@ def ParseLDIF(f,ignore_attrs=None,maxentries=0): stacklevel=2, ) ldif_parser = LDIFRecordList( - f,ignored_attr_types=ignore_attrs,max_entries=maxentries,process_url_schemes=0 + f,ignored_attr_types=ignore_attrs,max_entries=maxentries ) ldif_parser.parse() return ldif_parser.all_records diff --git a/Lib/slapdtest/__init__.py b/Lib/slapdtest/__init__.py index 7c410180..8f9f6620 100644 --- a/Lib/slapdtest/__init__.py +++ b/Lib/slapdtest/__init__.py @@ -6,7 +6,8 @@ __version__ = '3.4.4' -from slapdtest._slapdtest import SlapdObject, SlapdTestCase, SysLogHandler +from logging.handlers import SysLogHandler +from slapdtest._slapdtest import SlapdObject, SlapdTestCase from slapdtest._slapdtest import requires_ldapi, requires_sasl, requires_tls from slapdtest._slapdtest import requires_init_fd from slapdtest._slapdtest import skip_unless_ci diff --git a/Lib/slapdtest/_slapdtest.py b/Lib/slapdtest/_slapdtest.py index 36841110..458bc3c5 100644 --- a/Lib/slapdtest/_slapdtest.py +++ b/Lib/slapdtest/_slapdtest.py @@ -15,7 +15,11 @@ from shutil import which from urllib.parse import quote_plus -# Switch off processing .ldaprc or ldap.conf before importing _ldap +from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar, Tuple +from types import TracebackType +from typing_extensions import Self + +# Switch off processing .ldaprc or ldap.conf before importing ldap._ldap os.environ['LDAPNOINIT'] = '1' import ldap @@ -60,14 +64,16 @@ HAVE_LDAPI = hasattr(socket, 'AF_UNIX') -def identity(test_item): +T = TypeVar('T', bound=Any) + +def identity(test_item: T) -> T: """Identity decorator """ return test_item -def skip_unless_ci(reason, feature=None): +def skip_unless_ci(reason: str, feature: Optional[str] = None) -> Callable[..., Any]: """Skip test unless test case is executed on CI like Travis CI """ if not os.environ.get('CI', False): @@ -79,7 +85,7 @@ def skip_unless_ci(reason, feature=None): return identity -def requires_tls(): +def requires_tls() -> Callable[..., Any]: """Decorator for TLS tests Tests are not skipped on CI (e.g. Travis CI) @@ -90,7 +96,7 @@ def requires_tls(): return identity -def requires_sasl(): +def requires_sasl() -> Callable[..., Any]: if not ldap.SASL_AVAIL: return skip_unless_ci( "test needs ldap.SASL_AVAIL", feature='SASL') @@ -98,14 +104,14 @@ def requires_sasl(): return identity -def requires_ldapi(): +def requires_ldapi() -> Callable[..., Any]: if not HAVE_LDAPI: return skip_unless_ci( "test needs ldapi support (AF_UNIX)", feature='LDAPI') else: return identity -def requires_init_fd(): +def requires_init_fd() -> Callable[..., Any]: if not ldap.INIT_FD_AVAIL: return skip_unless_ci( "test needs ldap.INIT_FD", feature='INIT_FD') @@ -113,7 +119,7 @@ def requires_init_fd(): return identity -def _add_sbin(path): +def _add_sbin(path: str) -> str: """Add /sbin and related directories to a command search path""" directories = path.split(os.pathsep) if sys.platform != 'win32': @@ -123,19 +129,19 @@ def _add_sbin(path): return os.pathsep.join(directories) def combined_logger( - log_name, - log_level=logging.WARN, - sys_log_format='%(levelname)s %(message)s', - console_log_format='%(asctime)s %(levelname)s %(message)s', - ): + log_name: str, + log_level: int = logging.WARN, + sys_log_format: str = '%(levelname)s %(message)s', + console_log_format: str = '%(asctime)s %(levelname)s %(message)s', + ) -> logging.Logger: """ Returns a combined SysLogHandler/StreamHandler logging instance with formatters """ if 'LOGLEVEL' in os.environ: - log_level = os.environ['LOGLEVEL'] + log_level_str = os.environ['LOGLEVEL'] try: - log_level = int(log_level) + log_level = int(log_level_str) except ValueError: pass # for writing to syslog @@ -210,8 +216,8 @@ class SlapdObject: # create loggers once, multiple calls mess up refleak tests _log = combined_logger('python-ldap-test') - def __init__(self): - self._proc = None + def __init__(self) -> None: + self._proc: Optional[subprocess.Popen[bytes]] = None self._port = self._avail_tcp_port() self.server_id = self._port % 4096 self.testrundir = os.path.join(self.TMPDIR, 'python-ldap-test-%d' % self._port) @@ -220,7 +226,7 @@ def __init__(self): self.ldap_uri = "ldap://%s:%d/" % (self.local_host, self._port) if HAVE_LDAPI: ldapi_path = os.path.join(self.testrundir, 'ldapi') - self.ldapi_uri = "ldapi://%s" % quote_plus(ldapi_path) + self.ldapi_uri: Optional[str] = "ldapi://%s" % quote_plus(ldapi_path) self.default_ldap_uri = self.ldapi_uri # use SASL/EXTERNAL via LDAPI when invoking OpenLDAP CLI tools self.cli_sasl_external = ldap.SASL_AVAIL @@ -243,29 +249,31 @@ def __init__(self): self.clientkey = os.path.join(HERE, 'certs/client.key') @property - def root_dn(self): + def root_dn(self) -> str: return 'cn={self.root_cn},{self.suffix}'.format(self=self) @property - def hostname(self): + def hostname(self) -> str: return self.local_host @property - def port(self): + def port(self) -> int: return self._port - def _find_commands(self): + def _find_commands(self) -> None: self.PATH_LDAPADD = self._find_command('ldapadd') self.PATH_LDAPDELETE = self._find_command('ldapdelete') self.PATH_LDAPMODIFY = self._find_command('ldapmodify') self.PATH_LDAPWHOAMI = self._find_command('ldapwhoami') self.PATH_SLAPADD = self._find_command('slapadd') - self.PATH_SLAPD = os.environ.get('SLAPD', None) - if not self.PATH_SLAPD: + env_path_slapd = os.environ.get('SLAPD', None) + if env_path_slapd is not None: + self.PATH_SLAPD = env_path_slapd + else: self.PATH_SLAPD = self._find_command('slapd', in_sbin=True) - def _find_command(self, cmd, in_sbin=False): + def _find_command(self, cmd: str, in_sbin: bool = False) -> str: if in_sbin: path = self.SBIN_PATH var_name = 'SBIN' @@ -280,7 +288,7 @@ def _find_command(self, cmd, in_sbin=False): ) return command - def setup_rundir(self): + def setup_rundir(self) -> None: """ creates rundir structure @@ -291,7 +299,7 @@ def setup_rundir(self): os.mkdir(self._db_directory) self._create_sub_dirs(self.testrunsubdirs) - def _cleanup_rundir(self): + def _cleanup_rundir(self) -> None: """ Recursively delete whole directory specified by `path' """ @@ -314,20 +322,20 @@ def _cleanup_rundir(self): os.rmdir(self.testrundir) self._log.info('cleaned-up %s', self.testrundir) - def _avail_tcp_port(self): + def _avail_tcp_port(self) -> int: """ find an available port for TCP connection """ sock = socket.socket() try: sock.bind((self.local_host, 0)) - port = sock.getsockname()[1] + port = int(sock.getsockname()[1]) finally: sock.close() self._log.info('Found available port %d', port) return port - def gen_config(self): + def gen_config(self) -> str: """ generates a slapd.conf and returns it as one string @@ -350,7 +358,7 @@ def gen_config(self): } return self.slapd_conf_template % config_dict - def _create_sub_dirs(self, dir_names): + def _create_sub_dirs(self, dir_names: Iterable[str]) -> None: """ create sub-directories beneath self.testrundir """ @@ -359,7 +367,7 @@ def _create_sub_dirs(self, dir_names): self._log.debug('Create directory %s', dir_name) os.mkdir(dir_name) - def _write_config(self): + def _write_config(self) -> None: """Loads the slapd.d configuration.""" self._log.debug("importing configuration: %s", self._slapd_conf) @@ -375,7 +383,7 @@ def _write_config(self): self._log.debug("import ok: %s", self._slapd_conf) - def _test_config(self): + def _test_config(self) -> None: self._log.debug('testing config %s', self._slapd_conf) popen_list = [ self.PATH_SLAPD, @@ -395,7 +403,7 @@ def _test_config(self): raise RuntimeError("configuration test failed") self._log.info("config ok: %s", self._slapd_conf) - def _start_slapd(self): + def _start_slapd(self) -> None: """ Spawns/forks the slapd process """ @@ -434,7 +442,7 @@ def _start_slapd(self): return raise RuntimeError("slapd did not start properly") - def start(self): + def start(self) -> None: """ Starts the slapd server process running, and waits for it to come up. """ @@ -447,12 +455,14 @@ def start(self): self._write_config() self._test_config() self._start_slapd() + if self._proc is None: + raise RuntimeError("started slapd but self._proc is None") self._log.debug( 'slapd with pid=%d listening on %s and %s', self._proc.pid, self.ldap_uri, self.ldapi_uri ) - def stop(self): + def stop(self) -> None: """ Stops the slapd server, and waits for it to terminate and cleans up """ @@ -463,27 +473,28 @@ def stop(self): self._cleanup_rundir() atexit.unregister(self.stop) - def restart(self): + def restart(self) -> None: """ Restarts the slapd server with same data """ - self._proc.terminate() - self.wait() + if self._proc is not None: + self._proc.terminate() + self.wait() self._start_slapd() - def wait(self): + def wait(self) -> None: """Waits for the slapd process to terminate by itself.""" if self._proc: self._proc.wait() self._stopped() - def _stopped(self): + def _stopped(self) -> None: """Called when the slapd server is known to have terminated""" if self._proc is not None: self._log.info('slapd[%d] terminated', self._proc.pid) self._proc = None - def _cli_auth_args(self): + def _cli_auth_args(self) -> List[str]: if self.cli_sasl_external: authc_args = [ '-Y', 'EXTERNAL', @@ -499,8 +510,13 @@ def _cli_auth_args(self): return authc_args # no cover to avoid spurious coverage changes - def _cli_popen(self, ldapcommand, extra_args=None, ldap_uri=None, - stdin_data=None): # pragma: no cover + def _cli_popen( + self, + ldapcommand: str, + extra_args: Optional[List[str]] = None, + ldap_uri: Optional[str] = None, + stdin_data: Optional[bytes] = None, + ) -> Tuple[bytes, bytes]: # pragma: no cover if ldap_uri is None: ldap_uri = self.default_ldap_uri @@ -530,27 +546,32 @@ def _cli_popen(self, ldapcommand, extra_args=None, ldap_uri=None, ) return stdout_data, stderr_data - def ldapwhoami(self, extra_args=None): + def ldapwhoami(self, extra_args: Optional[List[str]] = None) -> None: """ Runs ldapwhoami on this slapd instance """ self._cli_popen(self.PATH_LDAPWHOAMI, extra_args=extra_args) - def ldapadd(self, ldif, extra_args=None): + def ldapadd(self, ldif: str, extra_args: Optional[List[str]] = None) -> None: """ Runs ldapadd on this slapd instance, passing it the ldif content """ self._cli_popen(self.PATH_LDAPADD, extra_args=extra_args, stdin_data=ldif.encode('utf-8')) - def ldapmodify(self, ldif, extra_args=None): + def ldapmodify(self, ldif: str, extra_args: Optional[List[str]] = None) -> None: """ Runs ldapadd on this slapd instance, passing it the ldif content """ self._cli_popen(self.PATH_LDAPMODIFY, extra_args=extra_args, stdin_data=ldif.encode('utf-8')) - def ldapdelete(self, dn, recursive=False, extra_args=None): + def ldapdelete( + self, + dn: str, + recursive: bool = False, + extra_args: Optional[List[str]] = None + ) -> None: """ Runs ldapdelete on this slapd instance, deleting 'dn' """ @@ -561,7 +582,11 @@ def ldapdelete(self, dn, recursive=False, extra_args=None): extra_args.append(dn) self._cli_popen(self.PATH_LDAPDELETE, extra_args=extra_args) - def slapadd(self, ldif, extra_args=None): + def slapadd( + self, + ldif: Optional[str], + extra_args: Optional[List[str]] = None + ) -> None: """ Runs slapadd on this slapd instance, passing it the ldif content """ @@ -571,11 +596,16 @@ def slapadd(self, ldif, extra_args=None): extra_args=extra_args, ) - def __enter__(self): + def __enter__(self) -> Self: self.start() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: self.stop() @@ -588,10 +618,20 @@ class SlapdTestCase(unittest.TestCase): server = None ldap_object_class = None - def _open_ldap_conn(self, who=None, cred=None, **kwargs): + def _open_ldap_conn( + self, + who: Optional[str] = None, + cred: Optional[str] = None, + **kwargs: Any, + ) -> ldap.ldapobject.LDAPObject: """ return a LDAPObject instance after simple bind """ + if self.server is None: + raise RuntimeError("_open_ldap_conn: self.server is None") + elif self.ldap_object_class is None: + raise RuntimeError("_open_ldap_conn: self.ldap_object_class is None") + ldap_conn = self.ldap_object_class(self.server.ldap_uri, **kwargs) ldap_conn.protocol_version = 3 #ldap_conn.set_option(ldap.OPT_REFERRALS, 0) @@ -599,10 +639,11 @@ def _open_ldap_conn(self, who=None, cred=None, **kwargs): return ldap_conn @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.server = cls.server_class() cls.server.start() @classmethod - def tearDownClass(cls): - cls.server.stop() + def tearDownClass(cls) -> None: + if cls.server is not None: + cls.server.stop() diff --git a/MANIFEST.in b/MANIFEST.in index 687d2b0c..f28afe80 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,8 @@ include MANIFEST.in Makefile CHANGES INSTALL LICENCE README TODO include tox.ini .coveragerc include Modules/*.c Modules/*.h +include Lib/ldap/py.typed +include Lib/ldap/_ldap.pyi recursive-include Build *.cfg* recursive-include Lib *.py recursive-include Demo *.py diff --git a/Tests/t_cext.py b/Tests/t_cext.py index 33fbf29a..f319d100 100644 --- a/Tests/t_cext.py +++ b/Tests/t_cext.py @@ -13,13 +13,13 @@ os.environ['LDAPNOINIT'] = '1' # import the plain C wrapper module -import _ldap +import ldap._ldap as _ldap from slapdtest import SlapdTestCase, requires_tls, requires_init_fd class TestLdapCExtension(SlapdTestCase): """ - These tests apply only to the _ldap module and therefore bypass the + These tests apply only to the ldap._ldap module and therefore bypass the LDAPObject wrapper completely. """ diff --git a/setup.cfg b/setup.cfg index fdb32fbc..e6d37db5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ defines = HAVE_SASL HAVE_TLS extra_compile_args = extra_objects = +extra_files = ./ldap/:Lib/ldap/_ldap.pyi # Uncomment this if your libldap is not thread-safe and you need libldap_r # instead @@ -49,3 +50,14 @@ python_files = t_*.py filterwarnings = error ignore::ldap.LDAPBytesWarning + +# mypy, https://mypy.readthedocs.io/en/latest/ +[mypy] +strict = True +files = Lib/ + +[mypy-pyasn1.*] +ignore_missing_imports = True + +[mypy-pyasn1_modules.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index dbf66a04..e8bbf5c1 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ class OpenLDAP2: #-- C extension modules ext_modules = [ Extension( - '_ldap', + 'ldap._ldap', [ 'Modules/LDAPObject.c', 'Modules/ldapcontrol.c', @@ -141,7 +141,6 @@ class OpenLDAP2: py_modules = [ 'ldapurl', 'ldif', - ], packages = [ 'ldap', @@ -152,11 +151,13 @@ class OpenLDAP2: 'slapdtest.certs', ], package_dir = {'': 'Lib',}, + package_data = {'ldap': ['py.typed'],}, data_files = LDAP_CLASS.extra_files, include_package_data=True, install_requires=[ 'pyasn1 >= 0.3.7', 'pyasn1_modules >= 0.1.5', + 'typing_extensions >= 0.4.1', ], zip_safe=False, python_requires='>=3.6', diff --git a/tox.ini b/tox.ini index beade024..1c6ce1d8 100644 --- a/tox.ini +++ b/tox.ini @@ -4,18 +4,25 @@ # and then run "tox" from this directory. [tox] -# Note: when updating Python versions, also change setup.py and .github/worlflows/* -envlist = py{36,37,38,39,310,311,312},c90-py{36,37},py3-nosasltls,doc,py3-trace,pypy3.9 +# Note: when updating Python versions, also change setup.py and .github/workflows/* +envlist = + py{36,37,38,39,310,311,312} + c90-py{36,37} + py3-nosasltls + doc + mypy + py3-trace + pypy3.9 minver = 1.8 [gh-actions] python = 3.6: py36 - 3.7: py37 - 3.8: py38, doc, py3-nosasltls - 3.9: py39, py3-trace - 3.10: py310 - 3.11: py311 + 3.7: py37, mypy + 3.8: py38, mypy, doc, py3-nosasltls + 3.9: py39, mypy, py3-trace + 3.10: py310, mypy + 3.11: py311, mypy 3.12: py312 pypy3.9: pypy3.9 @@ -42,6 +49,7 @@ deps = {[testenv]deps} pyasn1 pyasn1_modules + typing_extensions passenv = {[testenv]passenv} setenv = CI_DISABLED=LDAPI:SASL:TLS @@ -91,6 +99,13 @@ basepython = pypy3 deps = pytest commands = {envpython} -m pytest {posargs} +[testenv:mypy] +basepython = python3 +deps = mypy +commands = + {envpython} -m mypy --config-file {toxinidir}/setup.cfg {posargs} + {envpython} -m mypy.stubtest --mypy-config-file {toxinidir}/setup.cfg ldap._ldap {posargs} + [testenv:doc] basepython = python3 deps =