-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathbase.py
348 lines (294 loc) · 11.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import datetime
import json
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import structlog
from pydantic import BaseModel
from codegate.clients.clients import ClientType
from codegate.db.models import Alert, AlertSeverity, Output, Prompt
from codegate.extract_snippets.message_extractor import CodeSnippet
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
logger = structlog.get_logger("codegate")
@dataclass
class PipelineSensitiveData:
manager: SensitiveDataManager
session_id: str
model: Optional[str] = None
def secure_cleanup(self):
"""Securely cleanup sensitive data for this session"""
if self.manager is None or self.session_id == "":
return
self.manager.cleanup_session(self.session_id)
self.session_id = ""
self.model = None
@dataclass
class PipelineContext:
metadata: Dict[str, Any] = field(default_factory=dict)
sensitive: Optional[PipelineSensitiveData] = field(default_factory=lambda: None)
alerts_raised: List[Alert] = field(default_factory=list)
prompt_id: Optional[str] = field(default_factory=lambda: None)
input_request: Optional[Prompt] = field(default_factory=lambda: None)
output_responses: List[Output] = field(default_factory=list)
shortcut_response: bool = False
# TODO(jakub): Remove these flags, they couple the steps to the context too much
# instead we should be using the metadata field scoped to the step to store anything
# the step wants
bad_packages_found: bool = False
secrets_found: bool = False
pii_found: bool = False
client: ClientType = ClientType.GENERIC
def add_alert(
self,
step_name: str,
severity_category: AlertSeverity = AlertSeverity.INFO,
code_snippet: Optional[CodeSnippet] = None,
trigger_string: Optional[str] = None,
) -> None:
"""
Add an alert to the pipeline step alerts_raised.
"""
if self.prompt_id is None:
self.prompt_id = str(uuid.uuid4())
if not code_snippet and not trigger_string:
logger.warning("No code snippet or trigger string provided for alert. Will not create")
return
code_snippet_str = code_snippet.model_dump_json() if code_snippet else None
self.alerts_raised.append(
Alert(
id=str(uuid.uuid4()),
prompt_id=self.prompt_id,
code_snippet=code_snippet_str,
trigger_string=trigger_string,
trigger_type=step_name,
trigger_category=severity_category.value,
timestamp=datetime.datetime.now(datetime.timezone.utc),
)
)
# Uncomment the below to debug
# logger.debug(f"Added alert to context: {self.alerts_raised[-1]}")
def add_input_request(
self, normalized_request: Any, is_fim_request: bool, provider: str
) -> None:
try:
if self.prompt_id is None:
self.prompt_id = str(uuid.uuid4())
self.input_request = Prompt(
id=self.prompt_id,
timestamp=datetime.datetime.now(datetime.timezone.utc),
provider=provider,
type="fim" if is_fim_request else "chat",
request=normalized_request,
workspace_id=None,
)
# Uncomment the below to debug the input
# logger.debug(f"Added input request to context: {self.input_request}")
except Exception as e:
logger.warning(f"Failed to serialize input request: {normalized_request}", error=str(e))
def add_output(self, model_response: Any) -> None:
try:
if self.prompt_id is None:
logger.warning(f"Tried to record output without response: {model_response}")
return
if isinstance(model_response, BaseModel):
output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True)
else:
output_str = json.dumps(model_response)
self.output_responses.append(
Output(
id=str(uuid.uuid4()),
prompt_id=self.prompt_id,
timestamp=datetime.datetime.now(datetime.timezone.utc),
output=output_str,
)
)
# Uncomment the below to debug the responses
# logger.debug(f"Added output to context: {self.output_responses[-1]}")
except Exception as e:
logger.error(f"Failed to serialize output: {model_response}", error=str(e))
return
@dataclass
class PipelineResponse:
"""Response generated by a pipeline step"""
content: str
step_name: str # The name of the pipeline step that generated this response
model: str # Taken from the original request's model field
@dataclass
class PipelineResult:
"""
Represents the result of a pipeline operation.
Either contains a modified request to continue processing,
or a response to return to the client.
"""
request: Optional[Any] = None
response: Optional[PipelineResponse] = None
context: Optional[PipelineContext] = None
error_message: Optional[str] = None
def shortcuts_processing(self) -> bool:
"""Returns True if this result should end pipeline processing"""
return self.response is not None or self.error_message is not None
@property
def success(self) -> bool:
"""Returns True if the pipeline step completed without errors"""
return self.error_message is None
class PipelineStep(ABC):
"""Base class for all pipeline steps in the processing chain."""
@property
@abstractmethod
def name(self) -> str:
"""
Returns the name of the pipeline step.
Returns:
str: A unique identifier for this pipeline step
"""
pass
@staticmethod
def get_last_user_message(
request: Any,
) -> Optional[tuple[str, int]]:
"""
Get the last user message and its index from the request.
Args:
request (Any): The chat completion request to process
Returns:
Optional[tuple[str, int]]: A tuple containing the message content and
its index, or None if no user message is found
"""
msg = request.last_user_message()
if msg is None:
return None
# unpack the tuple
msg, idx = msg
return "".join([content.get_text() for content in msg.get_content()]), idx
@staticmethod
def get_last_user_message_block(
request: Any,
) -> Optional[tuple[str, int]]:
"""
Get the last block of consecutive 'user' messages from the request.
Args:
request (Any): The chat completion request to process
Returns:
Optional[str, int]: A string containing all consecutive user messages in the
last user message block, separated by newlines, or None if
no user message block is found.
Index of the first message detected in the block.
"""
user_messages = []
last_idx = -1
for msg, idx in request.last_user_block():
for content in msg.get_content():
txt = content.get_text()
if not txt:
continue
user_messages.append(txt)
last_idx = idx
if not user_messages:
return None
return "\n".join(reversed(user_messages)), last_idx
@abstractmethod
async def process(self, request: Any, context: PipelineContext) -> PipelineResult:
"""Process a request and return either modified request or response stream"""
pass
class InputPipelineInstance:
def __init__(
self,
pipeline_steps: List[PipelineStep],
sensitive_data_manager: SensitiveDataManager,
is_fim: bool,
client: ClientType = ClientType.GENERIC,
):
self.pipeline_steps = pipeline_steps
self.sensitive_data_manager = sensitive_data_manager
self.is_fim = is_fim
self.context = PipelineContext(client=client)
# we create the sesitive context here so that it is not shared between individual requests
# TODO: could we get away with just generating the session ID for an instance?
self.context.sensitive = PipelineSensitiveData(
manager=self.sensitive_data_manager,
session_id=str(uuid.uuid4()),
)
self.context.metadata["is_fim"] = is_fim
async def process_request(
self,
request: Any,
provider: str,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
) -> PipelineResult:
"""Process a request through all pipeline steps"""
self.context.metadata["extra_headers"] = extra_headers
current_request = request
self.context.sensitive.api_key = api_key
self.context.sensitive.model = model
self.context.sensitive.provider = provider
self.context.sensitive.api_base = api_base
# For Copilot provider=openai. Use a flag to not clash with other places that may use that.
provider_db = provider
if self.context.client == ClientType.COPILOT:
provider_db = "copilot"
for step in self.pipeline_steps:
try:
result = await step.process(current_request, self.context)
if result is None:
continue
except Exception as e:
logger.error(f"Error processing step '{step.name}'", exc_info=e)
# Re-raise to maintain the current behaviour.
raise e
if result.shortcuts_processing():
# Also record the input when shortchutting
self.context.add_input_request(
current_request, is_fim_request=self.is_fim, provider=provider_db
)
return result
if result.request is not None:
current_request = result.request
if result.context is not None:
self.context = result.context
# Create the input request at the end so we make sure the secrets are obfuscated
self.context.add_input_request(
current_request, is_fim_request=self.is_fim, provider=provider_db
)
return PipelineResult(request=current_request, context=self.context)
class SequentialPipelineProcessor:
def __init__(
self,
pipeline_steps: List[PipelineStep],
sensitive_data_manager: SensitiveDataManager,
client_type: ClientType,
is_fim: bool,
):
self.pipeline_steps = pipeline_steps
self.sensitive_data_manager = sensitive_data_manager
self.is_fim = is_fim
self.instance = self._create_instance(client_type)
def _create_instance(self, client_type: ClientType) -> InputPipelineInstance:
"""Create a new pipeline instance for processing a request"""
return InputPipelineInstance(
self.pipeline_steps,
self.sensitive_data_manager,
self.is_fim,
client_type,
)
async def process_request(
self,
request: Any,
provider: str,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
) -> PipelineResult:
"""Create a new pipeline instance and process the request"""
return await self.instance.process_request(
request,
provider,
model,
api_key,
api_base,
extra_headers,
)