-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathoutput.py
233 lines (198 loc) · 8.52 KB
/
output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, List, Optional
import structlog
from codegate.db.connection import DbRecorder
from codegate.extract_snippets.message_extractor import CodeSnippet
from codegate.pipeline.base import PipelineContext
logger = structlog.get_logger("codegate")
@dataclass
class OutputPipelineContext:
"""
Context passed between output pipeline steps.
Does not include the input context, that one is separate.
"""
# We store the messages that are not yet sent to the client in the buffer.
# One reason for this might be that the buffer contains a secret that we want to de-obfuscate
buffer: list[str] = field(default_factory=list)
# Store extracted code snippets
snippets: List[CodeSnippet] = field(default_factory=list)
# Store all content that has been processed by the pipeline
processed_content: List[str] = field(default_factory=list)
# partial buffer to store prefixes
prefix_buffer: str = ""
class OutputPipelineStep(ABC):
"""
Base class for output pipeline steps
The process method should be implemented by subclasses and handles
processing of a single chunk of the stream.
"""
@property
@abstractmethod
def name(self) -> str:
"""Returns the name of this pipeline step"""
pass
@abstractmethod
async def process_chunk(
self,
chunk: Any,
context: OutputPipelineContext,
input_context: Optional[PipelineContext] = None,
) -> List[Any]:
"""
Process a single chunk of the stream.
Args:
- chunk: The input chunk to process, normalized to Any
- context: The output pipeline context. Can be used to store state between steps, mainly
the buffer.
- input_context: The input context from processing the user's input. Can include the secrets
obfuscated in the user message or code snippets in the user message.
Return:
- Empty list to pause the stream
- List containing one or more Any objects to emit
"""
pass
class OutputPipelineInstance:
"""
Handles processing of a single stream
Think of this class as steps + buffer
"""
def __init__(
self,
pipeline_steps: list[OutputPipelineStep],
input_context: Optional[PipelineContext] = None,
db_recorder: Optional[DbRecorder] = None,
):
self._input_context = input_context
self._pipeline_steps = pipeline_steps
self._context = OutputPipelineContext()
# we won't actually buffer the chunk, but in case we need to pass
# the remaining content in the buffer when the stream ends, we need
# to store the parameters like model, timestamp, etc.
self._buffered_chunk = None
if not db_recorder:
self._db_recorder = DbRecorder()
else:
self._db_recorder = db_recorder
def _buffer_chunk(self, chunk: Any) -> None:
"""
Add chunk content to buffer. This is used to store content that is not yet processed
when a pipeline pauses streaming.
"""
self._buffered_chunk = chunk
for content in chunk.get_content():
text = content.get_text()
if text is not None:
self._context.buffer.append(text)
def _store_chunk_content(self, chunk: Any) -> None:
"""
Store chunk content in processed content. This keeps track of the content that has been
streamed through the pipeline.
"""
for content in chunk.get_content():
text = content.get_text()
if text:
self._context.processed_content.append(text)
def _record_to_db(self) -> None:
"""
Record the context to the database
Important: We cannot use `await` in the finally statement. Otherwise, the stream
will transmmitted properly. Hence we get the running loop and create a task to
record the context.
"""
loop = asyncio.get_running_loop()
loop.create_task(self._db_recorder.record_context(self._input_context))
async def process_stream(
self,
stream: AsyncIterator[Any],
cleanup_sensitive: bool = True,
finish_stream: bool = True,
) -> AsyncIterator[Any]:
"""
Process a stream through all pipeline steps
"""
try:
async for chunk in stream:
# Store chunk content in buffer
self._buffer_chunk(chunk)
self._input_context.add_output(chunk)
# Process chunk through each step of the pipeline
current_chunks = [chunk]
for step in self._pipeline_steps:
if not current_chunks:
# Stop processing if a step returned empty list
break
processed_chunks = []
for c in current_chunks:
try:
step_result = await step.process_chunk(
c, self._context, self._input_context
)
if not step_result:
break
except Exception as e:
logger.error(f"Error processing step '{step.name}'", exc_info=e)
# Re-raise to maintain the current behaviour.
raise e
processed_chunks.extend(step_result)
current_chunks = processed_chunks
# Yield all processed chunks
for c in current_chunks:
self._store_chunk_content(c)
self._context.buffer.clear()
yield c
except Exception as e:
# Log exception and stop processing
logger.error(f"Error processing stream: {e}", exc_info=e)
raise e
finally:
# NOTE: Don't use await in finally block, it will break the stream
# Don't flush the buffer if we assume we'll call the pipeline again
if cleanup_sensitive is False:
if finish_stream:
self._record_to_db()
return
# TODO figure out what's the logic here.
# Process any remaining content in buffer when stream ends
if self._context.buffer:
final_content = "".join(self._context.buffer)
logger.error(
"Context buffer was not empty, it should have been!",
content=final_content,
len=len(self._context.buffer),
)
# NOTE: this block ensured that buffered chunks were
# flushed at the end of the pipeline. This was
# possible as long as the current implementation
# assumed that all messages were equivalent and
# position was not relevant.
#
# This is not the case for Anthropic, whose protocol
# is much more structured than that of the others.
#
# We're not there yet to ensure that such a protocol
# is not broken in face of messages being arbitrarily
# retained at each pipeline step, so we decided to
# treat a clogged pipelines as a bug.
self._context.buffer.clear()
if finish_stream:
self._record_to_db()
# Cleanup sensitive data through the input context
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
self._input_context.sensitive.secure_cleanup()
class OutputPipelineProcessor:
"""
Since we want to provide each run of the pipeline with a fresh context,
we need a factory to create new instances of the pipeline.
"""
def __init__(self, pipeline_steps: list[OutputPipelineStep]):
self.pipeline_steps = pipeline_steps
def _create_instance(self) -> OutputPipelineInstance:
"""Create a new pipeline instance for processing a stream"""
return OutputPipelineInstance(self.pipeline_steps)
async def process_stream(self, stream: AsyncIterator[Any]) -> AsyncIterator[Any]:
"""Create a new pipeline instance and process the stream"""
instance = self._create_instance()
async for chunk in instance.process_stream(stream):
yield chunk