-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathbase.py
81 lines (68 loc) · 2.63 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from abc import ABC, abstractmethod
from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Union
class ModelInputNormalizer(ABC):
"""
The normalizer class is responsible for normalizing the input data
before it is passed to the pipeline. It converts the input data (raw request)
to the format expected by the pipeline.
"""
def _normalize_content_messages(self, data: Dict) -> Dict:
"""
If the request contains the "messages" key, make sure that it's content is a string.
"""
# Anyways copy the original data to avoid modifying it
if "messages" not in data:
return data.copy()
normalized_data = data.copy()
messages = normalized_data["messages"]
converted_messages = []
for msg in messages:
new_msg = msg.copy()
content = msg.get("content", "")
if isinstance(content, list):
# Convert list format to string
content_parts = []
for part in msg["content"]:
if isinstance(part, dict) and part.get("type") == "text":
content_parts.append(part["text"])
new_msg["content"] = " ".join(content_parts)
converted_messages.append(new_msg)
normalized_data["messages"] = converted_messages
return normalized_data
@abstractmethod
def normalize(self, data: Dict) -> Any:
"""Normalize the input data"""
pass
@abstractmethod
def denormalize(self, data: Any) -> Dict:
"""Denormalize the input data"""
pass
class ModelOutputNormalizer(ABC):
"""
The output normalizer class is responsible for normalizing the output data
from a model to the format expected by the output pipeline.
The normalize methods are not implemented yet - they will be when we get
around to implementing output pipelines.
"""
@abstractmethod
def normalize_streaming(
self,
model_reply: Union[AsyncIterable[Any], Iterable[Any]],
) -> Union[AsyncIterator[Any], Iterator[Any]]:
"""Normalize the output data"""
pass
@abstractmethod
def normalize(self, model_reply: Any) -> Any:
"""Normalize the output data"""
pass
@abstractmethod
def denormalize(self, normalized_reply: Any) -> Any:
"""Denormalize the output data"""
pass
@abstractmethod
def denormalize_streaming(
self,
normalized_reply: Union[AsyncIterable[Any], Iterable[Any]],
) -> Union[AsyncIterator[Any], Iterator[Any]]:
"""Denormalize the output data"""
pass