Skip to content

Commit

Permalink
fix some types when downstream enabled this package in mypy (#815)
Browse files Browse the repository at this point in the history
* fix: select arg type

* fix some type error when mypy enable

* revert

* revert

* fix
  • Loading branch information
trim21 authored Nov 18, 2024
1 parent c1a2fe1 commit a50b367
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 28 deletions.
54 changes: 54 additions & 0 deletions pypika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,63 @@
FunctionException,
)


__author__ = "Timothy Heys"
__email__ = "theys@kayak.com"
__version__ = "0.48.9"

NULL = NullValue()
SYSTEM_TIME = SystemTimeValue()

__all__ = (
'ClickHouseQuery',
'Dialects',
'MSSQLQuery',
'MySQLQuery',
'OracleQuery',
'PostgreSQLQuery',
'RedshiftQuery',
'SQLLiteQuery',
'VerticaQuery',
'DatePart',
'JoinType',
'Order',
'AliasedQuery',
'Query',
'Schema',
'Table',
'Column',
'Database',
'Tables',
'Columns',
'Array',
'Bracket',
'Case',
'Criterion',
'EmptyCriterion',
'Field',
'Index',
'Interval',
'JSON',
'Not',
'NullValue',
'SystemTimeValue',
'Parameter',
'QmarkParameter',
'NumericParameter',
'NamedParameter',
'FormatParameter',
'PyformatParameter',
'Rollup',
'Tuple',
'CustomFunction',
'CaseException',
'GroupingException',
'JoinException',
'QueryException',
'RollupException',
'SetOperationException',
'FunctionException',
'NULL',
'SYSTEM_TIME',
)
60 changes: 33 additions & 27 deletions pypika/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""
Package for SQL functions wrappers
"""
from __future__ import annotations

from typing import Optional

from pypika import Field
from pypika.enums import SqlTypes
from pypika.terms import (
AggregateFunction,
Expand All @@ -10,6 +15,7 @@
)
from pypika.utils import builder


__author__ = "Timothy Heys"
__email__ = "theys@kayak.com"

Expand All @@ -34,64 +40,64 @@ def distinct(self):


class Count(DistinctOptionFunction):
def __init__(self, param, alias=None):
def __init__(self, param: str | Field, alias: Optional[str] = None) -> None:
is_star = isinstance(param, str) and "*" == param
super(Count, self).__init__("COUNT", Star() if is_star else param, alias=alias)


# Arithmetic Functions
class Sum(DistinctOptionFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Sum, self).__init__("SUM", term, alias=alias)


class Avg(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Avg, self).__init__("AVG", term, alias=alias)


class Min(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Min, self).__init__("MIN", term, alias=alias)


class Max(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Max, self).__init__("MAX", term, alias=alias)


class Std(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Std, self).__init__("STD", term, alias=alias)


class StdDev(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(StdDev, self).__init__("STDDEV", term, alias=alias)


class Abs(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Abs, self).__init__("ABS", term, alias=alias)


class First(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(First, self).__init__("FIRST", term, alias=alias)


class Last(AggregateFunction):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Last, self).__init__("LAST", term, alias=alias)


class Sqrt(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Sqrt, self).__init__("SQRT", term, alias=alias)


class Floor(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Floor, self).__init__("FLOOR", term, alias=alias)


Expand Down Expand Up @@ -131,17 +137,17 @@ def __init__(self, term, as_type, alias=None):


class Signed(Cast):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Signed, self).__init__(term, SqlTypes.SIGNED, alias=alias)


class Unsigned(Cast):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Unsigned, self).__init__(term, SqlTypes.UNSIGNED, alias=alias)


class Date(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Date, self).__init__("DATE", term, alias=alias)


Expand All @@ -156,7 +162,7 @@ def __init__(self, start_time, end_time, alias=None):


class DateAdd(Function):
def __init__(self, date_part, interval, term, alias=None):
def __init__(self, date_part, interval, term: str, alias: Optional[str] = None):
date_part = getattr(date_part, "value", date_part)
super(DateAdd, self).__init__("DATE_ADD", LiteralValue(date_part), interval, term, alias=alias)

Expand All @@ -167,19 +173,19 @@ def __init__(self, value, format_mask, alias=None):


class Timestamp(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Timestamp, self).__init__("TIMESTAMP", term, alias=alias)


class TimestampAdd(Function):
def __init__(self, date_part, interval, term, alias=None):
def __init__(self, date_part, interval, term: str, alias: Optional[str] = None):
date_part = getattr(date_part, 'value', date_part)
super(TimestampAdd, self).__init__("TIMESTAMPADD", LiteralValue(date_part), interval, term, alias=alias)


# String Functions
class Ascii(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Ascii, self).__init__("ASCII", term, alias=alias)


Expand All @@ -189,7 +195,7 @@ def __init__(self, term, condition, **kwargs):


class Bin(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Bin, self).__init__("BIN", term, alias=alias)


Expand All @@ -205,17 +211,17 @@ def __init__(self, term, start, stop, subterm, alias=None):


class Length(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Length, self).__init__("LENGTH", term, alias=alias)


class Upper(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Upper, self).__init__("UPPER", term, alias=alias)


class Lower(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Lower, self).__init__("LOWER", term, alias=alias)


Expand All @@ -225,12 +231,12 @@ def __init__(self, term, start, stop, alias=None):


class Reverse(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Reverse, self).__init__("REVERSE", term, alias=alias)


class Trim(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(Trim, self).__init__("TRIM", term, alias=alias)


Expand Down Expand Up @@ -297,7 +303,7 @@ def get_special_params_sql(self, **kwargs):

# Null Functions
class IsNull(Function):
def __init__(self, term, alias=None):
def __init__(self, term: str | Field, alias: Optional[str] = None):
super(IsNull, self).__init__("ISNULL", term, alias=alias)


Expand All @@ -312,5 +318,5 @@ def __init__(self, condition, term, **kwargs):


class NVL(Function):
def __init__(self, condition, term, alias=None):
def __init__(self, condition, term: str, alias: Optional[str] = None):
super(NVL, self).__init__("NVL", condition, term, alias=alias)
2 changes: 1 addition & 1 deletion pypika/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __ne__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(str(self))

def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> "QueryBuilder":
def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder":
"""
Perform a SELECT operation on the current table
Expand Down

0 comments on commit a50b367

Please sign in to comment.