diff --git a/.github/workflows/import_packages.yml b/.github/workflows/import_packages.yml
index 3da31b63..e7ada4d4 100644
--- a/.github/workflows/import_packages.yml
+++ b/.github/workflows/import_packages.yml
@@ -47,6 +47,7 @@ jobs:
MALICIOUS_KEY=$(jq -r '.latest.malicious_packages' manifest.json)
DEPRECATED_KEY=$(jq -r '.latest.deprecated_packages' manifest.json)
ARCHIVED_KEY=$(jq -r '.latest.archived_packages' manifest.json)
+ VULNERABLE_KEY=$(jq -r '.latest.vulnerable_packages' manifest.json)
echo "Malicious key: $MALICIOUS_KEY"
echo "Deprecated key: $DEPRECATED_KEY"
@@ -58,6 +59,7 @@ jobs:
aws s3 cp s3://codegate-data-prod/$MALICIOUS_KEY /tmp/jsonl-files/malicious.jsonl --region $AWS_REGION
aws s3 cp s3://codegate-data-prod/$DEPRECATED_KEY /tmp/jsonl-files/deprecated.jsonl --region $AWS_REGION
aws s3 cp s3://codegate-data-prod/$ARCHIVED_KEY /tmp/jsonl-files/archived.jsonl --region $AWS_REGION
+ aws s3 cp s3://codegate-data-prod/$VULNERABLE_KEY /tmp/jsonl-files/vulnerable.jsonl --region $AWS_REGION
- name: Install Poetry
run: |
diff --git a/Dockerfile b/Dockerfile
index a12d0f76..70849c13 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -72,6 +72,7 @@ FROM python:3.12-slim AS runtime
RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 \
nginx \
+ gettext-base \
&& rm -rf /var/lib/apt/lists/*
# Create a non-root user
@@ -81,6 +82,7 @@ RUN useradd -m -u 1000 -r codegate
# Set permissions for user codegate to run nginx
RUN chown -R codegate /var/lib/nginx && \
chown -R codegate /var/log/nginx && \
+ chown -R codegate /etc/nginx && \
chown -R codegate /run
COPY nginx.conf /etc/nginx/nginx.conf
diff --git a/api/openapi.json b/api/openapi.json
index deb5c2de..4ea57d0e 100644
--- a/api/openapi.json
+++ b/api/openapi.json
@@ -418,7 +418,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/CreateOrRenameWorkspaceRequest"
+ "$ref": "#/components/schemas/FullWorkspace-Input"
}
}
},
@@ -430,7 +430,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/Workspace"
+ "$ref": "#/components/schemas/FullWorkspace-Output"
}
}
}
@@ -522,6 +522,58 @@
}
},
"/api/v1/workspaces/{workspace_name}": {
+ "put": {
+ "tags": [
+ "CodeGate API",
+ "Workspaces"
+ ],
+ "summary": "Update Workspace",
+ "description": "Update a workspace.",
+ "operationId": "v1_update_workspace",
+ "parameters": [
+ {
+ "name": "workspace_name",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string",
+ "title": "Workspace Name"
+ }
+ }
+ ],
+ "requestBody": {
+ "required": true,
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/FullWorkspace-Input"
+ }
+ }
+ }
+ },
+ "responses": {
+ "201": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/FullWorkspace-Output"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ },
"delete": {
"tags": [
"CodeGate API",
@@ -720,6 +772,50 @@
}
}
},
+ "/api/v1/workspaces/{workspace_name}/alerts-summary": {
+ "get": {
+ "tags": [
+ "CodeGate API",
+ "Workspaces"
+ ],
+ "summary": "Get Workspace Alerts Summary",
+ "description": "Get alert summary for a workspace.",
+ "operationId": "v1_get_workspace_alerts_summary",
+ "parameters": [
+ {
+ "name": "workspace_name",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string",
+ "title": "Workspace Name"
+ }
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/AlertSummary"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
"/api/v1/workspaces/{workspace_name}/messages": {
"get": {
"tags": [
@@ -1123,6 +1219,205 @@
}
}
}
+ },
+ "/api/v1/personas": {
+ "get": {
+ "tags": [
+ "CodeGate API",
+ "Personas"
+ ],
+ "summary": "List Personas",
+ "description": "List all personas.",
+ "operationId": "v1_list_personas",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "items": {
+ "$ref": "#/components/schemas/Persona"
+ },
+ "type": "array",
+ "title": "Response V1 List Personas"
+ }
+ }
+ }
+ }
+ }
+ },
+ "post": {
+ "tags": [
+ "CodeGate API",
+ "Personas"
+ ],
+ "summary": "Create Persona",
+ "description": "Create a new persona.",
+ "operationId": "v1_create_persona",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/PersonaRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "201": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Persona"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/api/v1/personas/{persona_name}": {
+ "get": {
+ "tags": [
+ "CodeGate API",
+ "Personas"
+ ],
+ "summary": "Get Persona",
+ "description": "Get a persona by name.",
+ "operationId": "v1_get_persona",
+ "parameters": [
+ {
+ "name": "persona_name",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string",
+ "title": "Persona Name"
+ }
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Persona"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ },
+ "put": {
+ "tags": [
+ "CodeGate API",
+ "Personas"
+ ],
+ "summary": "Update Persona",
+ "description": "Update an existing persona.",
+ "operationId": "v1_update_persona",
+ "parameters": [
+ {
+ "name": "persona_name",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string",
+ "title": "Persona Name"
+ }
+ }
+ ],
+ "requestBody": {
+ "required": true,
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/PersonaUpdateRequest"
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Persona"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ },
+ "delete": {
+ "tags": [
+ "CodeGate API",
+ "Personas"
+ ],
+ "summary": "Delete Persona",
+ "description": "Delete a persona.",
+ "operationId": "v1_delete_persona",
+ "parameters": [
+ {
+ "name": "persona_name",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string",
+ "title": "Persona Name"
+ }
+ }
+ ],
+ "responses": {
+ "204": {
+ "description": "Successful Response"
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ }
}
},
"components": {
@@ -1352,6 +1647,30 @@
],
"title": "AlertSeverity"
},
+ "AlertSummary": {
+ "properties": {
+ "malicious_packages": {
+ "type": "integer",
+ "title": "Malicious Packages"
+ },
+ "pii": {
+ "type": "integer",
+ "title": "Pii"
+ },
+ "secrets": {
+ "type": "integer",
+ "title": "Secrets"
+ }
+ },
+ "type": "object",
+ "required": [
+ "malicious_packages",
+ "pii",
+ "secrets"
+ ],
+ "title": "AlertSummary",
+ "description": "Represents a set of summary alerts"
+ },
"ChatMessage": {
"properties": {
"message": {
@@ -1521,7 +1840,20 @@
"title": "Conversation",
"description": "Represents a conversation."
},
- "CreateOrRenameWorkspaceRequest": {
+ "CustomInstructions": {
+ "properties": {
+ "prompt": {
+ "type": "string",
+ "title": "Prompt"
+ }
+ },
+ "type": "object",
+ "required": [
+ "prompt"
+ ],
+ "title": "CustomInstructions"
+ },
+ "FullWorkspace-Input": {
"properties": {
"name": {
"type": "string",
@@ -1530,43 +1862,42 @@
"config": {
"anyOf": [
{
- "$ref": "#/components/schemas/WorkspaceConfig"
+ "$ref": "#/components/schemas/WorkspaceConfig-Input"
},
{
"type": "null"
}
]
- },
- "rename_to": {
- "anyOf": [
- {
- "type": "string"
- },
- {
- "type": "null"
- }
- ],
- "title": "Rename To"
}
},
"type": "object",
"required": [
"name"
],
- "title": "CreateOrRenameWorkspaceRequest"
+ "title": "FullWorkspace"
},
- "CustomInstructions": {
+ "FullWorkspace-Output": {
"properties": {
- "prompt": {
+ "name": {
"type": "string",
- "title": "Prompt"
+ "title": "Name"
+ },
+ "config": {
+ "anyOf": [
+ {
+ "$ref": "#/components/schemas/WorkspaceConfig-Output"
+ },
+ {
+ "type": "null"
+ }
+ ]
}
},
"type": "object",
"required": [
- "prompt"
+ "name"
],
- "title": "CustomInstructions"
+ "title": "FullWorkspace"
},
"HTTPValidationError": {
"properties": {
@@ -1693,6 +2024,68 @@
"title": "MuxRule",
"description": "Represents a mux rule for a provider."
},
+ "Persona": {
+ "properties": {
+ "id": {
+ "type": "string",
+ "title": "Id"
+ },
+ "name": {
+ "type": "string",
+ "title": "Name"
+ },
+ "description": {
+ "type": "string",
+ "title": "Description"
+ }
+ },
+ "type": "object",
+ "required": [
+ "id",
+ "name",
+ "description"
+ ],
+ "title": "Persona",
+ "description": "Represents a persona object."
+ },
+ "PersonaRequest": {
+ "properties": {
+ "name": {
+ "type": "string",
+ "title": "Name"
+ },
+ "description": {
+ "type": "string",
+ "title": "Description"
+ }
+ },
+ "type": "object",
+ "required": [
+ "name",
+ "description"
+ ],
+ "title": "PersonaRequest",
+ "description": "Model for creating a new Persona."
+ },
+ "PersonaUpdateRequest": {
+ "properties": {
+ "new_name": {
+ "type": "string",
+ "title": "New Name"
+ },
+ "new_description": {
+ "type": "string",
+ "title": "New Description"
+ }
+ },
+ "type": "object",
+ "required": [
+ "new_name",
+ "new_description"
+ ],
+ "title": "PersonaUpdateRequest",
+ "description": "Model for updating a Persona."
+ },
"ProviderAuthType": {
"type": "string",
"enum": [
@@ -1914,11 +2307,32 @@
],
"title": "Workspace"
},
- "WorkspaceConfig": {
+ "WorkspaceConfig-Input": {
+ "properties": {
+ "custom_instructions": {
+ "type": "string",
+ "title": "Custom Instructions"
+ },
+ "muxing_rules": {
+ "items": {
+ "$ref": "#/components/schemas/MuxRule"
+ },
+ "type": "array",
+ "title": "Muxing Rules"
+ }
+ },
+ "type": "object",
+ "required": [
+ "custom_instructions",
+ "muxing_rules"
+ ],
+ "title": "WorkspaceConfig"
+ },
+ "WorkspaceConfig-Output": {
"properties": {
- "system_prompt": {
+ "custom_instructions": {
"type": "string",
- "title": "System Prompt"
+ "title": "Custom Instructions"
},
"muxing_rules": {
"items": {
@@ -1930,7 +2344,7 @@
},
"type": "object",
"required": [
- "system_prompt",
+ "custom_instructions",
"muxing_rules"
],
"title": "WorkspaceConfig"
diff --git a/docs/workspaces.md b/docs/workspaces.md
new file mode 100644
index 00000000..cdc4e751
--- /dev/null
+++ b/docs/workspaces.md
@@ -0,0 +1,111 @@
+# CodeGate Workspaces
+
+Workspaces help you group related resources together. They can be used to organize your
+configurations, muxing rules and custom prompts. It is important to note that workspaces
+are not a tenancy concept; CodeGate assumes that it's serving a single user.
+
+## Global vs Workspace resources
+
+In CodeGate, resources can be either global (available across all workspaces) or workspace-specific:
+
+- **Global resources**: These are shared across all workspaces and include provider endpoints,
+ authentication configurations, and personas.
+
+- **Workspace resources**: These are specific to a workspace and include custom instructions,
+ muxing rules, and conversation history.
+
+### Sessions and Active Workspaces
+
+CodeGate uses the concept of "sessions" to track which workspace is active. A session represents
+a user's interaction context with the system and maintains a reference to the active workspace.
+
+- **Sessions**: Each session has an ID, an active workspace ID, and a last update timestamp.
+- **Active workspace**: The workspace that is currently being used for processing requests.
+
+Currently, the implementation expects only one active session at a time, meaning only one
+workspace can be active. However, the underlying architecture is designed to potentially
+support multiple concurrent sessions in the future, which would allow different contexts
+to have different active workspaces simultaneously.
+
+When a workspace is activated, the session's active_workspace_id is updated to point to that
+workspace, and the muxing registry is updated to use that workspace's rules for routing requests.
+
+## Workspace Lifecycle
+
+Workspaces in CodeGate follow a specific lifecycle:
+
+1. **Creation**: Workspaces are created with a unique name and optional custom instructions and muxing rules.
+2. **Activation**: A workspace can be activated, making it the current context for processing requests.
+3. **Archiving**: Workspaces can be archived (soft-deleted) when no longer needed but might be used again.
+4. **Recovery**: Archived workspaces can be recovered to make them available again.
+5. **Deletion**: Archived workspaces can be permanently deleted (hard-deleted).
+
+### Default Workspace
+
+CodeGate includes a default workspace that cannot be deleted or archived. This workspace is used
+when no other workspace is explicitly activated.
+
+## Workspace Features
+
+### Custom Instructions
+
+Each workspace can have its own set of custom instructions that are applied to LLM requests.
+These instructions can be used to customize the behavior of the LLM for specific use cases.
+
+### Muxing Rules
+
+Workspaces can define muxing rules that determine which provider and model to use for different
+types of requests. Rules are evaluated in priority order (first rule in the list has highest priority).
+
+### Token Usage Tracking
+
+CodeGate tracks token usage per workspace, allowing you to monitor and analyze resource consumption
+across different contexts or projects.
+
+### Prompts, Alerts and Monitoring
+
+Workspaces maintain their own prompt and alert history, making it easier to monitor and respond to issues within specific contexts.
+
+## Developing
+
+### When to use workspaces?
+
+Consider using separate workspaces when:
+
+- You need different custom instructions for different projects or use cases
+- You want to route different types of requests to different models
+- You need to track token usage separately for different projects
+- You want to isolate alerts and monitoring for specific contexts
+- You're experimenting with different configurations and want to switch between them easily
+
+### When should a resource be global?
+
+Resources should be global when:
+
+- They need to be shared across multiple workspaces
+- They represent infrastructure configuration rather than usage patterns
+- They're related to provider connectivity rather than specific use cases
+- They represent reusable components like personas that might be used in multiple contexts
+
+### Exporting resources
+
+Exporting resources in CodeGate is designed to facilitate sharing workspaces between different instances.
+This is particularly useful for:
+
+- **Standardizing configurations**: When you want to ensure consistent behavior across multiple CodeGate instances
+- **Sharing best practices**: When you've developed effective muxing rules or custom instructions that others could benefit from
+- **Backup and recovery**: To preserve important workspace configurations before making significant changes
+
+When deciding whether to export resources, consider:
+
+- **Export workspace configurations** when they represent reusable patterns that could be valuable in other contexts
+- **Export muxing rules** when they represent well-tested routing strategies that could be applied in other instances
+- **Export custom instructions** when they contain general-purpose prompting strategies not specific to your instance
+
+Avoid exporting:
+- Workspaces with instance-specific configurations that wouldn't be applicable elsewhere
+- Workspaces containing sensitive or organization-specific custom instructions
+- Resources that are tightly coupled to your specific provider endpoints or authentication setup
+
+Note that conversation history, alerts, and token usage statistics are not included in exports as they
+represent instance-specific usage data rather than reusable configurations.
diff --git a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py
new file mode 100644
index 00000000..775e3967
--- /dev/null
+++ b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py
@@ -0,0 +1,63 @@
+"""add installation table
+
+Revision ID: e4c05d7591a8
+Revises: 3ec2b4ab569c
+Create Date: 2025-03-05 21:26:19.034319+00:00
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "e4c05d7591a8"
+down_revision: Union[str, None] = "3ec2b4ab569c"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ op.execute("BEGIN TRANSACTION;")
+
+ op.execute(
+ """
+ CREATE TABLE IF NOT EXISTS instance (
+ id TEXT PRIMARY KEY, -- UUID stored as TEXT
+ created_at DATETIME NOT NULL
+ );
+ """
+ )
+
+ op.execute(
+ """
+ -- The following trigger prevents multiple insertions in the
+ -- instance table. It is safe since the dimension of the table
+ -- is fixed.
+
+ CREATE TRIGGER single_instance
+ BEFORE INSERT ON instance
+ WHEN (SELECT COUNT(*) FROM instance) >= 1
+ BEGIN
+ SELECT RAISE(FAIL, 'only one instance!');
+ END;
+ """
+ )
+
+ # Finish transaction
+ op.execute("COMMIT;")
+
+
+def downgrade() -> None:
+ op.execute("BEGIN TRANSACTION;")
+
+ op.execute(
+ """
+ DROP TABLE instance;
+ """
+ )
+
+ # Finish transaction
+ op.execute("COMMIT;")
diff --git a/poetry.lock b/poetry.lock
index f757d84e..15579a88 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -151,23 +151,23 @@ docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"]
[[package]]
name = "alembic"
-version = "1.14.1"
+version = "1.15.1"
description = "A database migration tool for SQLAlchemy."
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
groups = ["main"]
files = [
- {file = "alembic-1.14.1-py3-none-any.whl", hash = "sha256:1acdd7a3a478e208b0503cd73614d5e4c6efafa4e73518bb60e4f2846a37b1c5"},
- {file = "alembic-1.14.1.tar.gz", hash = "sha256:496e888245a53adf1498fcab31713a469c65836f8de76e01399aa1c3e90dd213"},
+ {file = "alembic-1.15.1-py3-none-any.whl", hash = "sha256:197de710da4b3e91cf66a826a5b31b5d59a127ab41bd0fc42863e2902ce2bbbe"},
+ {file = "alembic-1.15.1.tar.gz", hash = "sha256:e1a1c738577bca1f27e68728c910cd389b9a92152ff91d902da649c192e30c49"},
]
[package.dependencies]
Mako = "*"
-SQLAlchemy = ">=1.3.0"
-typing-extensions = ">=4"
+SQLAlchemy = ">=1.4.0"
+typing-extensions = ">=4.12"
[package.extras]
-tz = ["backports.zoneinfo ; python_version < \"3.9\"", "tzdata"]
+tz = ["tzdata"]
[[package]]
name = "annotated-types"
@@ -1343,14 +1343,14 @@ files = [
[[package]]
name = "jinja2"
-version = "3.1.5"
+version = "3.1.6"
description = "A very fast and expressive template engine."
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
files = [
- {file = "jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"},
- {file = "jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb"},
+ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"},
+ {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"},
]
[package.dependencies]
@@ -1546,14 +1546,14 @@ files = [
[[package]]
name = "litellm"
-version = "1.62.1"
+version = "1.63.0"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
groups = ["main", "dev"]
files = [
- {file = "litellm-1.62.1-py3-none-any.whl", hash = "sha256:f576358c72b477207d1f45ce5ac895ede7bd84377f6420a6b522909c829a79dc"},
- {file = "litellm-1.62.1.tar.gz", hash = "sha256:eee9cc40dc9c1da7e411af2f4ef145a67bb61702ae4e1218c1bc15b9e6404daa"},
+ {file = "litellm-1.63.0-py3-none-any.whl", hash = "sha256:38961eaeb81fa2500c2725e01be898fb5d6347e73286b6d13d2f4d2f006d99e9"},
+ {file = "litellm-1.63.0.tar.gz", hash = "sha256:872fb3fa4c8875d82fe998a5e4249c21a15bb08800286f03f90ed1700203f62e"},
]
[package.dependencies]
@@ -4279,4 +4279,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.12,<3.13"
-content-hash = "1cae360ec3078b2da000dfea1d112e32256502aa9b3e2a4d8ca919384a49aff6"
+content-hash = "ba9315e5bd243ff23b9f1044c43228b0658f7c345bc081dcfe9f2af8f2511e0c"
diff --git a/pyproject.toml b/pyproject.toml
index a6748387..8da9768a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,7 @@ PyYAML = "==6.0.2"
fastapi = "==0.115.11"
uvicorn = "==0.34.0"
structlog = "==25.1.0"
-litellm = "==1.62.1"
+litellm = "==1.63.0"
llama_cpp_python = "==0.3.5"
cryptography = "==44.0.2"
sqlalchemy = "==2.0.38"
@@ -28,7 +28,7 @@ tree-sitter-java = "==0.23.5"
tree-sitter-javascript = "==0.23.1"
tree-sitter-python = "==0.23.6"
tree-sitter-rust = "==0.23.2"
-alembic = "==1.14.1"
+alembic = "==1.15.1"
pygments = "==2.19.1"
sqlite-vec-sl-tmp = "==0.0.4"
greenlet = "==3.1.1"
@@ -50,7 +50,7 @@ ruff = "==0.9.9"
bandit = "==1.8.3"
build = "==1.2.2.post1"
wheel = "==0.45.1"
-litellm = "==1.62.1"
+litellm = "==1.63.0"
pytest-asyncio = "==0.25.3"
llama_cpp_python = "==0.3.5"
scikit-learn = "==1.6.1"
diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh
index b28f6704..dd1f70d7 100755
--- a/scripts/entrypoint.sh
+++ b/scripts/entrypoint.sh
@@ -28,11 +28,10 @@ generate_certs() {
# Function to start Nginx server for the dashboard
start_dashboard() {
- if [ -n "${DASHBOARD_BASE_URL}" ]; then
- echo "Overriding dashboard url with $DASHBOARD_BASE_URL"
- sed -ibck "s|http://localhost:8989|http://$DASHBOARD_BASE_URL:8989|g" /var/www/html/assets/*.js
- fi
echo "Starting the dashboard..."
+
+ envsubst '${DASHBOARD_API_BASE_URL}' < /var/www/html/index.html > /var/www/html/index.html.tmp && mv /var/www/html/index.html.tmp /var/www/html/index.html
+
nginx -g 'daemon off;' &
}
diff --git a/scripts/import_packages.py b/scripts/import_packages.py
index 1cfdfd1e..c4a2dad1 100644
--- a/scripts/import_packages.py
+++ b/scripts/import_packages.py
@@ -20,6 +20,7 @@ def __init__(self, jsonl_dir="data", vec_db_path="./sqlite_data/vectordb.db"):
os.path.join(jsonl_dir, "archived.jsonl"),
os.path.join(jsonl_dir, "deprecated.jsonl"),
os.path.join(jsonl_dir, "malicious.jsonl"),
+ os.path.join(jsonl_dir, "vulnerable.jsonl"),
]
self.conn = self._get_connection()
Config.load() # Load the configuration
@@ -48,13 +49,41 @@ def setup_schema(self):
"""
)
+ # table for packages that has at least one vulnerability high or critical
+ cursor.execute(
+ """
+ CREATE TABLE cve_packages (
+ name TEXT NOT NULL,
+ version TEXT NOT NULL,
+ type TEXT NOT NULL
+ )
+ """
+ )
+
# Create indexes for faster querying
cursor.execute("CREATE INDEX IF NOT EXISTS idx_name ON packages(name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_type ON packages(type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON packages(status)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_pkg_cve_name ON cve_packages(name)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_pkg_cve_type ON cve_packages(type)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_pkg_cve_version ON cve_packages(version)")
self.conn.commit()
+ async def process_cve_packages(self, package):
+ cursor = self.conn.cursor()
+ cursor.execute(
+ """
+ INSERT INTO cve_packages (name, version, type) VALUES (?, ?, ?)
+ """,
+ (
+ package["name"],
+ package["version"],
+ package["type"],
+ ),
+ )
+ self.conn.commit()
+
async def process_package(self, package):
vector_str = generate_vector_string(package)
vector = await self.inference_engine.embed(
@@ -101,14 +130,19 @@ async def add_data(self):
package["status"] = json_file.split("/")[-1].split(".")[0]
key = f"{package['name']}/{package['type']}"
- if key in existing_packages and existing_packages[key] == {
- "status": package["status"],
- "description": package["description"],
- }:
- print("Package already exists", key)
- continue
-
- await self.process_package(package)
+ if package["status"] == "vulnerable":
+ # Process vulnerable packages using the cve flow
+ await self.process_cve_packages(package)
+ else:
+ # For non-vulnerable packages, check for duplicates and process normally
+ if key in existing_packages and existing_packages[key] == {
+ "status": package["status"],
+ "description": package["description"],
+ }:
+ print("Package already exists", key)
+ continue
+
+ await self.process_package(package)
async def run_import(self):
self.setup_schema()
diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py
index ebd9be79..33efea33 100644
--- a/src/codegate/api/v1.py
+++ b/src/codegate/api/v1.py
@@ -12,7 +12,12 @@
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
-from codegate.db.models import AlertSeverity, WorkspaceWithModel
+from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel
+from codegate.muxing.persona import (
+ PersonaDoesNotExistError,
+ PersonaManager,
+ PersonaSimilarDescriptionError,
+)
from codegate.providers import crud as provendcrud
from codegate.workspaces import crud
@@ -21,6 +26,7 @@
v1 = APIRouter()
wscrud = crud.WorkspaceCrud()
pcrud = provendcrud.ProviderCrud()
+persona_manager = PersonaManager()
# This is a singleton object
dbreader = DbReader()
@@ -248,22 +254,18 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status
@v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201)
async def create_workspace(
- request: v1_models.CreateOrRenameWorkspaceRequest,
-) -> v1_models.Workspace:
+ request: v1_models.FullWorkspace,
+) -> v1_models.FullWorkspace:
"""Create a new workspace."""
- if request.rename_to is not None:
- return await rename_workspace(request)
- return await create_new_workspace(request)
-
-
-async def create_new_workspace(
- request: v1_models.CreateOrRenameWorkspaceRequest,
-) -> v1_models.Workspace:
- # Input validation is done in the model
try:
- _ = await wscrud.add_workspace(request.name)
- except AlreadyExistsError:
- raise HTTPException(status_code=409, detail="Workspace already exists")
+ custom_instructions = request.config.custom_instructions if request.config else None
+ muxing_rules = request.config.muxing_rules if request.config else None
+
+ workspace_row, mux_rules = await wscrud.add_workspace(
+ request.name, custom_instructions, muxing_rules
+ )
+ except crud.WorkspaceNameAlreadyInUseError:
+ raise HTTPException(status_code=409, detail="Workspace name already in use")
except ValidationError:
raise HTTPException(
status_code=400,
@@ -277,18 +279,40 @@ async def create_new_workspace(
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
- return v1_models.Workspace(name=request.name, is_active=False)
+ return v1_models.FullWorkspace(
+ name=workspace_row.name,
+ config=v1_models.WorkspaceConfig(
+ custom_instructions=workspace_row.custom_instructions or "",
+ muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules],
+ ),
+ )
-async def rename_workspace(
- request: v1_models.CreateOrRenameWorkspaceRequest,
-) -> v1_models.Workspace:
+@v1.put(
+ "/workspaces/{workspace_name}",
+ tags=["Workspaces"],
+ generate_unique_id_function=uniq_name,
+ status_code=201,
+)
+async def update_workspace(
+ workspace_name: str,
+ request: v1_models.FullWorkspace,
+) -> v1_models.FullWorkspace:
+ """Update a workspace."""
try:
- _ = await wscrud.rename_workspace(request.name, request.rename_to)
+ custom_instructions = request.config.custom_instructions if request.config else None
+ muxing_rules = request.config.muxing_rules if request.config else None
+
+ workspace_row, mux_rules = await wscrud.update_workspace(
+ workspace_name,
+ request.name,
+ custom_instructions,
+ muxing_rules,
+ )
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
- except AlreadyExistsError:
- raise HTTPException(status_code=409, detail="Workspace already exists")
+ except crud.WorkspaceNameAlreadyInUseError:
+ raise HTTPException(status_code=409, detail="Workspace name already in use")
except ValidationError:
raise HTTPException(
status_code=400,
@@ -302,7 +326,13 @@ async def rename_workspace(
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
- return v1_models.Workspace(name=request.rename_to, is_active=False)
+ return v1_models.FullWorkspace(
+ name=workspace_row.name,
+ config=v1_models.WorkspaceConfig(
+ custom_instructions=workspace_row.custom_instructions or "",
+ muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules],
+ ),
+ )
@v1.delete(
@@ -397,6 +427,33 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A
raise HTTPException(status_code=500, detail="Internal server error")
+@v1.get(
+ "/workspaces/{workspace_name}/alerts-summary",
+ tags=["Workspaces"],
+ generate_unique_id_function=uniq_name,
+)
+async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSummary:
+ """Get alert summary for a workspace."""
+ try:
+ ws = await wscrud.get_workspace_by_name(workspace_name)
+ except crud.WorkspaceDoesNotExistError:
+ raise HTTPException(status_code=404, detail="Workspace does not exist")
+ except Exception:
+ logger.exception("Error while getting workspace")
+ raise HTTPException(status_code=500, detail="Internal server error")
+
+ try:
+ summary = await dbreader.get_alerts_summary_by_workspace(ws.id)
+ return v1_models.AlertSummary(
+ malicious_packages=summary["codegate_context_retriever_count"],
+ pii=summary["codegate_pii_count"],
+ secrets=summary["codegate_secrets_count"],
+ )
+ except Exception:
+ logger.exception("Error while getting alerts summary")
+ raise HTTPException(status_code=500, detail="Internal server error")
+
+
@v1.get(
"/workspaces/{workspace_name}/messages",
tags=["Workspaces"],
@@ -614,3 +671,103 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
except Exception:
logger.exception("Error while getting messages")
raise HTTPException(status_code=500, detail="Internal server error")
+
+
+@v1.get("/personas", tags=["Personas"], generate_unique_id_function=uniq_name)
+async def list_personas() -> List[Persona]:
+ """List all personas."""
+ try:
+ personas = await persona_manager.get_all_personas()
+ return personas
+ except Exception:
+ logger.exception("Error while getting personas")
+ raise HTTPException(status_code=500, detail="Internal server error")
+
+
+@v1.get("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name)
+async def get_persona(persona_name: str) -> Persona:
+ """Get a persona by name."""
+ try:
+ persona = await persona_manager.get_persona(persona_name)
+ return persona
+ except PersonaDoesNotExistError:
+ logger.exception("Error while getting persona")
+ raise HTTPException(status_code=404, detail="Persona does not exist")
+
+
+@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201)
+async def create_persona(request: v1_models.PersonaRequest) -> Persona:
+ """Create a new persona."""
+ try:
+ await persona_manager.add_persona(request.name, request.description)
+ persona = await dbreader.get_persona_by_name(request.name)
+ return persona
+ except PersonaSimilarDescriptionError:
+ logger.exception("Error while creating persona")
+ raise HTTPException(status_code=409, detail="Persona has a similar description to another")
+ except AlreadyExistsError:
+ logger.exception("Error while creating persona")
+ raise HTTPException(status_code=409, detail="Persona already exists")
+ except ValidationError:
+ logger.exception("Error while creating persona")
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ "Persona has invalid name, check is alphanumeric "
+ "and only contains dashes and underscores"
+ ),
+ )
+ except Exception:
+ logger.exception("Error while creating persona")
+ raise HTTPException(status_code=500, detail="Internal server error")
+
+
+@v1.put("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name)
+async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequest) -> Persona:
+ """Update an existing persona."""
+ try:
+ await persona_manager.update_persona(
+ persona_name, request.new_name, request.new_description
+ )
+ persona = await dbreader.get_persona_by_name(request.new_name)
+ return persona
+ except PersonaSimilarDescriptionError:
+ logger.exception("Error while updating persona")
+ raise HTTPException(status_code=409, detail="Persona has a similar description to another")
+ except PersonaDoesNotExistError:
+ logger.exception("Error while updating persona")
+ raise HTTPException(status_code=404, detail="Persona does not exist")
+ except AlreadyExistsError:
+ logger.exception("Error while updating persona")
+ raise HTTPException(status_code=409, detail="Persona already exists")
+ except ValidationError:
+ logger.exception("Error while creating persona")
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ "Persona has invalid name, check is alphanumeric "
+ "and only contains dashes and underscores"
+ ),
+ )
+ except Exception:
+ logger.exception("Error while updating persona")
+ raise HTTPException(status_code=500, detail="Internal server error")
+
+
+@v1.delete(
+ "/personas/{persona_name}",
+ tags=["Personas"],
+ generate_unique_id_function=uniq_name,
+ status_code=204,
+)
+async def delete_persona(persona_name: str):
+ """Delete a persona."""
+ try:
+ await persona_manager.delete_persona(persona_name)
+ return Response(status_code=204)
+ except PersonaDoesNotExistError:
+ logger.exception("Error while updating persona")
+ raise HTTPException(status_code=404, detail="Persona does not exist")
+ except Exception:
+ logger.exception("Error while deleting persona")
+ raise HTTPException(status_code=500, detail="Internal server error")
diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py
index c608484c..dff26489 100644
--- a/src/codegate/api/v1_models.py
+++ b/src/codegate/api/v1_models.py
@@ -61,7 +61,7 @@ def from_db_workspaces(
class WorkspaceConfig(pydantic.BaseModel):
- system_prompt: str
+ custom_instructions: str
muxing_rules: List[mux_models.MuxRule]
@@ -72,13 +72,6 @@ class FullWorkspace(pydantic.BaseModel):
config: Optional[WorkspaceConfig] = None
-class CreateOrRenameWorkspaceRequest(FullWorkspace):
- # If set, rename the workspace to this name. Note that
- # the 'name' field is still required and the workspace
- # workspace must exist.
- rename_to: Optional[str] = None
-
-
class ActivateWorkspaceRequest(pydantic.BaseModel):
name: str
@@ -190,6 +183,16 @@ def from_db_model(db_model: db_models.Alert) -> "Alert":
timestamp: datetime.datetime
+class AlertSummary(pydantic.BaseModel):
+ """
+ Represents a set of summary alerts
+ """
+
+ malicious_packages: int
+ pii: int
+ secrets: int
+
+
class PartialQuestionAnswer(pydantic.BaseModel):
"""
Represents a partial conversation.
@@ -312,3 +315,21 @@ class ModelByProvider(pydantic.BaseModel):
def __str__(self):
return f"{self.provider_name} / {self.name}"
+
+
+class PersonaRequest(pydantic.BaseModel):
+ """
+ Model for creating a new Persona.
+ """
+
+ name: str
+ description: str
+
+
+class PersonaUpdateRequest(pydantic.BaseModel):
+ """
+ Model for updating a Persona.
+ """
+
+ new_name: str
+ new_description: str
diff --git a/src/codegate/cli.py b/src/codegate/cli.py
index 455d9001..1ae3f9c2 100644
--- a/src/codegate/cli.py
+++ b/src/codegate/cli.py
@@ -14,7 +14,11 @@
from codegate.ca.codegate_ca import CertificateAuthority
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
-from codegate.db.connection import init_db_sync, init_session_if_not_exists
+from codegate.db.connection import (
+ init_db_sync,
+ init_session_if_not_exists,
+ init_instance,
+)
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
from codegate.providers import crud as provendcrud
@@ -318,6 +322,7 @@ def serve( # noqa: C901
logger = structlog.get_logger("codegate").bind(origin="cli")
init_db_sync(cfg.db_path)
+ init_instance(cfg.db_path)
init_session_if_not_exists(cfg.db_path)
# Check certificates and create CA if necessary
diff --git a/src/codegate/config.py b/src/codegate/config.py
index 761ca09e..179ec4d3 100644
--- a/src/codegate/config.py
+++ b/src/codegate/config.py
@@ -57,9 +57,14 @@ class Config:
force_certs: bool = False
max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
+
# Min value is 0 (max similarity), max value is 2 (orthogonal)
# The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py
+ # It's the threshold value to determine if a query matches a persona.
persona_threshold = 0.75
+ # The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py
+ # It's the threshold value to determine if a persona description is similar to existing personas
+ persona_diff_desc_threshold = 0.3
# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py
index 803943b3..3f439aea 100644
--- a/src/codegate/db/connection.py
+++ b/src/codegate/db/connection.py
@@ -1,4 +1,5 @@
import asyncio
+import datetime
import json
import sqlite3
import uuid
@@ -14,7 +15,8 @@
from sqlalchemy import CursorResult, TextClause, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
-from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
+from sqlalchemy.orm import sessionmaker
from codegate.db.fim_cache import FimCache
from codegate.db.models import (
@@ -22,6 +24,7 @@
Alert,
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
+ Instance,
IntermediatePromptWithOutputUsageAlerts,
MuxRule,
Output,
@@ -560,15 +563,62 @@ async def add_persona(self, persona: PersonaEmbedding) -> None:
)
try:
- # For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
- # We need to convert it back to a numpy array before inserting it into the DB.
- persona_dict = persona.model_dump()
- persona_dict["description_embedding"] = persona.description_embedding
- await self._execute_with_no_return(sql, persona_dict)
+ await self._execute_with_no_return(sql, persona.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")
+ async def update_persona(self, persona: PersonaEmbedding) -> None:
+ """
+ Update an existing Persona in the DB.
+
+ This handles validation and update of an existing persona.
+ """
+ sql = text(
+ """
+ UPDATE personas
+ SET name = :name,
+ description = :description,
+ description_embedding = :description_embedding
+ WHERE id = :id
+ """
+ )
+
+ try:
+ await self._execute_with_no_return(sql, persona.model_dump())
+ except IntegrityError as e:
+ logger.debug(f"Exception type: {type(e)}")
+ raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")
+
+ async def delete_persona(self, persona_id: str) -> None:
+ """
+ Delete an existing Persona from the DB.
+ """
+ sql = text("DELETE FROM personas WHERE id = :id")
+ conditions = {"id": persona_id}
+ await self._execute_with_no_return(sql, conditions)
+
+ async def init_instance(self) -> None:
+ """
+ Initializes instance details in the database.
+ """
+ sql = text(
+ """
+ INSERT INTO instance (id, created_at)
+ VALUES (:id, :created_at)
+ """
+ )
+
+ try:
+ instance = Instance(
+ id=str(uuid.uuid4()),
+ created_at=datetime.datetime.now(datetime.timezone.utc),
+ )
+ await self._execute_with_no_return(sql, instance.model_dump())
+ except IntegrityError as e:
+ logger.debug(f"Exception type: {type(e)}")
+ raise AlreadyExistsError(f"Instance already initialized.")
+
class DbReader(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
@@ -587,7 +637,10 @@ async def _dump_result_to_pydantic_model(
return None
async def _execute_select_pydantic_model(
- self, model_type: Type[BaseModel], sql_command: TextClause
+ self,
+ model_type: Type[BaseModel],
+ sql_command: TextClause,
+ should_raise: bool = False,
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
@@ -595,6 +648,9 @@ async def _execute_select_pydantic_model(
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model: {model_type}.", error=str(e))
+ # Exposes errors to the caller
+ if should_raise:
+ raise e
return None
async def _exec_select_conditions_to_pydantic(
@@ -746,6 +802,38 @@ async def get_alerts_by_workspace(
)
return prompts
+ async def get_alerts_summary_by_workspace(self, workspace_id: str) -> dict:
+ """Get aggregated alert summary counts for a given workspace_id."""
+ sql = text(
+ """
+ SELECT
+ COUNT(*) AS total_alerts,
+ SUM(CASE WHEN a.trigger_type = 'codegate-secrets' THEN 1 ELSE 0 END)
+ AS codegate_secrets_count,
+ SUM(CASE WHEN a.trigger_type = 'codegate-context-retriever' THEN 1 ELSE 0 END)
+ AS codegate_context_retriever_count,
+ SUM(CASE WHEN a.trigger_type = 'codegate-pii' THEN 1 ELSE 0 END)
+ AS codegate_pii_count
+ FROM alerts a
+ INNER JOIN prompts p ON p.id = a.prompt_id
+ WHERE p.workspace_id = :workspace_id
+ """
+ )
+ conditions = {"workspace_id": workspace_id}
+
+ async with self._async_db_engine.begin() as conn:
+ result = await conn.execute(sql, conditions)
+ row = result.fetchone()
+
+ # Return a dictionary with counts (handling None values safely)
+ return {
+ "codegate_secrets_count": row.codegate_secrets_count or 0 if row else 0,
+ "codegate_context_retriever_count": (
+ row.codegate_context_retriever_count or 0 if row else 0
+ ),
+ "codegate_pii_count": row.codegate_pii_count or 0 if row else 0,
+ }
+
async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
sql = text(
"""
@@ -971,6 +1059,33 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
)
return personas[0] if personas else None
+ async def get_distance_to_existing_personas(
+ self, query_embedding: np.ndarray, exclude_id: Optional[str]
+ ) -> List[PersonaDistance]:
+ """
+ Get the distance between a persona and a query embedding.
+ """
+ sql = """
+ SELECT
+ id,
+ name,
+ description,
+ vec_distance_cosine(description_embedding, :query_embedding) as distance
+ FROM personas
+ """
+ conditions = {"query_embedding": query_embedding}
+
+ # Exclude this persona from the SQL query. Used when checking the descriptions
+ # for updating the persona. Exclude the persona to update itself from the query.
+ if exclude_id:
+ sql += " WHERE id != :exclude_id"
+ conditions["exclude_id"] = exclude_id
+
+ persona_distances = await self._exec_vec_db_query_to_pydantic(
+ sql, conditions, PersonaDistance
+ )
+ return persona_distances
+
async def get_distance_to_persona(
self, persona_id: str, query_embedding: np.ndarray
) -> PersonaDistance:
@@ -992,6 +1107,55 @@ async def get_distance_to_persona(
)
return persona_distance[0]
+ async def get_all_personas(self) -> List[Persona]:
+ """
+ Get all the personas.
+ """
+ sql = text(
+ """
+ SELECT
+ id, name, description
+ FROM personas
+ """
+ )
+ personas = await self._execute_select_pydantic_model(Persona, sql, should_raise=True)
+ return personas
+
+ async def get_instance(self) -> Instance:
+ """
+ Get the details of the instance.
+ """
+ sql = text("SELECT id, created_at FROM instance")
+ return await self._execute_select_pydantic_model(Instance, sql)
+
+
+class DbTransaction:
+ def __init__(self):
+ self._session = None
+
+ async def __aenter__(self):
+ self._session = sessionmaker(
+ bind=DbCodeGate()._async_db_engine,
+ class_=AsyncSession,
+ expire_on_commit=False,
+ )()
+ await self._session.begin()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ if exc_type:
+ await self._session.rollback()
+ raise exc_val
+ else:
+ await self._session.commit()
+ await self._session.close()
+
+ async def commit(self):
+ await self._session.commit()
+
+ async def rollback(self):
+ await self._session.rollback()
+
def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
@@ -1014,8 +1178,6 @@ def init_db_sync(db_path: Optional[str] = None):
def init_session_if_not_exists(db_path: Optional[str] = None):
- import datetime
-
db_reader = DbReader(db_path)
sessions = asyncio.run(db_reader.get_sessions())
# If there are no sessions, create a new one
@@ -1035,5 +1197,19 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
logger.info("Session in DB initialized successfully.")
+def init_instance(db_path: Optional[str] = None):
+ db_reader = DbReader(db_path)
+ instance = asyncio.run(db_reader.get_instance())
+ # Initialize instance if not already initialized.
+ if not instance:
+ db_recorder = DbRecorder(db_path)
+ try:
+ asyncio.run(db_recorder.init_instance())
+ except Exception as e:
+ logger.error(f"Failed to initialize instance in DB: {e}")
+ raise
+ logger.info("Instance initialized successfully.")
+
+
if __name__ == "__main__":
init_db_sync()
diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py
index a5941e96..07c4c8ed 100644
--- a/src/codegate/db/models.py
+++ b/src/codegate/db/models.py
@@ -3,7 +3,15 @@
from typing import Annotated, Any, Dict, List, Optional
import numpy as np
-from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints
+import regex as re
+from pydantic import (
+ BaseModel,
+ BeforeValidator,
+ ConfigDict,
+ PlainSerializer,
+ StringConstraints,
+ field_validator,
+)
class AlertSeverity(str, Enum):
@@ -120,6 +128,11 @@ class Session(BaseModel):
last_update: datetime.datetime
+class Instance(BaseModel):
+ id: str
+ created_at: datetime.datetime
+
+
# Models for select queries
@@ -245,12 +258,14 @@ class MuxRule(BaseModel):
def nd_array_custom_before_validator(x):
# custome before validation logic
+ if isinstance(x, bytes):
+ return np.frombuffer(x, dtype=np.float32)
return x
def nd_array_custom_serializer(x):
# custome serialization logic
- return str(x)
+ return x
# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.
@@ -264,6 +279,8 @@ def nd_array_custom_serializer(x):
PlainSerializer(nd_array_custom_serializer, return_type=str),
]
+VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$")
+
class Persona(BaseModel):
"""
@@ -274,6 +291,15 @@ class Persona(BaseModel):
name: str
description: str
+ @field_validator("name", mode="after")
+ @classmethod
+ def validate_persona_name(cls, value: str) -> str:
+ if VALID_PERSONA_NAME_PATTERN.match(value):
+ return value
+ raise ValueError(
+ "Invalid persona name. It should be alphanumeric with underscores and dashes."
+ )
+
class PersonaEmbedding(Persona):
"""
diff --git a/src/codegate/muxing/semantic_router.py b/src/codegate/muxing/persona.py
similarity index 58%
rename from src/codegate/muxing/semantic_router.py
rename to src/codegate/muxing/persona.py
index ce240b1f..ac21205c 100644
--- a/src/codegate/muxing/semantic_router.py
+++ b/src/codegate/muxing/persona.py
@@ -1,5 +1,6 @@
import unicodedata
import uuid
+from typing import List, Optional
import numpy as np
import regex as re
@@ -28,14 +29,20 @@ class PersonaDoesNotExistError(Exception):
pass
-class SemanticRouter:
+class PersonaSimilarDescriptionError(Exception):
+ pass
+
+
+class PersonaManager:
def __init__(self):
- self._inference_engine = LlamaCppInferenceEngine()
+ Config.load()
conf = Config.get_config()
+ self._inference_engine = LlamaCppInferenceEngine()
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}"
self._n_gpu = conf.chat_model_n_gpu_layers
self._persona_threshold = conf.persona_threshold
+ self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold
self._db_recorder = DbRecorder()
self._db_reader = DbReader()
@@ -105,12 +112,50 @@ async def _embed_text(self, text: str) -> np.ndarray:
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
return np.array(embed_list[0], dtype=np.float32)
+ async def _is_persona_description_diff(
+ self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
+ ) -> bool:
+ """
+ Check if the persona description is different enough from existing personas.
+ """
+ # The distance calculation is done in the database
+ persona_distances = await self._db_reader.get_distance_to_existing_personas(
+ emb_persona_desc, exclude_id
+ )
+ if not persona_distances:
+ return True
+
+ for persona_distance in persona_distances:
+ logger.info(
+ f"Persona description distance to {persona_distance.name}",
+ distance=persona_distance.distance,
+ )
+ # If the distance is less than the threshold, the persona description is too similar
+ if persona_distance.distance < self._persona_diff_desc_threshold:
+ return False
+ return True
+
+ async def _validate_persona_description(
+ self, persona_desc: str, exclude_id: str = None
+ ) -> np.ndarray:
+ """
+ Validate the persona description by embedding the text and checking if it is
+ different enough from existing personas.
+ """
+ emb_persona_desc = await self._embed_text(persona_desc)
+ if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
+ raise PersonaSimilarDescriptionError(
+ "The persona description is too similar to existing personas."
+ )
+ return emb_persona_desc
+
async def add_persona(self, persona_name: str, persona_desc: str) -> None:
"""
Add a new persona to the database. The persona description is embedded
and stored in the database.
"""
- emb_persona_desc = await self._embed_text(persona_desc)
+ emb_persona_desc = await self._validate_persona_description(persona_desc)
+
new_persona = db_models.PersonaEmbedding(
id=str(uuid.uuid4()),
name=persona_name,
@@ -120,6 +165,58 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None:
await self._db_recorder.add_persona(new_persona)
logger.info(f"Added persona {persona_name} to the database.")
+ async def get_persona(self, persona_name: str) -> db_models.Persona:
+ """
+ Get a persona from the database by name.
+ """
+ persona = await self._db_reader.get_persona_by_name(persona_name)
+ if not persona:
+ raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
+ return persona
+
+ async def get_all_personas(self) -> List[db_models.Persona]:
+ """
+ Get all personas from the database.
+ """
+ return await self._db_reader.get_all_personas()
+
+ async def update_persona(
+ self, persona_name: str, new_persona_name: str, new_persona_desc: str
+ ) -> None:
+ """
+ Update an existing persona in the database. The name and description are
+ updated in the database, but the ID remains the same.
+ """
+ # First we check if the persona exists, if not we raise an error
+ found_persona = await self._db_reader.get_persona_by_name(persona_name)
+ if not found_persona:
+ raise PersonaDoesNotExistError(f"Person {persona_name} does not exist.")
+
+ emb_persona_desc = await self._validate_persona_description(
+ new_persona_desc, exclude_id=found_persona.id
+ )
+
+ # Then we update the attributes in the database
+ updated_persona = db_models.PersonaEmbedding(
+ id=found_persona.id,
+ name=new_persona_name,
+ description=new_persona_desc,
+ description_embedding=emb_persona_desc,
+ )
+ await self._db_recorder.update_persona(updated_persona)
+ logger.info(f"Updated persona {persona_name} in the database.")
+
+ async def delete_persona(self, persona_name: str) -> None:
+ """
+ Delete a persona from the database.
+ """
+ persona = await self._db_reader.get_persona_by_name(persona_name)
+ if not persona:
+ raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
+
+ await self._db_recorder.delete_persona(persona.id)
+ logger.info(f"Deleted persona {persona_name} from the database.")
+
async def check_persona_match(self, persona_name: str, query: str) -> bool:
"""
Check if the query matches the persona description. A vector similarity
diff --git a/src/codegate/pipeline/cli/cli.py b/src/codegate/pipeline/cli/cli.py
index be2222c8..fde37f94 100644
--- a/src/codegate/pipeline/cli/cli.py
+++ b/src/codegate/pipeline/cli/cli.py
@@ -95,6 +95,25 @@ def _get_cli_from_continue(last_user_message_str: str) -> Optional[re.Match[str]
return codegate_regex.match(last_user_message_str)
+def _get_cli_from_copilot(last_user_message_str: str) -> Optional[re.Match[str]]:
+ """
+ Process Copilot-specific CLI command format.
+
+ Copilot sends messages in the format:
+ file contentscodegate command
+
+ Args:
+ last_user_message_str (str): The message string from Copilot
+
+ Returns:
+ Optional[re.Match[str]]: A regex match object if command is found, None otherwise
+ """
+ cleaned_text = re.sub(
+ r".*", "", last_user_message_str, flags=re.DOTALL
+ )
+ return codegate_regex.match(cleaned_text.strip())
+
+
class CodegateCli(PipelineStep):
"""Pipeline step that handles codegate cli."""
@@ -136,6 +155,8 @@ async def process(
match = _get_cli_from_open_interpreter(last_user_message_str)
elif context.client in [ClientType.CONTINUE]:
match = _get_cli_from_continue(last_user_message_str)
+ elif context.client in [ClientType.COPILOT]:
+ match = _get_cli_from_copilot(last_user_message_str)
else:
# Check if "codegate" is the first word in the message
match = codegate_regex.match(last_user_message_str)
diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py
index 5b101400..c5655ec3 100644
--- a/src/codegate/pipeline/cli/commands.py
+++ b/src/codegate/pipeline/cli/commands.py
@@ -98,7 +98,6 @@ def help(self) -> str:
class CodegateCommandSubcommand(CodegateCommand):
-
@property
@abstractmethod
def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]:
@@ -174,7 +173,6 @@ async def run(self, args: List[str]) -> str:
class Workspace(CodegateCommandSubcommand):
-
def __init__(self):
self.workspace_crud = crud.WorkspaceCrud()
@@ -258,7 +256,7 @@ async def _rename_workspace(self, flags: Dict[str, str], args: List[str]) -> str
)
try:
- await self.workspace_crud.rename_workspace(old_workspace_name, new_workspace_name)
+ await self.workspace_crud.update_workspace(old_workspace_name, new_workspace_name)
except crud.WorkspaceDoesNotExistError:
return f"Workspace **{old_workspace_name}** does not exist"
except AlreadyExistsError:
@@ -410,7 +408,6 @@ def help(self) -> str:
class CustomInstructions(CodegateCommandSubcommand):
-
def __init__(self):
self.workspace_crud = crud.WorkspaceCrud()
diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py
index 96442824..706deb9b 100644
--- a/src/codegate/pipeline/pii/analyzer.py
+++ b/src/codegate/pipeline/pii/analyzer.py
@@ -1,10 +1,9 @@
-from typing import Any, List, Optional
+from typing import List, Optional
import structlog
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
-from codegate.db.models import AlertSeverity
from codegate.pipeline.base import PipelineContext
from codegate.pipeline.sensitive_data.session_store import SessionStore
diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py
index fde89428..d7f33d67 100644
--- a/src/codegate/pipeline/pii/pii.py
+++ b/src/codegate/pipeline/pii/pii.py
@@ -1,5 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
-import uuid
import regex as re
import structlog
diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py
index 527c817f..c299469e 100644
--- a/src/codegate/pipeline/secrets/secrets.py
+++ b/src/codegate/pipeline/secrets/secrets.py
@@ -316,8 +316,17 @@ async def process(
# Process all messages
for i, message in enumerate(new_request["messages"]):
if "content" in message and message["content"]:
+ message_content = message["content"]
+
+ # cline with anthropic seems to be sending a list of dicts with type:text instead of
+ # a string
+ # this hack will not be needed once we access the native functions through an API
+ # (I tested this actually)
+ if isinstance(message_content, list) and "text" in message_content[0]:
+ message_content = message_content[0]["text"]
+
redacted_content, secrets_matched = self._redact_message_content(
- message["content"], sensitive_data_manager, session_id, context
+ message_content, sensitive_data_manager, session_id, context
)
new_request["messages"][i]["content"] = redacted_content
if i > last_assistant_idx:
diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py
index 89506d15..bf467878 100644
--- a/src/codegate/pipeline/sensitive_data/manager.py
+++ b/src/codegate/pipeline/sensitive_data/manager.py
@@ -1,7 +1,8 @@
-import json
from typing import Dict, Optional
+
import pydantic
import structlog
+
from codegate.pipeline.sensitive_data.session_store import SessionStore
logger = structlog.get_logger("codegate")
diff --git a/src/codegate/pipeline/sensitive_data/session_store.py b/src/codegate/pipeline/sensitive_data/session_store.py
index 5e508847..7a33abd2 100644
--- a/src/codegate/pipeline/sensitive_data/session_store.py
+++ b/src/codegate/pipeline/sensitive_data/session_store.py
@@ -1,5 +1,5 @@
-from typing import Dict, Optional
import uuid
+from typing import Dict, Optional
class SessionStore:
diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py
index a81426a8..fbaf5b99 100644
--- a/src/codegate/workspaces/crud.py
+++ b/src/codegate/workspaces/crud.py
@@ -3,7 +3,7 @@
from uuid import uuid4 as uuid
from codegate.db import models as db_models
-from codegate.db.connection import DbReader, DbRecorder
+from codegate.db.connection import AlreadyExistsError, DbReader, DbRecorder, DbTransaction
from codegate.muxing import models as mux_models
from codegate.muxing import rulematcher
@@ -16,6 +16,10 @@ class WorkspaceDoesNotExistError(WorkspaceCrudError):
pass
+class WorkspaceNameAlreadyInUseError(WorkspaceCrudError):
+ pass
+
+
class WorkspaceAlreadyActiveError(WorkspaceCrudError):
pass
@@ -31,34 +35,73 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError):
class WorkspaceCrud:
-
def __init__(self):
self._db_reader = DbReader()
-
- async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow:
+ self._db_recorder = DbRecorder()
+
+ async def add_workspace(
+ self,
+ new_workspace_name: str,
+ custom_instructions: Optional[str] = None,
+ muxing_rules: Optional[List[mux_models.MuxRule]] = None,
+ ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]:
"""
Add a workspace
Args:
- name (str): The name of the workspace
+ new_workspace_name (str): The name of the workspace
+ system_prompt (Optional[str]): The system prompt for the workspace
+ muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace
"""
if new_workspace_name == "":
raise WorkspaceCrudError("Workspace name cannot be empty.")
if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS:
raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.")
- db_recorder = DbRecorder()
- workspace_created = await db_recorder.add_workspace(new_workspace_name)
- return workspace_created
- async def rename_workspace(
- self, old_workspace_name: str, new_workspace_name: str
- ) -> db_models.WorkspaceRow:
+ async with DbTransaction() as transaction:
+ try:
+ existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name)
+ if existing_ws:
+ raise WorkspaceNameAlreadyInUseError(
+ f"Workspace name {new_workspace_name} is already in use."
+ )
+
+ workspace_created = await self._db_recorder.add_workspace(new_workspace_name)
+
+ if custom_instructions:
+ workspace_created.custom_instructions = custom_instructions
+ await self._db_recorder.update_workspace(workspace_created)
+
+ mux_rules = []
+ if muxing_rules:
+ mux_rules = await self.set_muxes(new_workspace_name, muxing_rules)
+
+ await transaction.commit()
+ return workspace_created, mux_rules
+ except (
+ AlreadyExistsError,
+ WorkspaceDoesNotExistError,
+ WorkspaceNameAlreadyInUseError,
+ ) as e:
+ raise e
+ except Exception as e:
+ raise WorkspaceCrudError(f"Error adding workspace {new_workspace_name}: {str(e)}")
+
+ async def update_workspace(
+ self,
+ old_workspace_name: str,
+ new_workspace_name: str,
+ custom_instructions: Optional[str] = None,
+ muxing_rules: Optional[List[mux_models.MuxRule]] = None,
+ ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]:
"""
- Rename a workspace
+ Update a workspace
Args:
- old_name (str): The old name of the workspace
- new_name (str): The new name of the workspace
+ old_workspace_name (str): The old name of the workspace
+ new_workspace_name (str): The new name of the workspace
+ system_prompt (Optional[str]): The system prompt for the workspace
+ muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace
"""
if new_workspace_name == "":
raise WorkspaceCrudError("Workspace name cannot be empty.")
@@ -70,15 +113,40 @@ async def rename_workspace(
raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.")
if old_workspace_name == new_workspace_name:
raise WorkspaceCrudError("Old and new workspace names are the same.")
- ws = await self._db_reader.get_workspace_by_name(old_workspace_name)
- if not ws:
- raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.")
- db_recorder = DbRecorder()
- new_ws = db_models.WorkspaceRow(
- id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
- )
- workspace_renamed = await db_recorder.update_workspace(new_ws)
- return workspace_renamed
+
+ async with DbTransaction() as transaction:
+ try:
+ ws = await self._db_reader.get_workspace_by_name(old_workspace_name)
+ if not ws:
+ raise WorkspaceDoesNotExistError(
+ f"Workspace {old_workspace_name} does not exist."
+ )
+
+ existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name)
+ if existing_ws:
+ raise WorkspaceNameAlreadyInUseError(
+ f"Workspace name {new_workspace_name} is already in use."
+ )
+
+ new_ws = db_models.WorkspaceRow(
+ id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
+ )
+ workspace_renamed = await self._db_recorder.update_workspace(new_ws)
+
+ if custom_instructions:
+ workspace_renamed.custom_instructions = custom_instructions
+ await self._db_recorder.update_workspace(workspace_renamed)
+
+ mux_rules = []
+ if muxing_rules:
+ mux_rules = await self.set_muxes(new_workspace_name, muxing_rules)
+
+ await transaction.commit()
+ return workspace_renamed, mux_rules
+ except (WorkspaceNameAlreadyInUseError, WorkspaceDoesNotExistError) as e:
+ raise e
+ except Exception as e:
+ raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}")
async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]:
"""
@@ -128,8 +196,7 @@ async def activate_workspace(self, workspace_name: str):
session.active_workspace_id = workspace.id
session.last_update = datetime.datetime.now(datetime.timezone.utc)
- db_recorder = DbRecorder()
- await db_recorder.update_session(session)
+ await self._db_recorder.update_session(session)
# Ensure the mux registry is updated
mux_registry = await rulematcher.get_muxing_rules_registry()
@@ -144,8 +211,7 @@ async def recover_workspace(self, workspace_name: str):
if not selected_workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
- db_recorder = DbRecorder()
- await db_recorder.recover_workspace(selected_workspace)
+ await self._db_recorder.recover_workspace(selected_workspace)
return
async def update_workspace_custom_instructions(
@@ -161,8 +227,7 @@ async def update_workspace_custom_instructions(
name=selected_workspace.name,
custom_instructions=custom_instructions,
)
- db_recorder = DbRecorder()
- updated_workspace = await db_recorder.update_workspace(workspace_update)
+ updated_workspace = await self._db_recorder.update_workspace(workspace_update)
return updated_workspace
async def soft_delete_workspace(self, workspace_name: str):
@@ -183,9 +248,8 @@ async def soft_delete_workspace(self, workspace_name: str):
if active_workspace and active_workspace.id == selected_workspace.id:
raise WorkspaceCrudError("Cannot archive active workspace.")
- db_recorder = DbRecorder()
try:
- _ = await db_recorder.soft_delete_workspace(selected_workspace)
+ _ = await self._db_recorder.soft_delete_workspace(selected_workspace)
except Exception:
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")
@@ -205,9 +269,8 @@ async def hard_delete_workspace(self, workspace_name: str):
if not selected_workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
- db_recorder = DbRecorder()
try:
- _ = await db_recorder.hard_delete_workspace(selected_workspace)
+ _ = await self._db_recorder.hard_delete_workspace(selected_workspace)
except Exception:
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")
return
@@ -247,15 +310,16 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]:
return muxes
- async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None:
+ async def set_muxes(
+ self, workspace_name: str, muxes: List[mux_models.MuxRule]
+ ) -> List[db_models.MuxRule]:
# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
# Delete all muxes for the workspace
- db_recorder = DbRecorder()
- await db_recorder.delete_muxes_by_workspace(workspace.id)
+ await self._db_recorder.delete_muxes_by_workspace(workspace.id)
# Add the new muxes
priority = 0
@@ -268,6 +332,7 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non
muxes_with_routes.append((mux, route))
matchers: List[rulematcher.MuxingRuleMatcher] = []
+ dbmuxes: List[db_models.MuxRule] = []
for mux, route in muxes_with_routes:
new_mux = db_models.MuxRule(
@@ -279,7 +344,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non
matcher_blob=mux.matcher if mux.matcher else "",
priority=priority,
)
- dbmux = await db_recorder.add_mux(new_mux)
+ dbmux = await self._db_recorder.add_mux(new_mux)
+ dbmuxes.append(dbmux)
matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route))
@@ -289,6 +355,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non
mux_registry = await rulematcher.get_muxing_rules_registry()
await mux_registry.set_ws_rules(workspace_name, matchers)
+ return dbmuxes
+
async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux
diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py
new file mode 100644
index 00000000..8bfcbfaf
--- /dev/null
+++ b/tests/api/test_v1_workspaces.py
@@ -0,0 +1,378 @@
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+from uuid import uuid4 as uuid
+
+import httpx
+import pytest
+import structlog
+from httpx import AsyncClient
+
+from codegate.db import connection
+from codegate.pipeline.factory import PipelineFactory
+from codegate.providers.crud.crud import ProviderCrud
+from codegate.server import init_app
+from codegate.workspaces.crud import WorkspaceCrud
+
+logger = structlog.get_logger("codegate")
+
+
+@pytest.fixture
+def db_path():
+ """Creates a temporary database file path."""
+ current_test_dir = Path(__file__).parent
+ db_filepath = current_test_dir / f"codegate_test_{uuid()}.db"
+ db_fullpath = db_filepath.absolute()
+ connection.init_db_sync(str(db_fullpath))
+ yield db_fullpath
+ if db_fullpath.is_file():
+ db_fullpath.unlink()
+
+
+@pytest.fixture()
+def db_recorder(db_path) -> connection.DbRecorder:
+ """Creates a DbRecorder instance with test database."""
+ return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True)
+
+
+@pytest.fixture()
+def db_reader(db_path) -> connection.DbReader:
+ """Creates a DbReader instance with test database."""
+ return connection.DbReader(sqlite_path=db_path, _no_singleton=True)
+
+
+@pytest.fixture()
+def mock_workspace_crud(db_recorder, db_reader) -> WorkspaceCrud:
+ """Creates a WorkspaceCrud instance with test database."""
+ ws_crud = WorkspaceCrud()
+ ws_crud._db_reader = db_reader
+ ws_crud._db_recorder = db_recorder
+ return ws_crud
+
+
+@pytest.fixture()
+def mock_provider_crud(db_recorder, db_reader, mock_workspace_crud) -> ProviderCrud:
+ """Creates a ProviderCrud instance with test database."""
+ p_crud = ProviderCrud()
+ p_crud._db_reader = db_reader
+ p_crud._db_writer = db_recorder
+ p_crud._ws_crud = mock_workspace_crud
+ return p_crud
+
+
+@pytest.fixture
+def mock_pipeline_factory():
+ """Create a mock pipeline factory."""
+ mock_factory = MagicMock(spec=PipelineFactory)
+ mock_factory.create_input_pipeline.return_value = MagicMock()
+ mock_factory.create_fim_pipeline.return_value = MagicMock()
+ mock_factory.create_output_pipeline.return_value = MagicMock()
+ mock_factory.create_fim_output_pipeline.return_value = MagicMock()
+ return mock_factory
+
+
+@pytest.mark.asyncio
+async def test_create_update_workspace_happy_path(
+ mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
+) -> None:
+ with (
+ patch("codegate.api.v1.wscrud", mock_workspace_crud),
+ patch("codegate.api.v1.pcrud", mock_provider_crud),
+ patch(
+ "codegate.providers.openai.provider.OpenAIProvider.models",
+ return_value=["foo-bar-001", "foo-bar-002"],
+ ),
+ ):
+ """Test creating & updating a workspace (happy path)."""
+
+ app = init_app(mock_pipeline_factory)
+
+ provider_payload_1 = {
+ "name": "foo",
+ "description": "",
+ "auth_type": "none",
+ "provider_type": "openai",
+ "endpoint": "https://api.openai.com",
+ "api_key": "sk-proj-foo-bar-123-xzy",
+ }
+
+ provider_payload_2 = {
+ "name": "bar",
+ "description": "",
+ "auth_type": "none",
+ "provider_type": "openai",
+ "endpoint": "https://api.openai.com",
+ "api_key": "sk-proj-foo-bar-123-xzy",
+ }
+
+ async with AsyncClient(
+ transport=httpx.ASGITransport(app=app), base_url="http://test"
+ ) as ac:
+ # Create the first provider
+ response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1)
+ assert response.status_code == 201
+ provider_1 = response.json()
+
+ # Create the second provider
+ response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2)
+ assert response.status_code == 201
+ provider_2 = response.json()
+
+ name_1: str = str(uuid())
+ custom_instructions_1: str = "Respond to every request in iambic pentameter"
+ muxing_rules_1 = [
+ {
+ "provider_name": None, # optional & not implemented yet
+ "provider_id": provider_1["id"],
+ "model": "foo-bar-001",
+ "matcher": "*.ts",
+ "matcher_type": "filename_match",
+ },
+ {
+ "provider_name": None, # optional & not implemented yet
+ "provider_id": provider_2["id"],
+ "model": "foo-bar-002",
+ "matcher_type": "catch_all",
+ "matcher": "",
+ },
+ ]
+
+ payload_create = {
+ "name": name_1,
+ "config": {
+ "custom_instructions": custom_instructions_1,
+ "muxing_rules": muxing_rules_1,
+ },
+ }
+
+ response = await ac.post("/api/v1/workspaces", json=payload_create)
+ assert response.status_code == 201
+ response_body = response.json()
+
+ assert response_body["name"] == name_1
+ assert response_body["config"]["custom_instructions"] == custom_instructions_1
+ for i, rule in enumerate(response_body["config"]["muxing_rules"]):
+ assert rule["model"] == muxing_rules_1[i]["model"]
+ assert rule["matcher"] == muxing_rules_1[i]["matcher"]
+ assert rule["matcher_type"] == muxing_rules_1[i]["matcher_type"]
+
+ name_2: str = str(uuid())
+ custom_instructions_2: str = "Respond to every request in cockney rhyming slang"
+ muxing_rules_2 = [
+ {
+ "provider_name": None, # optional & not implemented yet
+ "provider_id": provider_2["id"],
+ "model": "foo-bar-002",
+ "matcher": "*.ts",
+ "matcher_type": "filename_match",
+ },
+ {
+ "provider_name": None, # optional & not implemented yet
+ "provider_id": provider_1["id"],
+ "model": "foo-bar-001",
+ "matcher_type": "catch_all",
+ "matcher": "",
+ },
+ ]
+
+ payload_update = {
+ "name": name_2,
+ "config": {
+ "custom_instructions": custom_instructions_2,
+ "muxing_rules": muxing_rules_2,
+ },
+ }
+
+ response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update)
+ assert response.status_code == 201
+ response_body = response.json()
+
+ assert response_body["name"] == name_2
+ assert response_body["config"]["custom_instructions"] == custom_instructions_2
+ for i, rule in enumerate(response_body["config"]["muxing_rules"]):
+ assert rule["model"] == muxing_rules_2[i]["model"]
+ assert rule["matcher"] == muxing_rules_2[i]["matcher"]
+ assert rule["matcher_type"] == muxing_rules_2[i]["matcher_type"]
+
+
+@pytest.mark.asyncio
+async def test_create_update_workspace_name_only(
+ mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
+) -> None:
+ with (
+ patch("codegate.api.v1.wscrud", mock_workspace_crud),
+ patch("codegate.api.v1.pcrud", mock_provider_crud),
+ patch(
+ "codegate.providers.openai.provider.OpenAIProvider.models",
+ return_value=["foo-bar-001", "foo-bar-002"],
+ ),
+ ):
+ """Test creating & updating a workspace (happy path)."""
+
+ app = init_app(mock_pipeline_factory)
+
+ async with AsyncClient(
+ transport=httpx.ASGITransport(app=app), base_url="http://test"
+ ) as ac:
+ name_1: str = str(uuid())
+
+ payload_create = {
+ "name": name_1,
+ }
+
+ response = await ac.post("/api/v1/workspaces", json=payload_create)
+ assert response.status_code == 201
+ response_body = response.json()
+
+ assert response_body["name"] == name_1
+
+ name_2: str = str(uuid())
+
+ payload_update = {
+ "name": name_2,
+ }
+
+ response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update)
+ assert response.status_code == 201
+ response_body = response.json()
+
+ assert response_body["name"] == name_2
+
+
+@pytest.mark.asyncio
+async def test_create_workspace_name_already_in_use(
+ mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
+) -> None:
+ with (
+ patch("codegate.api.v1.wscrud", mock_workspace_crud),
+ patch("codegate.api.v1.pcrud", mock_provider_crud),
+ patch(
+ "codegate.providers.openai.provider.OpenAIProvider.models",
+ return_value=["foo-bar-001", "foo-bar-002"],
+ ),
+ ):
+ """Test creating a workspace when the name is already in use."""
+
+ app = init_app(mock_pipeline_factory)
+
+ async with AsyncClient(
+ transport=httpx.ASGITransport(app=app), base_url="http://test"
+ ) as ac:
+ name: str = str(uuid())
+
+ payload_create = {
+ "name": name,
+ }
+
+ # Create the workspace for the first time
+ response = await ac.post("/api/v1/workspaces", json=payload_create)
+ assert response.status_code == 201
+
+ # Try to create the workspace again with the same name
+ response = await ac.post("/api/v1/workspaces", json=payload_create)
+ assert response.status_code == 409
+ assert response.json()["detail"] == "Workspace name already in use"
+
+
+@pytest.mark.asyncio
+async def test_rename_workspace_name_already_in_use(
+ mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
+) -> None:
+ with (
+ patch("codegate.api.v1.wscrud", mock_workspace_crud),
+ patch("codegate.api.v1.pcrud", mock_provider_crud),
+ patch(
+ "codegate.providers.openai.provider.OpenAIProvider.models",
+ return_value=["foo-bar-001", "foo-bar-002"],
+ ),
+ ):
+ """Test renaming a workspace when the new name is already in use."""
+
+ app = init_app(mock_pipeline_factory)
+
+ async with AsyncClient(
+ transport=httpx.ASGITransport(app=app), base_url="http://test"
+ ) as ac:
+ name_1: str = str(uuid())
+ name_2: str = str(uuid())
+
+ payload_create_1 = {
+ "name": name_1,
+ }
+
+ payload_create_2 = {
+ "name": name_2,
+ }
+
+ # Create two workspaces
+ response = await ac.post("/api/v1/workspaces", json=payload_create_1)
+ assert response.status_code == 201
+
+ response = await ac.post("/api/v1/workspaces", json=payload_create_2)
+ assert response.status_code == 201
+
+ # Try to rename the first workspace to the name of the second workspace
+ payload_update = {
+ "name": name_2,
+ }
+
+ response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update)
+ assert response.status_code == 409
+ assert response.json()["detail"] == "Workspace name already in use"
+
+
+@pytest.mark.asyncio
+async def test_create_workspace_with_nonexistent_model_in_muxing_rule(
+ mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
+) -> None:
+ with (
+ patch("codegate.api.v1.wscrud", mock_workspace_crud),
+ patch("codegate.api.v1.pcrud", mock_provider_crud),
+ patch(
+ "codegate.providers.openai.provider.OpenAIProvider.models",
+ return_value=["foo-bar-001", "foo-bar-002"],
+ ),
+ ):
+ """Test creating a workspace with a muxing rule that uses a nonexistent model."""
+
+ app = init_app(mock_pipeline_factory)
+
+ provider_payload = {
+ "name": "foo",
+ "description": "",
+ "auth_type": "none",
+ "provider_type": "openai",
+ "endpoint": "https://api.openai.com",
+ "api_key": "sk-proj-foo-bar-123-xzy",
+ }
+
+ async with AsyncClient(
+ transport=httpx.ASGITransport(app=app), base_url="http://test"
+ ) as ac:
+ # Create the first provider
+ response = await ac.post("/api/v1/provider-endpoints", json=provider_payload)
+ assert response.status_code == 201
+ provider = response.json()
+
+ name: str = str(uuid())
+ custom_instructions: str = "Respond to every request in iambic pentameter"
+ muxing_rules = [
+ {
+ "provider_name": None,
+ "provider_id": provider["id"],
+ "model": "nonexistent-model",
+ "matcher": "*.ts",
+ "matcher_type": "filename_match",
+ },
+ ]
+
+ payload_create = {
+ "name": name,
+ "config": {
+ "custom_instructions": custom_instructions,
+ "muxing_rules": muxing_rules,
+ },
+ }
+
+ response = await ac.post("/api/v1/workspaces", json=payload_create)
+ assert response.status_code == 400
+ assert "Model nonexistent-model does not exist" in response.json()["detail"]
diff --git a/tests/muxing/test_persona.py b/tests/muxing/test_persona.py
new file mode 100644
index 00000000..fd0003c9
--- /dev/null
+++ b/tests/muxing/test_persona.py
@@ -0,0 +1,490 @@
+import uuid
+from pathlib import Path
+from typing import List
+
+import pytest
+from pydantic import BaseModel, ValidationError
+
+from codegate.db import connection
+from codegate.muxing.persona import (
+ PersonaDoesNotExistError,
+ PersonaManager,
+ PersonaSimilarDescriptionError,
+)
+
+
+@pytest.fixture
+def db_path():
+ """Creates a temporary database file path."""
+ current_test_dir = Path(__file__).parent
+ db_filepath = current_test_dir / f"codegate_test_{uuid.uuid4()}.db"
+ db_fullpath = db_filepath.absolute()
+ connection.init_db_sync(str(db_fullpath))
+ yield db_fullpath
+ if db_fullpath.is_file():
+ db_fullpath.unlink()
+
+
+@pytest.fixture()
+def db_recorder(db_path) -> connection.DbRecorder:
+ """Creates a DbRecorder instance with test database."""
+ return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True)
+
+
+@pytest.fixture()
+def db_reader(db_path) -> connection.DbReader:
+ """Creates a DbReader instance with test database."""
+ return connection.DbReader(sqlite_path=db_path, _no_singleton=True)
+
+
+@pytest.fixture()
+def semantic_router_mocked_db(
+ db_recorder: connection.DbRecorder, db_reader: connection.DbReader
+) -> PersonaManager:
+ """Creates a SemanticRouter instance with mocked database."""
+ semantic_router = PersonaManager()
+ semantic_router._db_reader = db_reader
+ semantic_router._db_recorder = db_recorder
+ return semantic_router
+
+
+@pytest.mark.asyncio
+async def test_add_persona(semantic_router_mocked_db: PersonaManager):
+ """Test adding a persona to the database."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+ retrieved_persona = await semantic_router_mocked_db.get_persona(persona_name)
+ assert retrieved_persona.name == persona_name
+ assert retrieved_persona.description == persona_desc
+
+
+@pytest.mark.asyncio
+async def test_add_duplicate_persona(semantic_router_mocked_db: PersonaManager):
+ """Test adding a persona to the database."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ # Update the description to not trigger the similarity check
+ updated_description = "foo and bar description"
+ with pytest.raises(connection.AlreadyExistsError):
+ await semantic_router_mocked_db.add_persona(persona_name, updated_description)
+
+
+@pytest.mark.asyncio
+async def test_add_persona_invalid_name(semantic_router_mocked_db: PersonaManager):
+ """Test adding a persona to the database."""
+ persona_name = "test_persona&"
+ persona_desc = "test_persona_desc"
+ with pytest.raises(ValidationError):
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ with pytest.raises(PersonaDoesNotExistError):
+ await semantic_router_mocked_db.delete_persona(persona_name)
+
+
+@pytest.mark.asyncio
+async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager):
+ """Test checking persona match when persona does not exist"""
+ persona_name = "test_persona"
+ query = "test_query"
+ with pytest.raises(PersonaDoesNotExistError):
+ await semantic_router_mocked_db.check_persona_match(persona_name, query)
+
+
+class PersonaMatchTest(BaseModel):
+ persona_name: str
+ persona_desc: str
+ pass_queries: List[str]
+ fail_queries: List[str]
+
+
+simple_persona = PersonaMatchTest(
+ persona_name="test_persona",
+ persona_desc="test_desc",
+ pass_queries=["test_desc", "test_desc2"],
+ fail_queries=["foo"],
+)
+
+# Architect Persona
+architect = PersonaMatchTest(
+ persona_name="architect",
+ persona_desc="""
+ Expert in designing and planning software systems, technical infrastructure, and solution
+ architecture.
+ Specializes in creating scalable, maintainable, and resilient system designs.
+ Deep knowledge of architectural patterns, principles, and best practices.
+ Experienced in evaluating technology stacks and making strategic technical decisions.
+ Skilled at creating architecture diagrams, technical specifications, and system
+ documentation.
+ Focuses on both functional and non-functional requirements like performance, security,
+ and reliability.
+ Guides development teams on implementing complex systems and following architectural
+ guidelines.
+
+ Designs system architectures that balance business needs with technical constraints.
+ Creates technical roadmaps and migration strategies for legacy system modernization.
+ Evaluates trade-offs between different architectural approaches (monolithic, microservices,
+ serverless).
+ Implements domain-driven design principles to align software with business domains.
+
+ Develops reference architectures and technical standards for organization-wide adoption.
+ Conducts architecture reviews and provides recommendations for improvement.
+ Collaborates with stakeholders to translate business requirements into technical solutions.
+ Stays current with emerging technologies and evaluates their potential application.
+
+ Designs for cloud-native environments using containerization, orchestration, and managed
+ services.
+ Implements event-driven architectures using message queues, event buses, and streaming
+ platforms.
+ Creates data architectures that address storage, processing, and analytics requirements.
+ Develops integration strategies for connecting disparate systems and services.
+ """,
+ pass_queries=[
+ """
+ How should I design a system architecture that can scale with our growing user base?
+ """,
+ """
+ What's the best approach for migrating our monolithic application to microservices?
+ """,
+ """
+ I need to create a technical roadmap for modernizing our legacy systems. Where should
+ I start?
+ """,
+ """
+ Can you help me evaluate different cloud providers for our new infrastructure?
+ """,
+ """
+ What architectural patterns would you recommend for a distributed e-commerce platform?
+ """,
+ ],
+ fail_queries=[
+ """
+ How do I fix this specific bug in my JavaScript code?
+ """,
+ """
+ What's the syntax for a complex SQL query joining multiple tables?
+ """,
+ """
+ How do I implement authentication in my React application?
+ """,
+ """
+ What's the best way to optimize the performance of this specific function?
+ """,
+ ],
+)
+
+# Coder Persona
+coder = PersonaMatchTest(
+ persona_name="coder",
+ persona_desc="""
+ Expert in full stack development, programming, and software implementation.
+ Specializes in writing, debugging, and optimizing code across the entire technology stack.
+
+ Proficient in multiple programming languages including JavaScript, Python, Java, C#, and
+ TypeScript.
+ Implements efficient algorithms and data structures to solve complex programming challenges.
+ Develops maintainable code with appropriate patterns and practices for different contexts.
+
+ Experienced in frontend development using modern frameworks and libraries.
+ Creates responsive, accessible user interfaces with HTML, CSS, and JavaScript frameworks.
+ Implements state management, component architecture,
+ and client-side performance optimization for frontend applications.
+
+ Skilled in backend development and server-side programming.
+ Builds RESTful APIs, GraphQL services, and microservices architectures.
+ Implements authentication, authorization, and security best practices in web applications.
+ Understands best ways for different backend problems, like file uploads, caching,
+ and database interactions.
+
+ Designs and manages databases including schema design, query optimization,
+ and data modeling.
+ Works with both SQL and NoSQL databases to implement efficient data storage solutions.
+ Creates data access layers and ORM implementations for application data requirements.
+
+ Handles integration between different systems and third-party services.
+ Implements webhooks, API clients, and service communication patterns.
+ Develops data transformation and processing pipelines for various application needs.
+
+ Identifies and resolves performance issues across the application stack.
+ Uses debugging tools, profilers, and testing frameworks to ensure code quality.
+ Implements comprehensive testing strategies including unit, integration,
+ and end-to-end tests.
+ """,
+ pass_queries=[
+ """
+ How do I implement authentication in my web application?
+ """,
+ """
+ What's the best way to structure a RESTful API for my project?
+ """,
+ """
+ I need help optimizing my database queries for better performance.
+ """,
+ """
+ How should I implement state management in my frontend application?
+ """,
+ """
+ What's the differnce between SQL and NoSQL databases, and when should I use each?
+ """,
+ ],
+ fail_queries=[
+ """
+ What's the best approach for setting up a CI/CD pipeline for our team?
+ """,
+ """
+ Can you help me configure auto-scaling for our Kubernetes cluster?
+ """,
+ """
+ How should I structure our cloud infrastructure for better cost efficiency?
+ """,
+ """
+ How do I cook a delicious lasagna for dinner?
+ """,
+ ],
+)
+
+# DevOps/SRE Engineer Persona
+devops_sre = PersonaMatchTest(
+ persona_name="devops sre engineer",
+ persona_desc="""
+ Expert in infrastructure automation, deployment pipelines, and operational reliability.
+ Specializes in building and maintaining scalable, resilient, and secure infrastructure.
+ Proficient with cloud platforms (AWS, Azure, GCP), containerization, and orchestration.
+ Experienced with infrastructure as code, configuration management, and automation tools.
+ Skilled in implementing CI/CD pipelines, monitoring systems, and observability solutions.
+ Focuses on reliability, performance, security, and operational efficiency.
+ Practices site reliability engineering principles and DevOps methodologies.
+
+ Designs and implements cloud infrastructure using services like compute, storage,
+ networking, and databases.
+ Creates infrastructure as code using tools like Terraform, CloudFormation, or Pulumi.
+ Configures and manages container orchestration platforms like Kubernetes and ECS.
+ Implements CI/CD pipelines using tools like Jenkins, GitHub Actions, GitLab CI, or CircleCI.
+
+ Sets up comprehensive monitoring, alerting, and observability solutions.
+ Implements logging aggregation, metrics collection, and distributed tracing.
+ Creates dashboards and visualizations for system performance and health.
+ Designs and implements disaster recovery and backup strategies.
+
+ Automates routine operational tasks and infrastructure maintenance.
+ Conducts capacity planning, performance tuning, and cost optimization.
+ Implements security best practices, compliance controls, and access management.
+ Performs incident response, troubleshooting, and post-mortem analysis.
+
+ Designs for high availability, fault tolerance, and graceful degradation.
+ Implements auto-scaling, load balancing, and traffic management solutions.
+ Creates runbooks, documentation, and operational procedures.
+ Conducts chaos engineering experiments to improve system resilience.
+ """,
+ pass_queries=[
+ """
+ How do I set up a Kubernetes cluster with proper high availability?
+ """,
+ """
+ What's the best approach for implementing a CI/CD pipeline for our microservices?
+ """,
+ """
+ How can I automate our infrastructure provisioning using Terraform?
+ """,
+ """
+ What monitoring metrics should I track to ensure the reliability of our system?
+ """,
+ ],
+ fail_queries=[
+ """
+ How do I implement a sorting algorithm in Python?
+ """,
+ """
+ What's the best way to structure my React components for a single-page application?
+ """,
+ """
+ Can you help me design a database schema for my e-commerce application?
+ """,
+ """
+ How do I create a responsive layout using CSS Grid and Flexbox?
+ """,
+ """
+ What's the most efficient algorithm for finding the shortest path in a graph?
+ """,
+ ],
+)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "persona_match_test",
+ [
+ simple_persona,
+ architect,
+ coder,
+ devops_sre,
+ ],
+)
+async def test_check_persona_pass_match(
+ semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
+):
+ """Test checking persona match."""
+ await semantic_router_mocked_db.add_persona(
+ persona_match_test.persona_name, persona_match_test.persona_desc
+ )
+
+ # Check for the queries that should pass
+ for query in persona_match_test.pass_queries:
+ match = await semantic_router_mocked_db.check_persona_match(
+ persona_match_test.persona_name, query
+ )
+ assert match is True
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "persona_match_test",
+ [
+ simple_persona,
+ architect,
+ coder,
+ devops_sre,
+ ],
+)
+async def test_check_persona_fail_match(
+ semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
+):
+ """Test checking persona match."""
+ await semantic_router_mocked_db.add_persona(
+ persona_match_test.persona_name, persona_match_test.persona_desc
+ )
+
+ # Check for the queries that should fail
+ for query in persona_match_test.fail_queries:
+ match = await semantic_router_mocked_db.check_persona_match(
+ persona_match_test.persona_name, query
+ )
+ assert match is False
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "personas",
+ [
+ [
+ coder,
+ devops_sre,
+ architect,
+ ]
+ ],
+)
+async def test_persona_diff_description(
+ semantic_router_mocked_db: PersonaManager,
+ personas: List[PersonaMatchTest],
+):
+ # First, add all existing personas
+ for persona in personas:
+ await semantic_router_mocked_db.add_persona(persona.persona_name, persona.persona_desc)
+
+ last_added_persona = personas[-1]
+ with pytest.raises(PersonaSimilarDescriptionError):
+ await semantic_router_mocked_db.add_persona(
+ "repeated persona", last_added_persona.persona_desc
+ )
+
+
+@pytest.mark.asyncio
+async def test_update_persona(semantic_router_mocked_db: PersonaManager):
+ """Test updating a persona to the database different name and description."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ updated_description = "foo and bar description"
+ await semantic_router_mocked_db.update_persona(
+ persona_name, new_persona_name="new test persona", new_persona_desc=updated_description
+ )
+
+
+@pytest.mark.asyncio
+async def test_update_persona_same_desc(semantic_router_mocked_db: PersonaManager):
+ """Test updating a persona to the database with same description."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ await semantic_router_mocked_db.update_persona(
+ persona_name, new_persona_name="new test persona", new_persona_desc=persona_desc
+ )
+
+
+@pytest.mark.asyncio
+async def test_update_persona_not_exists(semantic_router_mocked_db: PersonaManager):
+ """Test updating a persona to the database."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+
+ with pytest.raises(PersonaDoesNotExistError):
+ await semantic_router_mocked_db.update_persona(
+ persona_name, new_persona_name="new test persona", new_persona_desc=persona_desc
+ )
+
+
+@pytest.mark.asyncio
+async def test_update_persona_same_name(semantic_router_mocked_db: PersonaManager):
+ """Test updating a persona to the database."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ persona_name_2 = "test_persona_2"
+ persona_desc_2 = "foo and bar"
+ await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2)
+
+ with pytest.raises(connection.AlreadyExistsError):
+ await semantic_router_mocked_db.update_persona(
+ persona_name_2, new_persona_name=persona_name, new_persona_desc=persona_desc_2
+ )
+
+
+@pytest.mark.asyncio
+async def test_delete_persona(semantic_router_mocked_db: PersonaManager):
+ """Test deleting a persona from the database."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ await semantic_router_mocked_db.delete_persona(persona_name)
+
+ with pytest.raises(PersonaDoesNotExistError):
+ await semantic_router_mocked_db.get_persona(persona_name)
+
+
+@pytest.mark.asyncio
+async def test_delete_persona_not_exists(semantic_router_mocked_db: PersonaManager):
+ persona_name = "test_persona"
+
+ with pytest.raises(PersonaDoesNotExistError):
+ await semantic_router_mocked_db.delete_persona(persona_name)
+
+
+@pytest.mark.asyncio
+async def test_get_personas(semantic_router_mocked_db: PersonaManager):
+ """Test getting personas from the database."""
+ persona_name = "test_persona"
+ persona_desc = "test_persona_desc"
+ await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
+
+ persona_name_2 = "test_persona_2"
+ persona_desc_2 = "foo and bar"
+ await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2)
+
+ all_personas = await semantic_router_mocked_db.get_all_personas()
+ assert len(all_personas) == 2
+ assert all_personas[0].name == persona_name
+ assert all_personas[1].name == persona_name_2
+
+
+@pytest.mark.asyncio
+async def test_get_personas_empty(semantic_router_mocked_db: PersonaManager):
+ """Test adding a persona to the database."""
+
+ all_personas = await semantic_router_mocked_db.get_all_personas()
+ assert len(all_personas) == 0
diff --git a/tests/muxing/test_semantic_router.py b/tests/muxing/test_semantic_router.py
deleted file mode 100644
index c8c7edc6..00000000
--- a/tests/muxing/test_semantic_router.py
+++ /dev/null
@@ -1,590 +0,0 @@
-import uuid
-from pathlib import Path
-from typing import List
-
-import pytest
-from pydantic import BaseModel
-
-from codegate.db import connection
-from codegate.muxing.semantic_router import PersonaDoesNotExistError, SemanticRouter
-
-
-@pytest.fixture
-def db_path():
- """Creates a temporary database file path."""
- current_test_dir = Path(__file__).parent
- db_filepath = current_test_dir / f"codegate_test_{uuid.uuid4()}.db"
- db_fullpath = db_filepath.absolute()
- connection.init_db_sync(str(db_fullpath))
- yield db_fullpath
- if db_fullpath.is_file():
- db_fullpath.unlink()
-
-
-@pytest.fixture()
-def db_recorder(db_path) -> connection.DbRecorder:
- """Creates a DbRecorder instance with test database."""
- return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True)
-
-
-@pytest.fixture()
-def db_reader(db_path) -> connection.DbReader:
- """Creates a DbReader instance with test database."""
- return connection.DbReader(sqlite_path=db_path, _no_singleton=True)
-
-
-@pytest.fixture()
-def semantic_router_mocked_db(
- db_recorder: connection.DbRecorder, db_reader: connection.DbReader
-) -> SemanticRouter:
- """Creates a SemanticRouter instance with mocked database."""
- semantic_router = SemanticRouter()
- semantic_router._db_reader = db_reader
- semantic_router._db_recorder = db_recorder
- return semantic_router
-
-
-@pytest.mark.asyncio
-async def test_add_persona(semantic_router_mocked_db: SemanticRouter):
- """Test adding a persona to the database."""
- persona_name = "test_persona"
- persona_desc = "test_persona_desc"
- await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
- retrieved_persona = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name)
- assert retrieved_persona.name == persona_name
- assert retrieved_persona.description == persona_desc
-
-
-@pytest.mark.asyncio
-async def test_persona_not_exist_match(semantic_router_mocked_db: SemanticRouter):
- """Test checking persona match when persona does not exist"""
- persona_name = "test_persona"
- query = "test_query"
- with pytest.raises(PersonaDoesNotExistError):
- await semantic_router_mocked_db.check_persona_match(persona_name, query)
-
-
-class PersonaMatchTest(BaseModel):
- persona_name: str
- persona_desc: str
- pass_queries: List[str]
- fail_queries: List[str]
-
-
-simple_persona = PersonaMatchTest(
- persona_name="test_persona",
- persona_desc="test_desc",
- pass_queries=["test_desc", "test_desc2"],
- fail_queries=["foo"],
-)
-
-software_architect = PersonaMatchTest(
- persona_name="software architect",
- persona_desc="""
- Expert in designing large-scale software systems and technical infrastructure.
- Specializes in distributed systems, microservices architecture,
- and cloud-native applications.
- Deep knowledge of architectural patterns like CQRS, event sourcing, hexagonal architecture,
- and domain-driven design.
- Experienced in designing scalable, resilient, and maintainable software solutions.
- Proficient in evaluating technology stacks and making strategic technical decisions.
- Skilled at creating architecture diagrams, technical specifications,
- and system documentation.
- Focuses on non-functional requirements like performance, security, and reliability.
- Guides development teams on best practices for implementing complex systems.
- """,
- pass_queries=[
- """
- How should I design a microservices architecture that can handle high traffic loads?
- """,
- """
- What's the best approach for implementing event sourcing in a distributed system?
- """,
- """
- I need to design a system that can scale to millions of users. What architecture would you
- recommend?
- """,
- """
- Can you explain the trade-offs between monolithic and microservices architectures for our
- new project?
- """,
- ],
- fail_queries=[
- """
- How do I create a simple landing page with HTML and CSS?
- """,
- """
- What's the best way to optimize my SQL query performance?
- """,
- """
- Can you help me debug this JavaScript function that's throwing an error?
- """,
- """
- How do I implement user authentication in my React application?
- """,
- ],
-)
-
-# Data Scientist Persona
-data_scientist = PersonaMatchTest(
- persona_name="data scientist",
- persona_desc="""
- Expert in analyzing and interpreting complex data to solve business problems.
- Specializes in statistical analysis, machine learning algorithms, and predictive modeling.
- Builds and deploys models for classification, regression, clustering, and anomaly detection.
- Proficient in data preprocessing, feature engineering, and model evaluation techniques.
- Uses Python with libraries like NumPy, Pandas, scikit-learn, TensorFlow, and PyTorch.
- Experienced with data visualization using Matplotlib, Seaborn, and interactive dashboards.
- Applies experimental design principles and A/B testing methodologies.
- Works with structured and unstructured data, including time series and text.
- Implements data pipelines for model training, validation, and deployment.
- Communicates insights and recommendations based on data analysis to stakeholders.
-
- Handles class imbalance problems in classification tasks using techniques like SMOTE,
- undersampling, oversampling, and class weighting. Addresses customer churn prediction
- challenges by identifying key features that indicate potential churners.
-
- Applies feature selection methods for high-dimensional datasets, including filter methods
- (correlation, chi-square), wrapper methods (recursive feature elimination), and embedded
- methods (LASSO regularization).
-
- Prevents overfitting and high variance in tree-based models like random forests through
- techniques such as pruning, setting maximum depth, adjusting minimum samples per leaf,
- and cross-validation.
-
- Specializes in time series forecasting for sales and demand prediction, using methods like
- ARIMA, SARIMA, Prophet, and exponential smoothing to handle seasonal patterns and trends.
- Implements forecasting models that account for quarterly business cycles and seasonal
- variations in customer behavior.
-
- Evaluates model performance using appropriate metrics: accuracy, precision, recall,
- F1-score
- for classification; RMSE, MAE, R-squared for regression; and specialized metrics for
- time series forecasting like MAPE and SMAPE.
-
- Experienced in developing customer segmentation models, recommendation systems,
- anomaly detection algorithms, and predictive maintenance solutions.
- """,
- pass_queries=[
- """
- How should I handle class imbalance in my customer churn prediction model?
- """,
- """
- What feature selection techniques would work best for my high-dimensional dataset?
- """,
- """
- I'm getting high variance in my random forest model. How can I prevent overfitting?
- """,
- """
- What's the best approach for forecasting seasonal time series data for our sales
- predictions?
- """,
- ],
- fail_queries=[
- """
- How do I structure my React components for a single-page application?
- """,
- """
- What's the best way to implement a CI/CD pipeline for my microservices?
- """,
- """
- Can you help me design a responsive layout for mobile and desktop browsers?
- """,
- """
- How should I configure my Kubernetes cluster for high availability?
- """,
- ],
-)
-
-# UX Designer Persona
-ux_designer = PersonaMatchTest(
- persona_name="ux designer",
- persona_desc="""
- Expert in creating intuitive, user-centered digital experiences and interfaces.
- Specializes in user research, usability testing, and interaction design.
- Creates wireframes, prototypes, and user flows to visualize design solutions.
- Conducts user interviews, usability studies, and analyzes user feedback.
- Develops user personas and journey maps to understand user needs and pain points.
- Designs information architecture and navigation systems for complex applications.
- Applies design thinking methodology to solve user experience problems.
- Knowledgeable about accessibility standards and inclusive design principles.
- Collaborates with product managers and developers to implement user-friendly features.
- Uses tools like Figma, Sketch, and Adobe XD to create high-fidelity mockups.
- """,
- pass_queries=[
- """
- How can I improve the user onboarding experience for my mobile application?
- """,
- """
- What usability testing methods would you recommend for evaluating our new interface design?
- """,
- """
- I'm designing a complex dashboard. What information architecture would make it most
- intuitive for users?
- """,
- """
- How should I structure user research to identify pain points in our current
- checkout process?
- """,
- ],
- fail_queries=[
- """
- How do I configure a load balancer for my web servers?
- """,
- """
- What's the best way to implement a caching layer in my application?
- """,
- """
- Can you explain how to set up a CI/CD pipeline with GitHub Actions?
- """,
- """
- How do I optimize my database queries for better performance?
- """,
- ],
-)
-
-# DevOps Engineer Persona
-devops_engineer = PersonaMatchTest(
- persona_name="devops engineer",
- persona_desc="""
- Expertise: Infrastructure automation, CI/CD pipelines, cloud services, containerization,
- and monitoring.
- Proficient with tools like Docker, Kubernetes, Terraform, Ansible, and Jenkins.
- Experienced with cloud platforms including AWS, Azure, and Google Cloud.
- Strong knowledge of Linux/Unix systems administration and shell scripting.
- Skilled in implementing microservices architectures and service mesh technologies.
- Focus on reliability, scalability, security, and operational efficiency.
- Practices infrastructure as code, GitOps, and site reliability engineering principles.
- Experienced with monitoring tools like Prometheus, Grafana, and ELK stack.
- """,
- pass_queries=[
- """
- What's the best way to set up auto-scaling for my Kubernetes cluster on AWS?
- """,
- """
- I need to implement a zero-downtime deployment strategy for my microservices.
- What approaches would you recommend?
- """,
- """
- How can I improve the security of my CI/CD pipeline and prevent supply chain attacks?
- """,
- """
- What monitoring metrics should I track to ensure the reliability of my distributed system?
- """,
- ],
- fail_queries=[
- """
- How do I design an effective user onboarding flow for my mobile app?
- """,
- """
- What's the best algorithm for sentiment analysis on customer reviews?
- """,
- """
- Can you help me with color theory for my website redesign?
- """,
- """
- I need advice on optimizing my SQL queries for a reporting dashboard.
- """,
- ],
-)
-
-# Security Specialist Persona
-security_specialist = PersonaMatchTest(
- persona_name="security specialist",
- persona_desc="""
- Expert in cybersecurity, application security, and secure system design.
- Specializes in identifying and mitigating security vulnerabilities and threats.
- Performs security assessments, penetration testing, and code security reviews.
- Implements security controls like authentication, authorization, and encryption.
- Knowledgeable about common attack vectors such as injection attacks, XSS, CSRF, and SSRF.
- Experienced with security frameworks and standards like OWASP Top 10, NIST, and ISO 27001.
- Designs secure architectures and implements defense-in-depth strategies.
- Conducts security incident response and forensic analysis.
- Implements security monitoring, logging, and alerting systems.
- Stays current with emerging security threats and mitigation techniques.
- """,
- pass_queries=[
- """
- How can I protect my web application from SQL injection attacks?
- """,
- """
- What security controls should I implement for storing sensitive user data?
- """,
- """
- How do I conduct a thorough security assessment of our cloud infrastructure?
- """,
- """
- What's the best approach for implementing secure authentication in my API?
- """,
- ],
- fail_queries=[
- """
- How do I optimize the loading speed of my website?
- """,
- """
- What's the best way to implement responsive design for mobile devices?
- """,
- """
- Can you help me design a database schema for my e-commerce application?
- """,
- """
- How should I structure my React components for better code organization?
- """,
- ],
-)
-
-# Mobile Developer Persona
-mobile_developer = PersonaMatchTest(
- persona_name="mobile developer",
- persona_desc="""
- Expert in building native and cross-platform mobile applications for iOS and Android.
- Specializes in mobile UI development, responsive layouts, and platform-specific
- design patterns.
- Proficient in Swift and SwiftUI for iOS, Kotlin for Android, and React Native or
- Flutter for cross-platform.
- Implements mobile-specific features like push notifications, offline storage, and
- location services.
- Optimizes mobile applications for performance, battery efficiency, and limited
- network connectivity.
- Experienced with mobile app architecture patterns like MVVM, MVC, and Redux.
- Integrates with device hardware features including camera, biometrics, sensors,
- and Bluetooth.
- Familiar with app store submission processes, app signing, and distribution workflows.
- Implements secure data storage, authentication, and API communication on mobile devices.
- Designs and develops responsive interfaces that work across different screen sizes
- and orientations.
-
- Implements sophisticated offline-first data synchronization strategies
- for mobile applications,
- handling conflict resolution, data merging, and background syncing when connectivity
- is restored.
- Uses technologies like Realm, SQLite, Core Data, and Room Database to enable seamless
- offline
- experiences in React Native and native apps.
-
- Structures Swift code following the MVVM (Model-View-ViewModel) architectural pattern
- to create
- maintainable, testable iOS applications. Implements proper separation of concerns
- with bindings
- between views and view models using Combine, RxSwift, or SwiftUI's native state management.
-
- Specializes in deep linking implementation for both Android and iOS, enabling app-to-app
- communication, marketing campaign tracking, and seamless user experiences when navigating
- between web and mobile contexts. Configures Universal Links, App Links, and custom URL
- schemes.
-
- Optimizes battery usage for location-based features by implementing intelligent location
- tracking
- strategies, including geofencing, significant location changes, deferred location updates,
- and
- region monitoring. Balances accuracy requirements with power consumption constraints.
-
- Develops efficient state management solutions for complex mobile applications using Redux,
- MobX, Provider, or Riverpod for React Native apps, and native state management approaches
- for iOS and Android.
-
- Creates responsive mobile interfaces that adapt to different device orientations,
- screen sizes,
- and pixel densities using constraint layouts, auto layout, size classes, and flexible
- grid systems.
- """,
- pass_queries=[
- """
- What's the best approach for implementing offline-first data synchronization in my mobile
- app?
- """,
- """
- How should I structure my Swift code to implement the MVVM pattern effectively?
- """,
- """
- What's the most efficient way to handle deep linking and app-to-app communication on
- Android?
- """,
- """
- How can I optimize battery usage when implementing background location tracking?
- """,
- ],
- fail_queries=[
- """
- How do I design a database schema with proper normalization for my web application?
- """,
- """
- What's the best approach for implementing a distributed caching layer in my microservices?
- """,
- """
- Can you help me set up a data pipeline for processing large datasets with Apache Spark?
- """,
- """
- How should I configure my load balancer to distribute traffic across my web servers?
- """,
- ],
-)
-
-# Database Administrator Persona
-database_administrator = PersonaMatchTest(
- persona_name="database administrator",
- persona_desc="""
- Expert in designing, implementing, and managing database systems for optimal performance and
- reliability.
- Specializes in database architecture, schema design, and query optimization techniques.
- Proficient with relational databases like PostgreSQL, MySQL, Oracle, and SQL Server.
- Implements and manages database security, access controls, and data protection measures.
- Designs high-availability solutions using replication, clustering, and failover mechanisms.
- Develops and executes backup strategies, disaster recovery plans, and data retention
- policies.
- Monitors database performance, identifies bottlenecks, and implements optimization
- solutions.
- Creates and maintains indexes, partitioning schemes, and other performance-enhancing
- structures.
- Experienced with database migration, version control, and change management processes.
- Implements data integrity constraints, stored procedures, triggers, and database automation.
-
- Optimizes complex JOIN query performance in PostgreSQL through advanced techniques including
- query rewriting, proper indexing strategies, materialized views, and query plan analysis.
- Uses EXPLAIN ANALYZE to identify bottlenecks in query execution plans and implements
- appropriate optimizations for specific query patterns.
-
- Designs and implements high-availability MySQL configurations with automatic failover using
- technologies like MySQL Group Replication, Galera Cluster, Percona XtraDB Cluster, or MySQL
- InnoDB Cluster with MySQL Router. Configures synchronous and asynchronous replication
- strategies
- to balance consistency and performance requirements.
-
- Develops sophisticated indexing strategies for tables with frequent write operations and
- complex
- read queries, balancing write performance with read optimization. Implements partial
- indexes,
- covering indexes, and composite indexes based on query patterns and cardinality analysis.
-
- Specializes in large-scale database migrations between different database engines,
- particularly
- Oracle to PostgreSQL transitions. Uses tools like ora2pg, AWS DMS, and custom ETL processes
- to
- ensure data integrity, schema compatibility, and minimal downtime during migration.
-
- Implements table partitioning schemes based on data access patterns, including range
- partitioning
- for time-series data, list partitioning for categorical data, and hash partitioning for
- evenly
- distributed workloads.
-
- Configures and manages database connection pooling, query caching, and buffer management to
- optimize resource utilization and throughput under varying workloads.
-
- Designs and implements database sharding strategies for horizontal scaling, including
- consistent hashing algorithms, shard key selection, and cross-shard query optimization.
- """,
- pass_queries=[
- """
- How can I optimize the performance of complex JOIN queries in my PostgreSQL database?
- """,
- """
- What's the best approach for implementing a high-availability MySQL setup with automatic
- failover?
- """,
- """
- How should I design my indexing strategy for a table with frequent writes and complex read
- queries?
- """,
- """
- What's the most efficient way to migrate a large Oracle database to PostgreSQL with minimal
- downtime?
- """,
- ],
- fail_queries=[
- """
- How do I structure my React components to implement the Redux state management pattern?
- """,
- """
- What's the best approach for implementing responsive design with CSS Grid and Flexbox?
- """,
- """
- Can you help me set up a CI/CD pipeline for my containerized microservices?
- """,
- ],
-)
-
-# Natural Language Processing Specialist Persona
-nlp_specialist = PersonaMatchTest(
- persona_name="nlp specialist",
- persona_desc="""
- Expertise: Natural language processing, computational linguistics, and text analytics.
- Proficient with NLP libraries and frameworks like NLTK, spaCy, Hugging Face Transformers,
- and Gensim.
- Experience with language models such as BERT, GPT, T5, and their applications.
- Skilled in text preprocessing, tokenization, lemmatization, and feature extraction
- techniques.
- Knowledge of sentiment analysis, named entity recognition, topic modeling, and text
- classification.
- Familiar with word embeddings, contextual embeddings, and language representation methods.
- Understanding of machine translation, question answering, and text summarization systems.
- Background in information retrieval, semantic search, and conversational AI development.
- """,
- pass_queries=[
- """
- What approach should I take to fine-tune BERT for my custom text classification task?
- """,
- """
- How can I improve the accuracy of my named entity recognition system for medical texts?
- """,
- """
- What's the best way to implement semantic search using embeddings from language models?
- """,
- """
- I need to build a sentiment analysis system that can handle sarcasm and idioms.
- Any suggestions?
- """,
- ],
- fail_queries=[
- """
- How do I optimize my React components to reduce rendering time?
- """,
- """
- What's the best approach for implementing a CI/CD pipeline with Jenkins?
- """,
- """
- Can you help me design a responsive UI for my web application?
- """,
- """
- How should I structure my microservices architecture for scalability?
- """,
- ],
-)
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
- "persona_match_test",
- [
- simple_persona,
- software_architect,
- data_scientist,
- ux_designer,
- devops_engineer,
- security_specialist,
- mobile_developer,
- database_administrator,
- nlp_specialist,
- ],
-)
-async def test_check_persona_match(
- semantic_router_mocked_db: SemanticRouter, persona_match_test: PersonaMatchTest
-):
- """Test checking persona match."""
- await semantic_router_mocked_db.add_persona(
- persona_match_test.persona_name, persona_match_test.persona_desc
- )
-
- # Check for the queries that should pass
- for query in persona_match_test.pass_queries:
- match = await semantic_router_mocked_db.check_persona_match(
- persona_match_test.persona_name, query
- )
- assert match is True
-
- # Check for the queries that should fail
- for query in persona_match_test.fail_queries:
- match = await semantic_router_mocked_db.check_persona_match(
- persona_match_test.persona_name, query
- )
- assert match is False
diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py
index d626b8cf..e856653c 100644
--- a/tests/pipeline/pii/test_analyzer.py
+++ b/tests/pipeline/pii/test_analyzer.py
@@ -1,7 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest
-from presidio_analyzer import RecognizerResult
from codegate.pipeline.pii.analyzer import PiiAnalyzer
diff --git a/tests/pipeline/sensitive_data/test_manager.py b/tests/pipeline/sensitive_data/test_manager.py
index 6115ad14..66305388 100644
--- a/tests/pipeline/sensitive_data/test_manager.py
+++ b/tests/pipeline/sensitive_data/test_manager.py
@@ -1,6 +1,7 @@
-import json
from unittest.mock import MagicMock, patch
+
import pytest
+
from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager
from codegate.pipeline.sensitive_data.session_store import SessionStore
diff --git a/tests/pipeline/sensitive_data/test_session_store.py b/tests/pipeline/sensitive_data/test_session_store.py
index b9ab64fe..e90b953e 100644
--- a/tests/pipeline/sensitive_data/test_session_store.py
+++ b/tests/pipeline/sensitive_data/test_session_store.py
@@ -1,5 +1,5 @@
-import uuid
import pytest
+
from codegate.pipeline.sensitive_data.session_store import SessionStore
diff --git a/tests/test_server.py b/tests/test_server.py
index aa549810..bcf55e7e 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -14,7 +14,6 @@
from codegate import __version__
from codegate.pipeline.factory import PipelineFactory
-from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
from codegate.providers.registry import ProviderRegistry
from codegate.server import init_app
from src.codegate.cli import UvicornServer, cli