diff --git a/.github/workflows/import_packages.yml b/.github/workflows/import_packages.yml index 3da31b63..e7ada4d4 100644 --- a/.github/workflows/import_packages.yml +++ b/.github/workflows/import_packages.yml @@ -47,6 +47,7 @@ jobs: MALICIOUS_KEY=$(jq -r '.latest.malicious_packages' manifest.json) DEPRECATED_KEY=$(jq -r '.latest.deprecated_packages' manifest.json) ARCHIVED_KEY=$(jq -r '.latest.archived_packages' manifest.json) + VULNERABLE_KEY=$(jq -r '.latest.vulnerable_packages' manifest.json) echo "Malicious key: $MALICIOUS_KEY" echo "Deprecated key: $DEPRECATED_KEY" @@ -58,6 +59,7 @@ jobs: aws s3 cp s3://codegate-data-prod/$MALICIOUS_KEY /tmp/jsonl-files/malicious.jsonl --region $AWS_REGION aws s3 cp s3://codegate-data-prod/$DEPRECATED_KEY /tmp/jsonl-files/deprecated.jsonl --region $AWS_REGION aws s3 cp s3://codegate-data-prod/$ARCHIVED_KEY /tmp/jsonl-files/archived.jsonl --region $AWS_REGION + aws s3 cp s3://codegate-data-prod/$VULNERABLE_KEY /tmp/jsonl-files/vulnerable.jsonl --region $AWS_REGION - name: Install Poetry run: | diff --git a/Dockerfile b/Dockerfile index a12d0f76..70849c13 100644 --- a/Dockerfile +++ b/Dockerfile @@ -72,6 +72,7 @@ FROM python:3.12-slim AS runtime RUN apt-get update && apt-get install -y --no-install-recommends \ libgomp1 \ nginx \ + gettext-base \ && rm -rf /var/lib/apt/lists/* # Create a non-root user @@ -81,6 +82,7 @@ RUN useradd -m -u 1000 -r codegate # Set permissions for user codegate to run nginx RUN chown -R codegate /var/lib/nginx && \ chown -R codegate /var/log/nginx && \ + chown -R codegate /etc/nginx && \ chown -R codegate /run COPY nginx.conf /etc/nginx/nginx.conf diff --git a/api/openapi.json b/api/openapi.json index deb5c2de..4ea57d0e 100644 --- a/api/openapi.json +++ b/api/openapi.json @@ -418,7 +418,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateOrRenameWorkspaceRequest" + "$ref": "#/components/schemas/FullWorkspace-Input" } } }, @@ -430,7 +430,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/Workspace" + "$ref": "#/components/schemas/FullWorkspace-Output" } } } @@ -522,6 +522,58 @@ } }, "/api/v1/workspaces/{workspace_name}": { + "put": { + "tags": [ + "CodeGate API", + "Workspaces" + ], + "summary": "Update Workspace", + "description": "Update a workspace.", + "operationId": "v1_update_workspace", + "parameters": [ + { + "name": "workspace_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workspace Name" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FullWorkspace-Input" + } + } + } + }, + "responses": { + "201": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FullWorkspace-Output" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, "delete": { "tags": [ "CodeGate API", @@ -720,6 +772,50 @@ } } }, + "/api/v1/workspaces/{workspace_name}/alerts-summary": { + "get": { + "tags": [ + "CodeGate API", + "Workspaces" + ], + "summary": "Get Workspace Alerts Summary", + "description": "Get alert summary for a workspace.", + "operationId": "v1_get_workspace_alerts_summary", + "parameters": [ + { + "name": "workspace_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workspace Name" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AlertSummary" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/api/v1/workspaces/{workspace_name}/messages": { "get": { "tags": [ @@ -1123,6 +1219,205 @@ } } } + }, + "/api/v1/personas": { + "get": { + "tags": [ + "CodeGate API", + "Personas" + ], + "summary": "List Personas", + "description": "List all personas.", + "operationId": "v1_list_personas", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/Persona" + }, + "type": "array", + "title": "Response V1 List Personas" + } + } + } + } + } + }, + "post": { + "tags": [ + "CodeGate API", + "Personas" + ], + "summary": "Create Persona", + "description": "Create a new persona.", + "operationId": "v1_create_persona", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PersonaRequest" + } + } + }, + "required": true + }, + "responses": { + "201": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Persona" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/api/v1/personas/{persona_name}": { + "get": { + "tags": [ + "CodeGate API", + "Personas" + ], + "summary": "Get Persona", + "description": "Get a persona by name.", + "operationId": "v1_get_persona", + "parameters": [ + { + "name": "persona_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Persona Name" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Persona" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "put": { + "tags": [ + "CodeGate API", + "Personas" + ], + "summary": "Update Persona", + "description": "Update an existing persona.", + "operationId": "v1_update_persona", + "parameters": [ + { + "name": "persona_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Persona Name" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PersonaUpdateRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Persona" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": [ + "CodeGate API", + "Personas" + ], + "summary": "Delete Persona", + "description": "Delete a persona.", + "operationId": "v1_delete_persona", + "parameters": [ + { + "name": "persona_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Persona Name" + } + } + ], + "responses": { + "204": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } } }, "components": { @@ -1352,6 +1647,30 @@ ], "title": "AlertSeverity" }, + "AlertSummary": { + "properties": { + "malicious_packages": { + "type": "integer", + "title": "Malicious Packages" + }, + "pii": { + "type": "integer", + "title": "Pii" + }, + "secrets": { + "type": "integer", + "title": "Secrets" + } + }, + "type": "object", + "required": [ + "malicious_packages", + "pii", + "secrets" + ], + "title": "AlertSummary", + "description": "Represents a set of summary alerts" + }, "ChatMessage": { "properties": { "message": { @@ -1521,7 +1840,20 @@ "title": "Conversation", "description": "Represents a conversation." }, - "CreateOrRenameWorkspaceRequest": { + "CustomInstructions": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt" + } + }, + "type": "object", + "required": [ + "prompt" + ], + "title": "CustomInstructions" + }, + "FullWorkspace-Input": { "properties": { "name": { "type": "string", @@ -1530,43 +1862,42 @@ "config": { "anyOf": [ { - "$ref": "#/components/schemas/WorkspaceConfig" + "$ref": "#/components/schemas/WorkspaceConfig-Input" }, { "type": "null" } ] - }, - "rename_to": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Rename To" } }, "type": "object", "required": [ "name" ], - "title": "CreateOrRenameWorkspaceRequest" + "title": "FullWorkspace" }, - "CustomInstructions": { + "FullWorkspace-Output": { "properties": { - "prompt": { + "name": { "type": "string", - "title": "Prompt" + "title": "Name" + }, + "config": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkspaceConfig-Output" + }, + { + "type": "null" + } + ] } }, "type": "object", "required": [ - "prompt" + "name" ], - "title": "CustomInstructions" + "title": "FullWorkspace" }, "HTTPValidationError": { "properties": { @@ -1693,6 +2024,68 @@ "title": "MuxRule", "description": "Represents a mux rule for a provider." }, + "Persona": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "type": "string", + "title": "Description" + } + }, + "type": "object", + "required": [ + "id", + "name", + "description" + ], + "title": "Persona", + "description": "Represents a persona object." + }, + "PersonaRequest": { + "properties": { + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "type": "string", + "title": "Description" + } + }, + "type": "object", + "required": [ + "name", + "description" + ], + "title": "PersonaRequest", + "description": "Model for creating a new Persona." + }, + "PersonaUpdateRequest": { + "properties": { + "new_name": { + "type": "string", + "title": "New Name" + }, + "new_description": { + "type": "string", + "title": "New Description" + } + }, + "type": "object", + "required": [ + "new_name", + "new_description" + ], + "title": "PersonaUpdateRequest", + "description": "Model for updating a Persona." + }, "ProviderAuthType": { "type": "string", "enum": [ @@ -1914,11 +2307,32 @@ ], "title": "Workspace" }, - "WorkspaceConfig": { + "WorkspaceConfig-Input": { + "properties": { + "custom_instructions": { + "type": "string", + "title": "Custom Instructions" + }, + "muxing_rules": { + "items": { + "$ref": "#/components/schemas/MuxRule" + }, + "type": "array", + "title": "Muxing Rules" + } + }, + "type": "object", + "required": [ + "custom_instructions", + "muxing_rules" + ], + "title": "WorkspaceConfig" + }, + "WorkspaceConfig-Output": { "properties": { - "system_prompt": { + "custom_instructions": { "type": "string", - "title": "System Prompt" + "title": "Custom Instructions" }, "muxing_rules": { "items": { @@ -1930,7 +2344,7 @@ }, "type": "object", "required": [ - "system_prompt", + "custom_instructions", "muxing_rules" ], "title": "WorkspaceConfig" diff --git a/docs/workspaces.md b/docs/workspaces.md new file mode 100644 index 00000000..cdc4e751 --- /dev/null +++ b/docs/workspaces.md @@ -0,0 +1,111 @@ +# CodeGate Workspaces + +Workspaces help you group related resources together. They can be used to organize your +configurations, muxing rules and custom prompts. It is important to note that workspaces +are not a tenancy concept; CodeGate assumes that it's serving a single user. + +## Global vs Workspace resources + +In CodeGate, resources can be either global (available across all workspaces) or workspace-specific: + +- **Global resources**: These are shared across all workspaces and include provider endpoints, + authentication configurations, and personas. + +- **Workspace resources**: These are specific to a workspace and include custom instructions, + muxing rules, and conversation history. + +### Sessions and Active Workspaces + +CodeGate uses the concept of "sessions" to track which workspace is active. A session represents +a user's interaction context with the system and maintains a reference to the active workspace. + +- **Sessions**: Each session has an ID, an active workspace ID, and a last update timestamp. +- **Active workspace**: The workspace that is currently being used for processing requests. + +Currently, the implementation expects only one active session at a time, meaning only one +workspace can be active. However, the underlying architecture is designed to potentially +support multiple concurrent sessions in the future, which would allow different contexts +to have different active workspaces simultaneously. + +When a workspace is activated, the session's active_workspace_id is updated to point to that +workspace, and the muxing registry is updated to use that workspace's rules for routing requests. + +## Workspace Lifecycle + +Workspaces in CodeGate follow a specific lifecycle: + +1. **Creation**: Workspaces are created with a unique name and optional custom instructions and muxing rules. +2. **Activation**: A workspace can be activated, making it the current context for processing requests. +3. **Archiving**: Workspaces can be archived (soft-deleted) when no longer needed but might be used again. +4. **Recovery**: Archived workspaces can be recovered to make them available again. +5. **Deletion**: Archived workspaces can be permanently deleted (hard-deleted). + +### Default Workspace + +CodeGate includes a default workspace that cannot be deleted or archived. This workspace is used +when no other workspace is explicitly activated. + +## Workspace Features + +### Custom Instructions + +Each workspace can have its own set of custom instructions that are applied to LLM requests. +These instructions can be used to customize the behavior of the LLM for specific use cases. + +### Muxing Rules + +Workspaces can define muxing rules that determine which provider and model to use for different +types of requests. Rules are evaluated in priority order (first rule in the list has highest priority). + +### Token Usage Tracking + +CodeGate tracks token usage per workspace, allowing you to monitor and analyze resource consumption +across different contexts or projects. + +### Prompts, Alerts and Monitoring + +Workspaces maintain their own prompt and alert history, making it easier to monitor and respond to issues within specific contexts. + +## Developing + +### When to use workspaces? + +Consider using separate workspaces when: + +- You need different custom instructions for different projects or use cases +- You want to route different types of requests to different models +- You need to track token usage separately for different projects +- You want to isolate alerts and monitoring for specific contexts +- You're experimenting with different configurations and want to switch between them easily + +### When should a resource be global? + +Resources should be global when: + +- They need to be shared across multiple workspaces +- They represent infrastructure configuration rather than usage patterns +- They're related to provider connectivity rather than specific use cases +- They represent reusable components like personas that might be used in multiple contexts + +### Exporting resources + +Exporting resources in CodeGate is designed to facilitate sharing workspaces between different instances. +This is particularly useful for: + +- **Standardizing configurations**: When you want to ensure consistent behavior across multiple CodeGate instances +- **Sharing best practices**: When you've developed effective muxing rules or custom instructions that others could benefit from +- **Backup and recovery**: To preserve important workspace configurations before making significant changes + +When deciding whether to export resources, consider: + +- **Export workspace configurations** when they represent reusable patterns that could be valuable in other contexts +- **Export muxing rules** when they represent well-tested routing strategies that could be applied in other instances +- **Export custom instructions** when they contain general-purpose prompting strategies not specific to your instance + +Avoid exporting: +- Workspaces with instance-specific configurations that wouldn't be applicable elsewhere +- Workspaces containing sensitive or organization-specific custom instructions +- Resources that are tightly coupled to your specific provider endpoints or authentication setup + +Note that conversation history, alerts, and token usage statistics are not included in exports as they +represent instance-specific usage data rather than reusable configurations. diff --git a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py new file mode 100644 index 00000000..775e3967 --- /dev/null +++ b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py @@ -0,0 +1,63 @@ +"""add installation table + +Revision ID: e4c05d7591a8 +Revises: 3ec2b4ab569c +Create Date: 2025-03-05 21:26:19.034319+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "e4c05d7591a8" +down_revision: Union[str, None] = "3ec2b4ab569c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute("BEGIN TRANSACTION;") + + op.execute( + """ + CREATE TABLE IF NOT EXISTS instance ( + id TEXT PRIMARY KEY, -- UUID stored as TEXT + created_at DATETIME NOT NULL + ); + """ + ) + + op.execute( + """ + -- The following trigger prevents multiple insertions in the + -- instance table. It is safe since the dimension of the table + -- is fixed. + + CREATE TRIGGER single_instance + BEFORE INSERT ON instance + WHEN (SELECT COUNT(*) FROM instance) >= 1 + BEGIN + SELECT RAISE(FAIL, 'only one instance!'); + END; + """ + ) + + # Finish transaction + op.execute("COMMIT;") + + +def downgrade() -> None: + op.execute("BEGIN TRANSACTION;") + + op.execute( + """ + DROP TABLE instance; + """ + ) + + # Finish transaction + op.execute("COMMIT;") diff --git a/poetry.lock b/poetry.lock index f757d84e..15579a88 100644 --- a/poetry.lock +++ b/poetry.lock @@ -151,23 +151,23 @@ docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"] [[package]] name = "alembic" -version = "1.14.1" +version = "1.15.1" description = "A database migration tool for SQLAlchemy." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "alembic-1.14.1-py3-none-any.whl", hash = "sha256:1acdd7a3a478e208b0503cd73614d5e4c6efafa4e73518bb60e4f2846a37b1c5"}, - {file = "alembic-1.14.1.tar.gz", hash = "sha256:496e888245a53adf1498fcab31713a469c65836f8de76e01399aa1c3e90dd213"}, + {file = "alembic-1.15.1-py3-none-any.whl", hash = "sha256:197de710da4b3e91cf66a826a5b31b5d59a127ab41bd0fc42863e2902ce2bbbe"}, + {file = "alembic-1.15.1.tar.gz", hash = "sha256:e1a1c738577bca1f27e68728c910cd389b9a92152ff91d902da649c192e30c49"}, ] [package.dependencies] Mako = "*" -SQLAlchemy = ">=1.3.0" -typing-extensions = ">=4" +SQLAlchemy = ">=1.4.0" +typing-extensions = ">=4.12" [package.extras] -tz = ["backports.zoneinfo ; python_version < \"3.9\"", "tzdata"] +tz = ["tzdata"] [[package]] name = "annotated-types" @@ -1343,14 +1343,14 @@ files = [ [[package]] name = "jinja2" -version = "3.1.5" +version = "3.1.6" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" groups = ["main", "dev"] files = [ - {file = "jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"}, - {file = "jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb"}, + {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, + {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, ] [package.dependencies] @@ -1546,14 +1546,14 @@ files = [ [[package]] name = "litellm" -version = "1.62.1" +version = "1.63.0" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" groups = ["main", "dev"] files = [ - {file = "litellm-1.62.1-py3-none-any.whl", hash = "sha256:f576358c72b477207d1f45ce5ac895ede7bd84377f6420a6b522909c829a79dc"}, - {file = "litellm-1.62.1.tar.gz", hash = "sha256:eee9cc40dc9c1da7e411af2f4ef145a67bb61702ae4e1218c1bc15b9e6404daa"}, + {file = "litellm-1.63.0-py3-none-any.whl", hash = "sha256:38961eaeb81fa2500c2725e01be898fb5d6347e73286b6d13d2f4d2f006d99e9"}, + {file = "litellm-1.63.0.tar.gz", hash = "sha256:872fb3fa4c8875d82fe998a5e4249c21a15bb08800286f03f90ed1700203f62e"}, ] [package.dependencies] @@ -4279,4 +4279,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<3.13" -content-hash = "1cae360ec3078b2da000dfea1d112e32256502aa9b3e2a4d8ca919384a49aff6" +content-hash = "ba9315e5bd243ff23b9f1044c43228b0658f7c345bc081dcfe9f2af8f2511e0c" diff --git a/pyproject.toml b/pyproject.toml index a6748387..8da9768a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ PyYAML = "==6.0.2" fastapi = "==0.115.11" uvicorn = "==0.34.0" structlog = "==25.1.0" -litellm = "==1.62.1" +litellm = "==1.63.0" llama_cpp_python = "==0.3.5" cryptography = "==44.0.2" sqlalchemy = "==2.0.38" @@ -28,7 +28,7 @@ tree-sitter-java = "==0.23.5" tree-sitter-javascript = "==0.23.1" tree-sitter-python = "==0.23.6" tree-sitter-rust = "==0.23.2" -alembic = "==1.14.1" +alembic = "==1.15.1" pygments = "==2.19.1" sqlite-vec-sl-tmp = "==0.0.4" greenlet = "==3.1.1" @@ -50,7 +50,7 @@ ruff = "==0.9.9" bandit = "==1.8.3" build = "==1.2.2.post1" wheel = "==0.45.1" -litellm = "==1.62.1" +litellm = "==1.63.0" pytest-asyncio = "==0.25.3" llama_cpp_python = "==0.3.5" scikit-learn = "==1.6.1" diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index b28f6704..dd1f70d7 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -28,11 +28,10 @@ generate_certs() { # Function to start Nginx server for the dashboard start_dashboard() { - if [ -n "${DASHBOARD_BASE_URL}" ]; then - echo "Overriding dashboard url with $DASHBOARD_BASE_URL" - sed -ibck "s|http://localhost:8989|http://$DASHBOARD_BASE_URL:8989|g" /var/www/html/assets/*.js - fi echo "Starting the dashboard..." + + envsubst '${DASHBOARD_API_BASE_URL}' < /var/www/html/index.html > /var/www/html/index.html.tmp && mv /var/www/html/index.html.tmp /var/www/html/index.html + nginx -g 'daemon off;' & } diff --git a/scripts/import_packages.py b/scripts/import_packages.py index 1cfdfd1e..c4a2dad1 100644 --- a/scripts/import_packages.py +++ b/scripts/import_packages.py @@ -20,6 +20,7 @@ def __init__(self, jsonl_dir="data", vec_db_path="./sqlite_data/vectordb.db"): os.path.join(jsonl_dir, "archived.jsonl"), os.path.join(jsonl_dir, "deprecated.jsonl"), os.path.join(jsonl_dir, "malicious.jsonl"), + os.path.join(jsonl_dir, "vulnerable.jsonl"), ] self.conn = self._get_connection() Config.load() # Load the configuration @@ -48,13 +49,41 @@ def setup_schema(self): """ ) + # table for packages that has at least one vulnerability high or critical + cursor.execute( + """ + CREATE TABLE cve_packages ( + name TEXT NOT NULL, + version TEXT NOT NULL, + type TEXT NOT NULL + ) + """ + ) + # Create indexes for faster querying cursor.execute("CREATE INDEX IF NOT EXISTS idx_name ON packages(name)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_type ON packages(type)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON packages(status)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_pkg_cve_name ON cve_packages(name)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_pkg_cve_type ON cve_packages(type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_pkg_cve_version ON cve_packages(version)") self.conn.commit() + async def process_cve_packages(self, package): + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT INTO cve_packages (name, version, type) VALUES (?, ?, ?) + """, + ( + package["name"], + package["version"], + package["type"], + ), + ) + self.conn.commit() + async def process_package(self, package): vector_str = generate_vector_string(package) vector = await self.inference_engine.embed( @@ -101,14 +130,19 @@ async def add_data(self): package["status"] = json_file.split("/")[-1].split(".")[0] key = f"{package['name']}/{package['type']}" - if key in existing_packages and existing_packages[key] == { - "status": package["status"], - "description": package["description"], - }: - print("Package already exists", key) - continue - - await self.process_package(package) + if package["status"] == "vulnerable": + # Process vulnerable packages using the cve flow + await self.process_cve_packages(package) + else: + # For non-vulnerable packages, check for duplicates and process normally + if key in existing_packages and existing_packages[key] == { + "status": package["status"], + "description": package["description"], + }: + print("Package already exists", key) + continue + + await self.process_package(package) async def run_import(self): self.setup_schema() diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ebd9be79..33efea33 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -12,7 +12,12 @@ from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, WorkspaceWithModel +from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel +from codegate.muxing.persona import ( + PersonaDoesNotExistError, + PersonaManager, + PersonaSimilarDescriptionError, +) from codegate.providers import crud as provendcrud from codegate.workspaces import crud @@ -21,6 +26,7 @@ v1 = APIRouter() wscrud = crud.WorkspaceCrud() pcrud = provendcrud.ProviderCrud() +persona_manager = PersonaManager() # This is a singleton object dbreader = DbReader() @@ -248,22 +254,18 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status @v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201) async def create_workspace( - request: v1_models.CreateOrRenameWorkspaceRequest, -) -> v1_models.Workspace: + request: v1_models.FullWorkspace, +) -> v1_models.FullWorkspace: """Create a new workspace.""" - if request.rename_to is not None: - return await rename_workspace(request) - return await create_new_workspace(request) - - -async def create_new_workspace( - request: v1_models.CreateOrRenameWorkspaceRequest, -) -> v1_models.Workspace: - # Input validation is done in the model try: - _ = await wscrud.add_workspace(request.name) - except AlreadyExistsError: - raise HTTPException(status_code=409, detail="Workspace already exists") + custom_instructions = request.config.custom_instructions if request.config else None + muxing_rules = request.config.muxing_rules if request.config else None + + workspace_row, mux_rules = await wscrud.add_workspace( + request.name, custom_instructions, muxing_rules + ) + except crud.WorkspaceNameAlreadyInUseError: + raise HTTPException(status_code=409, detail="Workspace name already in use") except ValidationError: raise HTTPException( status_code=400, @@ -277,18 +279,40 @@ async def create_new_workspace( except Exception: raise HTTPException(status_code=500, detail="Internal server error") - return v1_models.Workspace(name=request.name, is_active=False) + return v1_models.FullWorkspace( + name=workspace_row.name, + config=v1_models.WorkspaceConfig( + custom_instructions=workspace_row.custom_instructions or "", + muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + ), + ) -async def rename_workspace( - request: v1_models.CreateOrRenameWorkspaceRequest, -) -> v1_models.Workspace: +@v1.put( + "/workspaces/{workspace_name}", + tags=["Workspaces"], + generate_unique_id_function=uniq_name, + status_code=201, +) +async def update_workspace( + workspace_name: str, + request: v1_models.FullWorkspace, +) -> v1_models.FullWorkspace: + """Update a workspace.""" try: - _ = await wscrud.rename_workspace(request.name, request.rename_to) + custom_instructions = request.config.custom_instructions if request.config else None + muxing_rules = request.config.muxing_rules if request.config else None + + workspace_row, mux_rules = await wscrud.update_workspace( + workspace_name, + request.name, + custom_instructions, + muxing_rules, + ) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") - except AlreadyExistsError: - raise HTTPException(status_code=409, detail="Workspace already exists") + except crud.WorkspaceNameAlreadyInUseError: + raise HTTPException(status_code=409, detail="Workspace name already in use") except ValidationError: raise HTTPException( status_code=400, @@ -302,7 +326,13 @@ async def rename_workspace( except Exception: raise HTTPException(status_code=500, detail="Internal server error") - return v1_models.Workspace(name=request.rename_to, is_active=False) + return v1_models.FullWorkspace( + name=workspace_row.name, + config=v1_models.WorkspaceConfig( + custom_instructions=workspace_row.custom_instructions or "", + muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + ), + ) @v1.delete( @@ -397,6 +427,33 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A raise HTTPException(status_code=500, detail="Internal server error") +@v1.get( + "/workspaces/{workspace_name}/alerts-summary", + tags=["Workspaces"], + generate_unique_id_function=uniq_name, +) +async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSummary: + """Get alert summary for a workspace.""" + try: + ws = await wscrud.get_workspace_by_name(workspace_name) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") + except Exception: + logger.exception("Error while getting workspace") + raise HTTPException(status_code=500, detail="Internal server error") + + try: + summary = await dbreader.get_alerts_summary_by_workspace(ws.id) + return v1_models.AlertSummary( + malicious_packages=summary["codegate_context_retriever_count"], + pii=summary["codegate_pii_count"], + secrets=summary["codegate_secrets_count"], + ) + except Exception: + logger.exception("Error while getting alerts summary") + raise HTTPException(status_code=500, detail="Internal server error") + + @v1.get( "/workspaces/{workspace_name}/messages", tags=["Workspaces"], @@ -614,3 +671,103 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage except Exception: logger.exception("Error while getting messages") raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.get("/personas", tags=["Personas"], generate_unique_id_function=uniq_name) +async def list_personas() -> List[Persona]: + """List all personas.""" + try: + personas = await persona_manager.get_all_personas() + return personas + except Exception: + logger.exception("Error while getting personas") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.get("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name) +async def get_persona(persona_name: str) -> Persona: + """Get a persona by name.""" + try: + persona = await persona_manager.get_persona(persona_name) + return persona + except PersonaDoesNotExistError: + logger.exception("Error while getting persona") + raise HTTPException(status_code=404, detail="Persona does not exist") + + +@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201) +async def create_persona(request: v1_models.PersonaRequest) -> Persona: + """Create a new persona.""" + try: + await persona_manager.add_persona(request.name, request.description) + persona = await dbreader.get_persona_by_name(request.name) + return persona + except PersonaSimilarDescriptionError: + logger.exception("Error while creating persona") + raise HTTPException(status_code=409, detail="Persona has a similar description to another") + except AlreadyExistsError: + logger.exception("Error while creating persona") + raise HTTPException(status_code=409, detail="Persona already exists") + except ValidationError: + logger.exception("Error while creating persona") + raise HTTPException( + status_code=400, + detail=( + "Persona has invalid name, check is alphanumeric " + "and only contains dashes and underscores" + ), + ) + except Exception: + logger.exception("Error while creating persona") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.put("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name) +async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequest) -> Persona: + """Update an existing persona.""" + try: + await persona_manager.update_persona( + persona_name, request.new_name, request.new_description + ) + persona = await dbreader.get_persona_by_name(request.new_name) + return persona + except PersonaSimilarDescriptionError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=409, detail="Persona has a similar description to another") + except PersonaDoesNotExistError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=404, detail="Persona does not exist") + except AlreadyExistsError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=409, detail="Persona already exists") + except ValidationError: + logger.exception("Error while creating persona") + raise HTTPException( + status_code=400, + detail=( + "Persona has invalid name, check is alphanumeric " + "and only contains dashes and underscores" + ), + ) + except Exception: + logger.exception("Error while updating persona") + raise HTTPException(status_code=500, detail="Internal server error") + + +@v1.delete( + "/personas/{persona_name}", + tags=["Personas"], + generate_unique_id_function=uniq_name, + status_code=204, +) +async def delete_persona(persona_name: str): + """Delete a persona.""" + try: + await persona_manager.delete_persona(persona_name) + return Response(status_code=204) + except PersonaDoesNotExistError: + logger.exception("Error while updating persona") + raise HTTPException(status_code=404, detail="Persona does not exist") + except Exception: + logger.exception("Error while deleting persona") + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index c608484c..dff26489 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -61,7 +61,7 @@ def from_db_workspaces( class WorkspaceConfig(pydantic.BaseModel): - system_prompt: str + custom_instructions: str muxing_rules: List[mux_models.MuxRule] @@ -72,13 +72,6 @@ class FullWorkspace(pydantic.BaseModel): config: Optional[WorkspaceConfig] = None -class CreateOrRenameWorkspaceRequest(FullWorkspace): - # If set, rename the workspace to this name. Note that - # the 'name' field is still required and the workspace - # workspace must exist. - rename_to: Optional[str] = None - - class ActivateWorkspaceRequest(pydantic.BaseModel): name: str @@ -190,6 +183,16 @@ def from_db_model(db_model: db_models.Alert) -> "Alert": timestamp: datetime.datetime +class AlertSummary(pydantic.BaseModel): + """ + Represents a set of summary alerts + """ + + malicious_packages: int + pii: int + secrets: int + + class PartialQuestionAnswer(pydantic.BaseModel): """ Represents a partial conversation. @@ -312,3 +315,21 @@ class ModelByProvider(pydantic.BaseModel): def __str__(self): return f"{self.provider_name} / {self.name}" + + +class PersonaRequest(pydantic.BaseModel): + """ + Model for creating a new Persona. + """ + + name: str + description: str + + +class PersonaUpdateRequest(pydantic.BaseModel): + """ + Model for updating a Persona. + """ + + new_name: str + new_description: str diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 455d9001..1ae3f9c2 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -14,7 +14,11 @@ from codegate.ca.codegate_ca import CertificateAuthority from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.config import Config, ConfigurationError -from codegate.db.connection import init_db_sync, init_session_if_not_exists +from codegate.db.connection import ( + init_db_sync, + init_session_if_not_exists, + init_instance, +) from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers import crud as provendcrud @@ -318,6 +322,7 @@ def serve( # noqa: C901 logger = structlog.get_logger("codegate").bind(origin="cli") init_db_sync(cfg.db_path) + init_instance(cfg.db_path) init_session_if_not_exists(cfg.db_path) # Check certificates and create CA if necessary diff --git a/src/codegate/config.py b/src/codegate/config.py index 761ca09e..179ec4d3 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -57,9 +57,14 @@ class Config: force_certs: bool = False max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes. + # Min value is 0 (max similarity), max value is 2 (orthogonal) # The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py + # It's the threshold value to determine if a query matches a persona. persona_threshold = 0.75 + # The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py + # It's the threshold value to determine if a persona description is similar to existing personas + persona_diff_desc_threshold = 0.3 # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 803943b3..3f439aea 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,4 +1,5 @@ import asyncio +import datetime import json import sqlite3 import uuid @@ -14,7 +15,8 @@ from sqlalchemy import CursorResult, TextClause, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker from codegate.db.fim_cache import FimCache from codegate.db.models import ( @@ -22,6 +24,7 @@ Alert, GetPromptWithOutputsRow, GetWorkspaceByNameConditions, + Instance, IntermediatePromptWithOutputUsageAlerts, MuxRule, Output, @@ -560,15 +563,62 @@ async def add_persona(self, persona: PersonaEmbedding) -> None: ) try: - # For Pydantic we convert the numpy array to string when serializing with .model_dumpy() - # We need to convert it back to a numpy array before inserting it into the DB. - persona_dict = persona.model_dump() - persona_dict["description_embedding"] = persona.description_embedding - await self._execute_with_no_return(sql, persona_dict) + await self._execute_with_no_return(sql, persona.model_dump()) except IntegrityError as e: logger.debug(f"Exception type: {type(e)}") raise AlreadyExistsError(f"Persona '{persona.name}' already exists.") + async def update_persona(self, persona: PersonaEmbedding) -> None: + """ + Update an existing Persona in the DB. + + This handles validation and update of an existing persona. + """ + sql = text( + """ + UPDATE personas + SET name = :name, + description = :description, + description_embedding = :description_embedding + WHERE id = :id + """ + ) + + try: + await self._execute_with_no_return(sql, persona.model_dump()) + except IntegrityError as e: + logger.debug(f"Exception type: {type(e)}") + raise AlreadyExistsError(f"Persona '{persona.name}' already exists.") + + async def delete_persona(self, persona_id: str) -> None: + """ + Delete an existing Persona from the DB. + """ + sql = text("DELETE FROM personas WHERE id = :id") + conditions = {"id": persona_id} + await self._execute_with_no_return(sql, conditions) + + async def init_instance(self) -> None: + """ + Initializes instance details in the database. + """ + sql = text( + """ + INSERT INTO instance (id, created_at) + VALUES (:id, :created_at) + """ + ) + + try: + instance = Instance( + id=str(uuid.uuid4()), + created_at=datetime.datetime.now(datetime.timezone.utc), + ) + await self._execute_with_no_return(sql, instance.model_dump()) + except IntegrityError as e: + logger.debug(f"Exception type: {type(e)}") + raise AlreadyExistsError(f"Instance already initialized.") + class DbReader(DbCodeGate): def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs): @@ -587,7 +637,10 @@ async def _dump_result_to_pydantic_model( return None async def _execute_select_pydantic_model( - self, model_type: Type[BaseModel], sql_command: TextClause + self, + model_type: Type[BaseModel], + sql_command: TextClause, + should_raise: bool = False, ) -> Optional[List[BaseModel]]: async with self._async_db_engine.begin() as conn: try: @@ -595,6 +648,9 @@ async def _execute_select_pydantic_model( return await self._dump_result_to_pydantic_model(model_type, result) except Exception as e: logger.error(f"Failed to select model: {model_type}.", error=str(e)) + # Exposes errors to the caller + if should_raise: + raise e return None async def _exec_select_conditions_to_pydantic( @@ -746,6 +802,38 @@ async def get_alerts_by_workspace( ) return prompts + async def get_alerts_summary_by_workspace(self, workspace_id: str) -> dict: + """Get aggregated alert summary counts for a given workspace_id.""" + sql = text( + """ + SELECT + COUNT(*) AS total_alerts, + SUM(CASE WHEN a.trigger_type = 'codegate-secrets' THEN 1 ELSE 0 END) + AS codegate_secrets_count, + SUM(CASE WHEN a.trigger_type = 'codegate-context-retriever' THEN 1 ELSE 0 END) + AS codegate_context_retriever_count, + SUM(CASE WHEN a.trigger_type = 'codegate-pii' THEN 1 ELSE 0 END) + AS codegate_pii_count + FROM alerts a + INNER JOIN prompts p ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + """ + ) + conditions = {"workspace_id": workspace_id} + + async with self._async_db_engine.begin() as conn: + result = await conn.execute(sql, conditions) + row = result.fetchone() + + # Return a dictionary with counts (handling None values safely) + return { + "codegate_secrets_count": row.codegate_secrets_count or 0 if row else 0, + "codegate_context_retriever_count": ( + row.codegate_context_retriever_count or 0 if row else 0 + ), + "codegate_pii_count": row.codegate_pii_count or 0 if row else 0, + } + async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( """ @@ -971,6 +1059,33 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]: ) return personas[0] if personas else None + async def get_distance_to_existing_personas( + self, query_embedding: np.ndarray, exclude_id: Optional[str] + ) -> List[PersonaDistance]: + """ + Get the distance between a persona and a query embedding. + """ + sql = """ + SELECT + id, + name, + description, + vec_distance_cosine(description_embedding, :query_embedding) as distance + FROM personas + """ + conditions = {"query_embedding": query_embedding} + + # Exclude this persona from the SQL query. Used when checking the descriptions + # for updating the persona. Exclude the persona to update itself from the query. + if exclude_id: + sql += " WHERE id != :exclude_id" + conditions["exclude_id"] = exclude_id + + persona_distances = await self._exec_vec_db_query_to_pydantic( + sql, conditions, PersonaDistance + ) + return persona_distances + async def get_distance_to_persona( self, persona_id: str, query_embedding: np.ndarray ) -> PersonaDistance: @@ -992,6 +1107,55 @@ async def get_distance_to_persona( ) return persona_distance[0] + async def get_all_personas(self) -> List[Persona]: + """ + Get all the personas. + """ + sql = text( + """ + SELECT + id, name, description + FROM personas + """ + ) + personas = await self._execute_select_pydantic_model(Persona, sql, should_raise=True) + return personas + + async def get_instance(self) -> Instance: + """ + Get the details of the instance. + """ + sql = text("SELECT id, created_at FROM instance") + return await self._execute_select_pydantic_model(Instance, sql) + + +class DbTransaction: + def __init__(self): + self._session = None + + async def __aenter__(self): + self._session = sessionmaker( + bind=DbCodeGate()._async_db_engine, + class_=AsyncSession, + expire_on_commit=False, + )() + await self._session.begin() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type: + await self._session.rollback() + raise exc_val + else: + await self._session.commit() + await self._session.close() + + async def commit(self): + await self._session.commit() + + async def rollback(self): + await self._session.rollback() + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" @@ -1014,8 +1178,6 @@ def init_db_sync(db_path: Optional[str] = None): def init_session_if_not_exists(db_path: Optional[str] = None): - import datetime - db_reader = DbReader(db_path) sessions = asyncio.run(db_reader.get_sessions()) # If there are no sessions, create a new one @@ -1035,5 +1197,19 @@ def init_session_if_not_exists(db_path: Optional[str] = None): logger.info("Session in DB initialized successfully.") +def init_instance(db_path: Optional[str] = None): + db_reader = DbReader(db_path) + instance = asyncio.run(db_reader.get_instance()) + # Initialize instance if not already initialized. + if not instance: + db_recorder = DbRecorder(db_path) + try: + asyncio.run(db_recorder.init_instance()) + except Exception as e: + logger.error(f"Failed to initialize instance in DB: {e}") + raise + logger.info("Instance initialized successfully.") + + if __name__ == "__main__": init_db_sync() diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index a5941e96..07c4c8ed 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -3,7 +3,15 @@ from typing import Annotated, Any, Dict, List, Optional import numpy as np -from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints +import regex as re +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + PlainSerializer, + StringConstraints, + field_validator, +) class AlertSeverity(str, Enum): @@ -120,6 +128,11 @@ class Session(BaseModel): last_update: datetime.datetime +class Instance(BaseModel): + id: str + created_at: datetime.datetime + + # Models for select queries @@ -245,12 +258,14 @@ class MuxRule(BaseModel): def nd_array_custom_before_validator(x): # custome before validation logic + if isinstance(x, bytes): + return np.frombuffer(x, dtype=np.float32) return x def nd_array_custom_serializer(x): # custome serialization logic - return str(x) + return x # Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type. @@ -264,6 +279,8 @@ def nd_array_custom_serializer(x): PlainSerializer(nd_array_custom_serializer, return_type=str), ] +VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$") + class Persona(BaseModel): """ @@ -274,6 +291,15 @@ class Persona(BaseModel): name: str description: str + @field_validator("name", mode="after") + @classmethod + def validate_persona_name(cls, value: str) -> str: + if VALID_PERSONA_NAME_PATTERN.match(value): + return value + raise ValueError( + "Invalid persona name. It should be alphanumeric with underscores and dashes." + ) + class PersonaEmbedding(Persona): """ diff --git a/src/codegate/muxing/semantic_router.py b/src/codegate/muxing/persona.py similarity index 58% rename from src/codegate/muxing/semantic_router.py rename to src/codegate/muxing/persona.py index ce240b1f..ac21205c 100644 --- a/src/codegate/muxing/semantic_router.py +++ b/src/codegate/muxing/persona.py @@ -1,5 +1,6 @@ import unicodedata import uuid +from typing import List, Optional import numpy as np import regex as re @@ -28,14 +29,20 @@ class PersonaDoesNotExistError(Exception): pass -class SemanticRouter: +class PersonaSimilarDescriptionError(Exception): + pass + + +class PersonaManager: def __init__(self): - self._inference_engine = LlamaCppInferenceEngine() + Config.load() conf = Config.get_config() + self._inference_engine = LlamaCppInferenceEngine() self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}" self._n_gpu = conf.chat_model_n_gpu_layers self._persona_threshold = conf.persona_threshold + self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold self._db_recorder = DbRecorder() self._db_reader = DbReader() @@ -105,12 +112,50 @@ async def _embed_text(self, text: str) -> np.ndarray: logger.debug("Text embedded in semantic routing", text=cleaned_text[:50]) return np.array(embed_list[0], dtype=np.float32) + async def _is_persona_description_diff( + self, emb_persona_desc: np.ndarray, exclude_id: Optional[str] + ) -> bool: + """ + Check if the persona description is different enough from existing personas. + """ + # The distance calculation is done in the database + persona_distances = await self._db_reader.get_distance_to_existing_personas( + emb_persona_desc, exclude_id + ) + if not persona_distances: + return True + + for persona_distance in persona_distances: + logger.info( + f"Persona description distance to {persona_distance.name}", + distance=persona_distance.distance, + ) + # If the distance is less than the threshold, the persona description is too similar + if persona_distance.distance < self._persona_diff_desc_threshold: + return False + return True + + async def _validate_persona_description( + self, persona_desc: str, exclude_id: str = None + ) -> np.ndarray: + """ + Validate the persona description by embedding the text and checking if it is + different enough from existing personas. + """ + emb_persona_desc = await self._embed_text(persona_desc) + if not await self._is_persona_description_diff(emb_persona_desc, exclude_id): + raise PersonaSimilarDescriptionError( + "The persona description is too similar to existing personas." + ) + return emb_persona_desc + async def add_persona(self, persona_name: str, persona_desc: str) -> None: """ Add a new persona to the database. The persona description is embedded and stored in the database. """ - emb_persona_desc = await self._embed_text(persona_desc) + emb_persona_desc = await self._validate_persona_description(persona_desc) + new_persona = db_models.PersonaEmbedding( id=str(uuid.uuid4()), name=persona_name, @@ -120,6 +165,58 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None: await self._db_recorder.add_persona(new_persona) logger.info(f"Added persona {persona_name} to the database.") + async def get_persona(self, persona_name: str) -> db_models.Persona: + """ + Get a persona from the database by name. + """ + persona = await self._db_reader.get_persona_by_name(persona_name) + if not persona: + raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") + return persona + + async def get_all_personas(self) -> List[db_models.Persona]: + """ + Get all personas from the database. + """ + return await self._db_reader.get_all_personas() + + async def update_persona( + self, persona_name: str, new_persona_name: str, new_persona_desc: str + ) -> None: + """ + Update an existing persona in the database. The name and description are + updated in the database, but the ID remains the same. + """ + # First we check if the persona exists, if not we raise an error + found_persona = await self._db_reader.get_persona_by_name(persona_name) + if not found_persona: + raise PersonaDoesNotExistError(f"Person {persona_name} does not exist.") + + emb_persona_desc = await self._validate_persona_description( + new_persona_desc, exclude_id=found_persona.id + ) + + # Then we update the attributes in the database + updated_persona = db_models.PersonaEmbedding( + id=found_persona.id, + name=new_persona_name, + description=new_persona_desc, + description_embedding=emb_persona_desc, + ) + await self._db_recorder.update_persona(updated_persona) + logger.info(f"Updated persona {persona_name} in the database.") + + async def delete_persona(self, persona_name: str) -> None: + """ + Delete a persona from the database. + """ + persona = await self._db_reader.get_persona_by_name(persona_name) + if not persona: + raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") + + await self._db_recorder.delete_persona(persona.id) + logger.info(f"Deleted persona {persona_name} from the database.") + async def check_persona_match(self, persona_name: str, query: str) -> bool: """ Check if the query matches the persona description. A vector similarity diff --git a/src/codegate/pipeline/cli/cli.py b/src/codegate/pipeline/cli/cli.py index be2222c8..fde37f94 100644 --- a/src/codegate/pipeline/cli/cli.py +++ b/src/codegate/pipeline/cli/cli.py @@ -95,6 +95,25 @@ def _get_cli_from_continue(last_user_message_str: str) -> Optional[re.Match[str] return codegate_regex.match(last_user_message_str) +def _get_cli_from_copilot(last_user_message_str: str) -> Optional[re.Match[str]]: + """ + Process Copilot-specific CLI command format. + + Copilot sends messages in the format: + file contentscodegate command + + Args: + last_user_message_str (str): The message string from Copilot + + Returns: + Optional[re.Match[str]]: A regex match object if command is found, None otherwise + """ + cleaned_text = re.sub( + r".*", "", last_user_message_str, flags=re.DOTALL + ) + return codegate_regex.match(cleaned_text.strip()) + + class CodegateCli(PipelineStep): """Pipeline step that handles codegate cli.""" @@ -136,6 +155,8 @@ async def process( match = _get_cli_from_open_interpreter(last_user_message_str) elif context.client in [ClientType.CONTINUE]: match = _get_cli_from_continue(last_user_message_str) + elif context.client in [ClientType.COPILOT]: + match = _get_cli_from_copilot(last_user_message_str) else: # Check if "codegate" is the first word in the message match = codegate_regex.match(last_user_message_str) diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py index 5b101400..c5655ec3 100644 --- a/src/codegate/pipeline/cli/commands.py +++ b/src/codegate/pipeline/cli/commands.py @@ -98,7 +98,6 @@ def help(self) -> str: class CodegateCommandSubcommand(CodegateCommand): - @property @abstractmethod def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]: @@ -174,7 +173,6 @@ async def run(self, args: List[str]) -> str: class Workspace(CodegateCommandSubcommand): - def __init__(self): self.workspace_crud = crud.WorkspaceCrud() @@ -258,7 +256,7 @@ async def _rename_workspace(self, flags: Dict[str, str], args: List[str]) -> str ) try: - await self.workspace_crud.rename_workspace(old_workspace_name, new_workspace_name) + await self.workspace_crud.update_workspace(old_workspace_name, new_workspace_name) except crud.WorkspaceDoesNotExistError: return f"Workspace **{old_workspace_name}** does not exist" except AlreadyExistsError: @@ -410,7 +408,6 @@ def help(self) -> str: class CustomInstructions(CodegateCommandSubcommand): - def __init__(self): self.workspace_crud = crud.WorkspaceCrud() diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index 96442824..706deb9b 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -1,10 +1,9 @@ -from typing import Any, List, Optional +from typing import List, Optional import structlog from presidio_analyzer import AnalyzerEngine from presidio_anonymizer import AnonymizerEngine -from codegate.db.models import AlertSeverity from codegate.pipeline.base import PipelineContext from codegate.pipeline.sensitive_data.session_store import SessionStore diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index fde89428..d7f33d67 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -1,5 +1,4 @@ from typing import Any, Dict, List, Optional, Tuple -import uuid import regex as re import structlog diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 527c817f..c299469e 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -316,8 +316,17 @@ async def process( # Process all messages for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: + message_content = message["content"] + + # cline with anthropic seems to be sending a list of dicts with type:text instead of + # a string + # this hack will not be needed once we access the native functions through an API + # (I tested this actually) + if isinstance(message_content, list) and "text" in message_content[0]: + message_content = message_content[0]["text"] + redacted_content, secrets_matched = self._redact_message_content( - message["content"], sensitive_data_manager, session_id, context + message_content, sensitive_data_manager, session_id, context ) new_request["messages"][i]["content"] = redacted_content if i > last_assistant_idx: diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py index 89506d15..bf467878 100644 --- a/src/codegate/pipeline/sensitive_data/manager.py +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -1,7 +1,8 @@ -import json from typing import Dict, Optional + import pydantic import structlog + from codegate.pipeline.sensitive_data.session_store import SessionStore logger = structlog.get_logger("codegate") diff --git a/src/codegate/pipeline/sensitive_data/session_store.py b/src/codegate/pipeline/sensitive_data/session_store.py index 5e508847..7a33abd2 100644 --- a/src/codegate/pipeline/sensitive_data/session_store.py +++ b/src/codegate/pipeline/sensitive_data/session_store.py @@ -1,5 +1,5 @@ -from typing import Dict, Optional import uuid +from typing import Dict, Optional class SessionStore: diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index a81426a8..fbaf5b99 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -3,7 +3,7 @@ from uuid import uuid4 as uuid from codegate.db import models as db_models -from codegate.db.connection import DbReader, DbRecorder +from codegate.db.connection import AlreadyExistsError, DbReader, DbRecorder, DbTransaction from codegate.muxing import models as mux_models from codegate.muxing import rulematcher @@ -16,6 +16,10 @@ class WorkspaceDoesNotExistError(WorkspaceCrudError): pass +class WorkspaceNameAlreadyInUseError(WorkspaceCrudError): + pass + + class WorkspaceAlreadyActiveError(WorkspaceCrudError): pass @@ -31,34 +35,73 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError): class WorkspaceCrud: - def __init__(self): self._db_reader = DbReader() - - async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow: + self._db_recorder = DbRecorder() + + async def add_workspace( + self, + new_workspace_name: str, + custom_instructions: Optional[str] = None, + muxing_rules: Optional[List[mux_models.MuxRule]] = None, + ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ Add a workspace Args: - name (str): The name of the workspace + new_workspace_name (str): The name of the workspace + system_prompt (Optional[str]): The system prompt for the workspace + muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace """ if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS: raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.") - db_recorder = DbRecorder() - workspace_created = await db_recorder.add_workspace(new_workspace_name) - return workspace_created - async def rename_workspace( - self, old_workspace_name: str, new_workspace_name: str - ) -> db_models.WorkspaceRow: + async with DbTransaction() as transaction: + try: + existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) + if existing_ws: + raise WorkspaceNameAlreadyInUseError( + f"Workspace name {new_workspace_name} is already in use." + ) + + workspace_created = await self._db_recorder.add_workspace(new_workspace_name) + + if custom_instructions: + workspace_created.custom_instructions = custom_instructions + await self._db_recorder.update_workspace(workspace_created) + + mux_rules = [] + if muxing_rules: + mux_rules = await self.set_muxes(new_workspace_name, muxing_rules) + + await transaction.commit() + return workspace_created, mux_rules + except ( + AlreadyExistsError, + WorkspaceDoesNotExistError, + WorkspaceNameAlreadyInUseError, + ) as e: + raise e + except Exception as e: + raise WorkspaceCrudError(f"Error adding workspace {new_workspace_name}: {str(e)}") + + async def update_workspace( + self, + old_workspace_name: str, + new_workspace_name: str, + custom_instructions: Optional[str] = None, + muxing_rules: Optional[List[mux_models.MuxRule]] = None, + ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ - Rename a workspace + Update a workspace Args: - old_name (str): The old name of the workspace - new_name (str): The new name of the workspace + old_workspace_name (str): The old name of the workspace + new_workspace_name (str): The new name of the workspace + system_prompt (Optional[str]): The system prompt for the workspace + muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace """ if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") @@ -70,15 +113,40 @@ async def rename_workspace( raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.") if old_workspace_name == new_workspace_name: raise WorkspaceCrudError("Old and new workspace names are the same.") - ws = await self._db_reader.get_workspace_by_name(old_workspace_name) - if not ws: - raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.") - db_recorder = DbRecorder() - new_ws = db_models.WorkspaceRow( - id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions - ) - workspace_renamed = await db_recorder.update_workspace(new_ws) - return workspace_renamed + + async with DbTransaction() as transaction: + try: + ws = await self._db_reader.get_workspace_by_name(old_workspace_name) + if not ws: + raise WorkspaceDoesNotExistError( + f"Workspace {old_workspace_name} does not exist." + ) + + existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) + if existing_ws: + raise WorkspaceNameAlreadyInUseError( + f"Workspace name {new_workspace_name} is already in use." + ) + + new_ws = db_models.WorkspaceRow( + id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions + ) + workspace_renamed = await self._db_recorder.update_workspace(new_ws) + + if custom_instructions: + workspace_renamed.custom_instructions = custom_instructions + await self._db_recorder.update_workspace(workspace_renamed) + + mux_rules = [] + if muxing_rules: + mux_rules = await self.set_muxes(new_workspace_name, muxing_rules) + + await transaction.commit() + return workspace_renamed, mux_rules + except (WorkspaceNameAlreadyInUseError, WorkspaceDoesNotExistError) as e: + raise e + except Exception as e: + raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}") async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]: """ @@ -128,8 +196,7 @@ async def activate_workspace(self, workspace_name: str): session.active_workspace_id = workspace.id session.last_update = datetime.datetime.now(datetime.timezone.utc) - db_recorder = DbRecorder() - await db_recorder.update_session(session) + await self._db_recorder.update_session(session) # Ensure the mux registry is updated mux_registry = await rulematcher.get_muxing_rules_registry() @@ -144,8 +211,7 @@ async def recover_workspace(self, workspace_name: str): if not selected_workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") - db_recorder = DbRecorder() - await db_recorder.recover_workspace(selected_workspace) + await self._db_recorder.recover_workspace(selected_workspace) return async def update_workspace_custom_instructions( @@ -161,8 +227,7 @@ async def update_workspace_custom_instructions( name=selected_workspace.name, custom_instructions=custom_instructions, ) - db_recorder = DbRecorder() - updated_workspace = await db_recorder.update_workspace(workspace_update) + updated_workspace = await self._db_recorder.update_workspace(workspace_update) return updated_workspace async def soft_delete_workspace(self, workspace_name: str): @@ -183,9 +248,8 @@ async def soft_delete_workspace(self, workspace_name: str): if active_workspace and active_workspace.id == selected_workspace.id: raise WorkspaceCrudError("Cannot archive active workspace.") - db_recorder = DbRecorder() try: - _ = await db_recorder.soft_delete_workspace(selected_workspace) + _ = await self._db_recorder.soft_delete_workspace(selected_workspace) except Exception: raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") @@ -205,9 +269,8 @@ async def hard_delete_workspace(self, workspace_name: str): if not selected_workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") - db_recorder = DbRecorder() try: - _ = await db_recorder.hard_delete_workspace(selected_workspace) + _ = await self._db_recorder.hard_delete_workspace(selected_workspace) except Exception: raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") return @@ -247,15 +310,16 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: return muxes - async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None: + async def set_muxes( + self, workspace_name: str, muxes: List[mux_models.MuxRule] + ) -> List[db_models.MuxRule]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") # Delete all muxes for the workspace - db_recorder = DbRecorder() - await db_recorder.delete_muxes_by_workspace(workspace.id) + await self._db_recorder.delete_muxes_by_workspace(workspace.id) # Add the new muxes priority = 0 @@ -268,6 +332,7 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non muxes_with_routes.append((mux, route)) matchers: List[rulematcher.MuxingRuleMatcher] = [] + dbmuxes: List[db_models.MuxRule] = [] for mux, route in muxes_with_routes: new_mux = db_models.MuxRule( @@ -279,7 +344,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non matcher_blob=mux.matcher if mux.matcher else "", priority=priority, ) - dbmux = await db_recorder.add_mux(new_mux) + dbmux = await self._db_recorder.add_mux(new_mux) + dbmuxes.append(dbmux) matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route)) @@ -289,6 +355,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non mux_registry = await rulematcher.get_muxing_rules_registry() await mux_registry.set_ws_rules(workspace_name, matchers) + return dbmuxes + async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute: """Get the routing for a mux diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py new file mode 100644 index 00000000..8bfcbfaf --- /dev/null +++ b/tests/api/test_v1_workspaces.py @@ -0,0 +1,378 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch +from uuid import uuid4 as uuid + +import httpx +import pytest +import structlog +from httpx import AsyncClient + +from codegate.db import connection +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.crud.crud import ProviderCrud +from codegate.server import init_app +from codegate.workspaces.crud import WorkspaceCrud + +logger = structlog.get_logger("codegate") + + +@pytest.fixture +def db_path(): + """Creates a temporary database file path.""" + current_test_dir = Path(__file__).parent + db_filepath = current_test_dir / f"codegate_test_{uuid()}.db" + db_fullpath = db_filepath.absolute() + connection.init_db_sync(str(db_fullpath)) + yield db_fullpath + if db_fullpath.is_file(): + db_fullpath.unlink() + + +@pytest.fixture() +def db_recorder(db_path) -> connection.DbRecorder: + """Creates a DbRecorder instance with test database.""" + return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def db_reader(db_path) -> connection.DbReader: + """Creates a DbReader instance with test database.""" + return connection.DbReader(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def mock_workspace_crud(db_recorder, db_reader) -> WorkspaceCrud: + """Creates a WorkspaceCrud instance with test database.""" + ws_crud = WorkspaceCrud() + ws_crud._db_reader = db_reader + ws_crud._db_recorder = db_recorder + return ws_crud + + +@pytest.fixture() +def mock_provider_crud(db_recorder, db_reader, mock_workspace_crud) -> ProviderCrud: + """Creates a ProviderCrud instance with test database.""" + p_crud = ProviderCrud() + p_crud._db_reader = db_reader + p_crud._db_writer = db_recorder + p_crud._ws_crud = mock_workspace_crud + return p_crud + + +@pytest.fixture +def mock_pipeline_factory(): + """Create a mock pipeline factory.""" + mock_factory = MagicMock(spec=PipelineFactory) + mock_factory.create_input_pipeline.return_value = MagicMock() + mock_factory.create_fim_pipeline.return_value = MagicMock() + mock_factory.create_output_pipeline.return_value = MagicMock() + mock_factory.create_fim_output_pipeline.return_value = MagicMock() + return mock_factory + + +@pytest.mark.asyncio +async def test_create_update_workspace_happy_path( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating & updating a workspace (happy path).""" + + app = init_app(mock_pipeline_factory) + + provider_payload_1 = { + "name": "foo", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + + provider_payload_2 = { + "name": "bar", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create the first provider + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + provider_1 = response.json() + + # Create the second provider + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + provider_2 = response.json() + + name_1: str = str(uuid()) + custom_instructions_1: str = "Respond to every request in iambic pentameter" + muxing_rules_1 = [ + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_1["id"], + "model": "foo-bar-001", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_2["id"], + "model": "foo-bar-002", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_create = { + "name": name_1, + "config": { + "custom_instructions": custom_instructions_1, + "muxing_rules": muxing_rules_1, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_1 + assert response_body["config"]["custom_instructions"] == custom_instructions_1 + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["model"] == muxing_rules_1[i]["model"] + assert rule["matcher"] == muxing_rules_1[i]["matcher"] + assert rule["matcher_type"] == muxing_rules_1[i]["matcher_type"] + + name_2: str = str(uuid()) + custom_instructions_2: str = "Respond to every request in cockney rhyming slang" + muxing_rules_2 = [ + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_2["id"], + "model": "foo-bar-002", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": None, # optional & not implemented yet + "provider_id": provider_1["id"], + "model": "foo-bar-001", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_update = { + "name": name_2, + "config": { + "custom_instructions": custom_instructions_2, + "muxing_rules": muxing_rules_2, + }, + } + + response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_2 + assert response_body["config"]["custom_instructions"] == custom_instructions_2 + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["model"] == muxing_rules_2[i]["model"] + assert rule["matcher"] == muxing_rules_2[i]["matcher"] + assert rule["matcher_type"] == muxing_rules_2[i]["matcher_type"] + + +@pytest.mark.asyncio +async def test_create_update_workspace_name_only( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating & updating a workspace (happy path).""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name_1: str = str(uuid()) + + payload_create = { + "name": name_1, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_1 + + name_2: str = str(uuid()) + + payload_update = { + "name": name_2, + } + + response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) + assert response.status_code == 201 + response_body = response.json() + + assert response_body["name"] == name_2 + + +@pytest.mark.asyncio +async def test_create_workspace_name_already_in_use( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating a workspace when the name is already in use.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + + payload_create = { + "name": name, + } + + # Create the workspace for the first time + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Try to create the workspace again with the same name + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 409 + assert response.json()["detail"] == "Workspace name already in use" + + +@pytest.mark.asyncio +async def test_rename_workspace_name_already_in_use( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test renaming a workspace when the new name is already in use.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name_1: str = str(uuid()) + name_2: str = str(uuid()) + + payload_create_1 = { + "name": name_1, + } + + payload_create_2 = { + "name": name_2, + } + + # Create two workspaces + response = await ac.post("/api/v1/workspaces", json=payload_create_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/workspaces", json=payload_create_2) + assert response.status_code == 201 + + # Try to rename the first workspace to the name of the second workspace + payload_update = { + "name": name_2, + } + + response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) + assert response.status_code == 409 + assert response.json()["detail"] == "Workspace name already in use" + + +@pytest.mark.asyncio +async def test_create_workspace_with_nonexistent_model_in_muxing_rule( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test creating a workspace with a muxing rule that uses a nonexistent model.""" + + app = init_app(mock_pipeline_factory) + + provider_payload = { + "name": "foo", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create the first provider + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + provider = response.json() + + name: str = str(uuid()) + custom_instructions: str = "Respond to every request in iambic pentameter" + muxing_rules = [ + { + "provider_name": None, + "provider_id": provider["id"], + "model": "nonexistent-model", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + ] + + payload_create = { + "name": name, + "config": { + "custom_instructions": custom_instructions, + "muxing_rules": muxing_rules, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 400 + assert "Model nonexistent-model does not exist" in response.json()["detail"] diff --git a/tests/muxing/test_persona.py b/tests/muxing/test_persona.py new file mode 100644 index 00000000..fd0003c9 --- /dev/null +++ b/tests/muxing/test_persona.py @@ -0,0 +1,490 @@ +import uuid +from pathlib import Path +from typing import List + +import pytest +from pydantic import BaseModel, ValidationError + +from codegate.db import connection +from codegate.muxing.persona import ( + PersonaDoesNotExistError, + PersonaManager, + PersonaSimilarDescriptionError, +) + + +@pytest.fixture +def db_path(): + """Creates a temporary database file path.""" + current_test_dir = Path(__file__).parent + db_filepath = current_test_dir / f"codegate_test_{uuid.uuid4()}.db" + db_fullpath = db_filepath.absolute() + connection.init_db_sync(str(db_fullpath)) + yield db_fullpath + if db_fullpath.is_file(): + db_fullpath.unlink() + + +@pytest.fixture() +def db_recorder(db_path) -> connection.DbRecorder: + """Creates a DbRecorder instance with test database.""" + return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def db_reader(db_path) -> connection.DbReader: + """Creates a DbReader instance with test database.""" + return connection.DbReader(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def semantic_router_mocked_db( + db_recorder: connection.DbRecorder, db_reader: connection.DbReader +) -> PersonaManager: + """Creates a SemanticRouter instance with mocked database.""" + semantic_router = PersonaManager() + semantic_router._db_reader = db_reader + semantic_router._db_recorder = db_recorder + return semantic_router + + +@pytest.mark.asyncio +async def test_add_persona(semantic_router_mocked_db: PersonaManager): + """Test adding a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + retrieved_persona = await semantic_router_mocked_db.get_persona(persona_name) + assert retrieved_persona.name == persona_name + assert retrieved_persona.description == persona_desc + + +@pytest.mark.asyncio +async def test_add_duplicate_persona(semantic_router_mocked_db: PersonaManager): + """Test adding a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + # Update the description to not trigger the similarity check + updated_description = "foo and bar description" + with pytest.raises(connection.AlreadyExistsError): + await semantic_router_mocked_db.add_persona(persona_name, updated_description) + + +@pytest.mark.asyncio +async def test_add_persona_invalid_name(semantic_router_mocked_db: PersonaManager): + """Test adding a persona to the database.""" + persona_name = "test_persona&" + persona_desc = "test_persona_desc" + with pytest.raises(ValidationError): + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.delete_persona(persona_name) + + +@pytest.mark.asyncio +async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager): + """Test checking persona match when persona does not exist""" + persona_name = "test_persona" + query = "test_query" + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.check_persona_match(persona_name, query) + + +class PersonaMatchTest(BaseModel): + persona_name: str + persona_desc: str + pass_queries: List[str] + fail_queries: List[str] + + +simple_persona = PersonaMatchTest( + persona_name="test_persona", + persona_desc="test_desc", + pass_queries=["test_desc", "test_desc2"], + fail_queries=["foo"], +) + +# Architect Persona +architect = PersonaMatchTest( + persona_name="architect", + persona_desc=""" + Expert in designing and planning software systems, technical infrastructure, and solution + architecture. + Specializes in creating scalable, maintainable, and resilient system designs. + Deep knowledge of architectural patterns, principles, and best practices. + Experienced in evaluating technology stacks and making strategic technical decisions. + Skilled at creating architecture diagrams, technical specifications, and system + documentation. + Focuses on both functional and non-functional requirements like performance, security, + and reliability. + Guides development teams on implementing complex systems and following architectural + guidelines. + + Designs system architectures that balance business needs with technical constraints. + Creates technical roadmaps and migration strategies for legacy system modernization. + Evaluates trade-offs between different architectural approaches (monolithic, microservices, + serverless). + Implements domain-driven design principles to align software with business domains. + + Develops reference architectures and technical standards for organization-wide adoption. + Conducts architecture reviews and provides recommendations for improvement. + Collaborates with stakeholders to translate business requirements into technical solutions. + Stays current with emerging technologies and evaluates their potential application. + + Designs for cloud-native environments using containerization, orchestration, and managed + services. + Implements event-driven architectures using message queues, event buses, and streaming + platforms. + Creates data architectures that address storage, processing, and analytics requirements. + Develops integration strategies for connecting disparate systems and services. + """, + pass_queries=[ + """ + How should I design a system architecture that can scale with our growing user base? + """, + """ + What's the best approach for migrating our monolithic application to microservices? + """, + """ + I need to create a technical roadmap for modernizing our legacy systems. Where should + I start? + """, + """ + Can you help me evaluate different cloud providers for our new infrastructure? + """, + """ + What architectural patterns would you recommend for a distributed e-commerce platform? + """, + ], + fail_queries=[ + """ + How do I fix this specific bug in my JavaScript code? + """, + """ + What's the syntax for a complex SQL query joining multiple tables? + """, + """ + How do I implement authentication in my React application? + """, + """ + What's the best way to optimize the performance of this specific function? + """, + ], +) + +# Coder Persona +coder = PersonaMatchTest( + persona_name="coder", + persona_desc=""" + Expert in full stack development, programming, and software implementation. + Specializes in writing, debugging, and optimizing code across the entire technology stack. + + Proficient in multiple programming languages including JavaScript, Python, Java, C#, and + TypeScript. + Implements efficient algorithms and data structures to solve complex programming challenges. + Develops maintainable code with appropriate patterns and practices for different contexts. + + Experienced in frontend development using modern frameworks and libraries. + Creates responsive, accessible user interfaces with HTML, CSS, and JavaScript frameworks. + Implements state management, component architecture, + and client-side performance optimization for frontend applications. + + Skilled in backend development and server-side programming. + Builds RESTful APIs, GraphQL services, and microservices architectures. + Implements authentication, authorization, and security best practices in web applications. + Understands best ways for different backend problems, like file uploads, caching, + and database interactions. + + Designs and manages databases including schema design, query optimization, + and data modeling. + Works with both SQL and NoSQL databases to implement efficient data storage solutions. + Creates data access layers and ORM implementations for application data requirements. + + Handles integration between different systems and third-party services. + Implements webhooks, API clients, and service communication patterns. + Develops data transformation and processing pipelines for various application needs. + + Identifies and resolves performance issues across the application stack. + Uses debugging tools, profilers, and testing frameworks to ensure code quality. + Implements comprehensive testing strategies including unit, integration, + and end-to-end tests. + """, + pass_queries=[ + """ + How do I implement authentication in my web application? + """, + """ + What's the best way to structure a RESTful API for my project? + """, + """ + I need help optimizing my database queries for better performance. + """, + """ + How should I implement state management in my frontend application? + """, + """ + What's the differnce between SQL and NoSQL databases, and when should I use each? + """, + ], + fail_queries=[ + """ + What's the best approach for setting up a CI/CD pipeline for our team? + """, + """ + Can you help me configure auto-scaling for our Kubernetes cluster? + """, + """ + How should I structure our cloud infrastructure for better cost efficiency? + """, + """ + How do I cook a delicious lasagna for dinner? + """, + ], +) + +# DevOps/SRE Engineer Persona +devops_sre = PersonaMatchTest( + persona_name="devops sre engineer", + persona_desc=""" + Expert in infrastructure automation, deployment pipelines, and operational reliability. + Specializes in building and maintaining scalable, resilient, and secure infrastructure. + Proficient with cloud platforms (AWS, Azure, GCP), containerization, and orchestration. + Experienced with infrastructure as code, configuration management, and automation tools. + Skilled in implementing CI/CD pipelines, monitoring systems, and observability solutions. + Focuses on reliability, performance, security, and operational efficiency. + Practices site reliability engineering principles and DevOps methodologies. + + Designs and implements cloud infrastructure using services like compute, storage, + networking, and databases. + Creates infrastructure as code using tools like Terraform, CloudFormation, or Pulumi. + Configures and manages container orchestration platforms like Kubernetes and ECS. + Implements CI/CD pipelines using tools like Jenkins, GitHub Actions, GitLab CI, or CircleCI. + + Sets up comprehensive monitoring, alerting, and observability solutions. + Implements logging aggregation, metrics collection, and distributed tracing. + Creates dashboards and visualizations for system performance and health. + Designs and implements disaster recovery and backup strategies. + + Automates routine operational tasks and infrastructure maintenance. + Conducts capacity planning, performance tuning, and cost optimization. + Implements security best practices, compliance controls, and access management. + Performs incident response, troubleshooting, and post-mortem analysis. + + Designs for high availability, fault tolerance, and graceful degradation. + Implements auto-scaling, load balancing, and traffic management solutions. + Creates runbooks, documentation, and operational procedures. + Conducts chaos engineering experiments to improve system resilience. + """, + pass_queries=[ + """ + How do I set up a Kubernetes cluster with proper high availability? + """, + """ + What's the best approach for implementing a CI/CD pipeline for our microservices? + """, + """ + How can I automate our infrastructure provisioning using Terraform? + """, + """ + What monitoring metrics should I track to ensure the reliability of our system? + """, + ], + fail_queries=[ + """ + How do I implement a sorting algorithm in Python? + """, + """ + What's the best way to structure my React components for a single-page application? + """, + """ + Can you help me design a database schema for my e-commerce application? + """, + """ + How do I create a responsive layout using CSS Grid and Flexbox? + """, + """ + What's the most efficient algorithm for finding the shortest path in a graph? + """, + ], +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "persona_match_test", + [ + simple_persona, + architect, + coder, + devops_sre, + ], +) +async def test_check_persona_pass_match( + semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest +): + """Test checking persona match.""" + await semantic_router_mocked_db.add_persona( + persona_match_test.persona_name, persona_match_test.persona_desc + ) + + # Check for the queries that should pass + for query in persona_match_test.pass_queries: + match = await semantic_router_mocked_db.check_persona_match( + persona_match_test.persona_name, query + ) + assert match is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "persona_match_test", + [ + simple_persona, + architect, + coder, + devops_sre, + ], +) +async def test_check_persona_fail_match( + semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest +): + """Test checking persona match.""" + await semantic_router_mocked_db.add_persona( + persona_match_test.persona_name, persona_match_test.persona_desc + ) + + # Check for the queries that should fail + for query in persona_match_test.fail_queries: + match = await semantic_router_mocked_db.check_persona_match( + persona_match_test.persona_name, query + ) + assert match is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "personas", + [ + [ + coder, + devops_sre, + architect, + ] + ], +) +async def test_persona_diff_description( + semantic_router_mocked_db: PersonaManager, + personas: List[PersonaMatchTest], +): + # First, add all existing personas + for persona in personas: + await semantic_router_mocked_db.add_persona(persona.persona_name, persona.persona_desc) + + last_added_persona = personas[-1] + with pytest.raises(PersonaSimilarDescriptionError): + await semantic_router_mocked_db.add_persona( + "repeated persona", last_added_persona.persona_desc + ) + + +@pytest.mark.asyncio +async def test_update_persona(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database different name and description.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + updated_description = "foo and bar description" + await semantic_router_mocked_db.update_persona( + persona_name, new_persona_name="new test persona", new_persona_desc=updated_description + ) + + +@pytest.mark.asyncio +async def test_update_persona_same_desc(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database with same description.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + await semantic_router_mocked_db.update_persona( + persona_name, new_persona_name="new test persona", new_persona_desc=persona_desc + ) + + +@pytest.mark.asyncio +async def test_update_persona_not_exists(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.update_persona( + persona_name, new_persona_name="new test persona", new_persona_desc=persona_desc + ) + + +@pytest.mark.asyncio +async def test_update_persona_same_name(semantic_router_mocked_db: PersonaManager): + """Test updating a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + persona_name_2 = "test_persona_2" + persona_desc_2 = "foo and bar" + await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2) + + with pytest.raises(connection.AlreadyExistsError): + await semantic_router_mocked_db.update_persona( + persona_name_2, new_persona_name=persona_name, new_persona_desc=persona_desc_2 + ) + + +@pytest.mark.asyncio +async def test_delete_persona(semantic_router_mocked_db: PersonaManager): + """Test deleting a persona from the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + await semantic_router_mocked_db.delete_persona(persona_name) + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.get_persona(persona_name) + + +@pytest.mark.asyncio +async def test_delete_persona_not_exists(semantic_router_mocked_db: PersonaManager): + persona_name = "test_persona" + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.delete_persona(persona_name) + + +@pytest.mark.asyncio +async def test_get_personas(semantic_router_mocked_db: PersonaManager): + """Test getting personas from the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + persona_name_2 = "test_persona_2" + persona_desc_2 = "foo and bar" + await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2) + + all_personas = await semantic_router_mocked_db.get_all_personas() + assert len(all_personas) == 2 + assert all_personas[0].name == persona_name + assert all_personas[1].name == persona_name_2 + + +@pytest.mark.asyncio +async def test_get_personas_empty(semantic_router_mocked_db: PersonaManager): + """Test adding a persona to the database.""" + + all_personas = await semantic_router_mocked_db.get_all_personas() + assert len(all_personas) == 0 diff --git a/tests/muxing/test_semantic_router.py b/tests/muxing/test_semantic_router.py deleted file mode 100644 index c8c7edc6..00000000 --- a/tests/muxing/test_semantic_router.py +++ /dev/null @@ -1,590 +0,0 @@ -import uuid -from pathlib import Path -from typing import List - -import pytest -from pydantic import BaseModel - -from codegate.db import connection -from codegate.muxing.semantic_router import PersonaDoesNotExistError, SemanticRouter - - -@pytest.fixture -def db_path(): - """Creates a temporary database file path.""" - current_test_dir = Path(__file__).parent - db_filepath = current_test_dir / f"codegate_test_{uuid.uuid4()}.db" - db_fullpath = db_filepath.absolute() - connection.init_db_sync(str(db_fullpath)) - yield db_fullpath - if db_fullpath.is_file(): - db_fullpath.unlink() - - -@pytest.fixture() -def db_recorder(db_path) -> connection.DbRecorder: - """Creates a DbRecorder instance with test database.""" - return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) - - -@pytest.fixture() -def db_reader(db_path) -> connection.DbReader: - """Creates a DbReader instance with test database.""" - return connection.DbReader(sqlite_path=db_path, _no_singleton=True) - - -@pytest.fixture() -def semantic_router_mocked_db( - db_recorder: connection.DbRecorder, db_reader: connection.DbReader -) -> SemanticRouter: - """Creates a SemanticRouter instance with mocked database.""" - semantic_router = SemanticRouter() - semantic_router._db_reader = db_reader - semantic_router._db_recorder = db_recorder - return semantic_router - - -@pytest.mark.asyncio -async def test_add_persona(semantic_router_mocked_db: SemanticRouter): - """Test adding a persona to the database.""" - persona_name = "test_persona" - persona_desc = "test_persona_desc" - await semantic_router_mocked_db.add_persona(persona_name, persona_desc) - retrieved_persona = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name) - assert retrieved_persona.name == persona_name - assert retrieved_persona.description == persona_desc - - -@pytest.mark.asyncio -async def test_persona_not_exist_match(semantic_router_mocked_db: SemanticRouter): - """Test checking persona match when persona does not exist""" - persona_name = "test_persona" - query = "test_query" - with pytest.raises(PersonaDoesNotExistError): - await semantic_router_mocked_db.check_persona_match(persona_name, query) - - -class PersonaMatchTest(BaseModel): - persona_name: str - persona_desc: str - pass_queries: List[str] - fail_queries: List[str] - - -simple_persona = PersonaMatchTest( - persona_name="test_persona", - persona_desc="test_desc", - pass_queries=["test_desc", "test_desc2"], - fail_queries=["foo"], -) - -software_architect = PersonaMatchTest( - persona_name="software architect", - persona_desc=""" - Expert in designing large-scale software systems and technical infrastructure. - Specializes in distributed systems, microservices architecture, - and cloud-native applications. - Deep knowledge of architectural patterns like CQRS, event sourcing, hexagonal architecture, - and domain-driven design. - Experienced in designing scalable, resilient, and maintainable software solutions. - Proficient in evaluating technology stacks and making strategic technical decisions. - Skilled at creating architecture diagrams, technical specifications, - and system documentation. - Focuses on non-functional requirements like performance, security, and reliability. - Guides development teams on best practices for implementing complex systems. - """, - pass_queries=[ - """ - How should I design a microservices architecture that can handle high traffic loads? - """, - """ - What's the best approach for implementing event sourcing in a distributed system? - """, - """ - I need to design a system that can scale to millions of users. What architecture would you - recommend? - """, - """ - Can you explain the trade-offs between monolithic and microservices architectures for our - new project? - """, - ], - fail_queries=[ - """ - How do I create a simple landing page with HTML and CSS? - """, - """ - What's the best way to optimize my SQL query performance? - """, - """ - Can you help me debug this JavaScript function that's throwing an error? - """, - """ - How do I implement user authentication in my React application? - """, - ], -) - -# Data Scientist Persona -data_scientist = PersonaMatchTest( - persona_name="data scientist", - persona_desc=""" - Expert in analyzing and interpreting complex data to solve business problems. - Specializes in statistical analysis, machine learning algorithms, and predictive modeling. - Builds and deploys models for classification, regression, clustering, and anomaly detection. - Proficient in data preprocessing, feature engineering, and model evaluation techniques. - Uses Python with libraries like NumPy, Pandas, scikit-learn, TensorFlow, and PyTorch. - Experienced with data visualization using Matplotlib, Seaborn, and interactive dashboards. - Applies experimental design principles and A/B testing methodologies. - Works with structured and unstructured data, including time series and text. - Implements data pipelines for model training, validation, and deployment. - Communicates insights and recommendations based on data analysis to stakeholders. - - Handles class imbalance problems in classification tasks using techniques like SMOTE, - undersampling, oversampling, and class weighting. Addresses customer churn prediction - challenges by identifying key features that indicate potential churners. - - Applies feature selection methods for high-dimensional datasets, including filter methods - (correlation, chi-square), wrapper methods (recursive feature elimination), and embedded - methods (LASSO regularization). - - Prevents overfitting and high variance in tree-based models like random forests through - techniques such as pruning, setting maximum depth, adjusting minimum samples per leaf, - and cross-validation. - - Specializes in time series forecasting for sales and demand prediction, using methods like - ARIMA, SARIMA, Prophet, and exponential smoothing to handle seasonal patterns and trends. - Implements forecasting models that account for quarterly business cycles and seasonal - variations in customer behavior. - - Evaluates model performance using appropriate metrics: accuracy, precision, recall, - F1-score - for classification; RMSE, MAE, R-squared for regression; and specialized metrics for - time series forecasting like MAPE and SMAPE. - - Experienced in developing customer segmentation models, recommendation systems, - anomaly detection algorithms, and predictive maintenance solutions. - """, - pass_queries=[ - """ - How should I handle class imbalance in my customer churn prediction model? - """, - """ - What feature selection techniques would work best for my high-dimensional dataset? - """, - """ - I'm getting high variance in my random forest model. How can I prevent overfitting? - """, - """ - What's the best approach for forecasting seasonal time series data for our sales - predictions? - """, - ], - fail_queries=[ - """ - How do I structure my React components for a single-page application? - """, - """ - What's the best way to implement a CI/CD pipeline for my microservices? - """, - """ - Can you help me design a responsive layout for mobile and desktop browsers? - """, - """ - How should I configure my Kubernetes cluster for high availability? - """, - ], -) - -# UX Designer Persona -ux_designer = PersonaMatchTest( - persona_name="ux designer", - persona_desc=""" - Expert in creating intuitive, user-centered digital experiences and interfaces. - Specializes in user research, usability testing, and interaction design. - Creates wireframes, prototypes, and user flows to visualize design solutions. - Conducts user interviews, usability studies, and analyzes user feedback. - Develops user personas and journey maps to understand user needs and pain points. - Designs information architecture and navigation systems for complex applications. - Applies design thinking methodology to solve user experience problems. - Knowledgeable about accessibility standards and inclusive design principles. - Collaborates with product managers and developers to implement user-friendly features. - Uses tools like Figma, Sketch, and Adobe XD to create high-fidelity mockups. - """, - pass_queries=[ - """ - How can I improve the user onboarding experience for my mobile application? - """, - """ - What usability testing methods would you recommend for evaluating our new interface design? - """, - """ - I'm designing a complex dashboard. What information architecture would make it most - intuitive for users? - """, - """ - How should I structure user research to identify pain points in our current - checkout process? - """, - ], - fail_queries=[ - """ - How do I configure a load balancer for my web servers? - """, - """ - What's the best way to implement a caching layer in my application? - """, - """ - Can you explain how to set up a CI/CD pipeline with GitHub Actions? - """, - """ - How do I optimize my database queries for better performance? - """, - ], -) - -# DevOps Engineer Persona -devops_engineer = PersonaMatchTest( - persona_name="devops engineer", - persona_desc=""" - Expertise: Infrastructure automation, CI/CD pipelines, cloud services, containerization, - and monitoring. - Proficient with tools like Docker, Kubernetes, Terraform, Ansible, and Jenkins. - Experienced with cloud platforms including AWS, Azure, and Google Cloud. - Strong knowledge of Linux/Unix systems administration and shell scripting. - Skilled in implementing microservices architectures and service mesh technologies. - Focus on reliability, scalability, security, and operational efficiency. - Practices infrastructure as code, GitOps, and site reliability engineering principles. - Experienced with monitoring tools like Prometheus, Grafana, and ELK stack. - """, - pass_queries=[ - """ - What's the best way to set up auto-scaling for my Kubernetes cluster on AWS? - """, - """ - I need to implement a zero-downtime deployment strategy for my microservices. - What approaches would you recommend? - """, - """ - How can I improve the security of my CI/CD pipeline and prevent supply chain attacks? - """, - """ - What monitoring metrics should I track to ensure the reliability of my distributed system? - """, - ], - fail_queries=[ - """ - How do I design an effective user onboarding flow for my mobile app? - """, - """ - What's the best algorithm for sentiment analysis on customer reviews? - """, - """ - Can you help me with color theory for my website redesign? - """, - """ - I need advice on optimizing my SQL queries for a reporting dashboard. - """, - ], -) - -# Security Specialist Persona -security_specialist = PersonaMatchTest( - persona_name="security specialist", - persona_desc=""" - Expert in cybersecurity, application security, and secure system design. - Specializes in identifying and mitigating security vulnerabilities and threats. - Performs security assessments, penetration testing, and code security reviews. - Implements security controls like authentication, authorization, and encryption. - Knowledgeable about common attack vectors such as injection attacks, XSS, CSRF, and SSRF. - Experienced with security frameworks and standards like OWASP Top 10, NIST, and ISO 27001. - Designs secure architectures and implements defense-in-depth strategies. - Conducts security incident response and forensic analysis. - Implements security monitoring, logging, and alerting systems. - Stays current with emerging security threats and mitigation techniques. - """, - pass_queries=[ - """ - How can I protect my web application from SQL injection attacks? - """, - """ - What security controls should I implement for storing sensitive user data? - """, - """ - How do I conduct a thorough security assessment of our cloud infrastructure? - """, - """ - What's the best approach for implementing secure authentication in my API? - """, - ], - fail_queries=[ - """ - How do I optimize the loading speed of my website? - """, - """ - What's the best way to implement responsive design for mobile devices? - """, - """ - Can you help me design a database schema for my e-commerce application? - """, - """ - How should I structure my React components for better code organization? - """, - ], -) - -# Mobile Developer Persona -mobile_developer = PersonaMatchTest( - persona_name="mobile developer", - persona_desc=""" - Expert in building native and cross-platform mobile applications for iOS and Android. - Specializes in mobile UI development, responsive layouts, and platform-specific - design patterns. - Proficient in Swift and SwiftUI for iOS, Kotlin for Android, and React Native or - Flutter for cross-platform. - Implements mobile-specific features like push notifications, offline storage, and - location services. - Optimizes mobile applications for performance, battery efficiency, and limited - network connectivity. - Experienced with mobile app architecture patterns like MVVM, MVC, and Redux. - Integrates with device hardware features including camera, biometrics, sensors, - and Bluetooth. - Familiar with app store submission processes, app signing, and distribution workflows. - Implements secure data storage, authentication, and API communication on mobile devices. - Designs and develops responsive interfaces that work across different screen sizes - and orientations. - - Implements sophisticated offline-first data synchronization strategies - for mobile applications, - handling conflict resolution, data merging, and background syncing when connectivity - is restored. - Uses technologies like Realm, SQLite, Core Data, and Room Database to enable seamless - offline - experiences in React Native and native apps. - - Structures Swift code following the MVVM (Model-View-ViewModel) architectural pattern - to create - maintainable, testable iOS applications. Implements proper separation of concerns - with bindings - between views and view models using Combine, RxSwift, or SwiftUI's native state management. - - Specializes in deep linking implementation for both Android and iOS, enabling app-to-app - communication, marketing campaign tracking, and seamless user experiences when navigating - between web and mobile contexts. Configures Universal Links, App Links, and custom URL - schemes. - - Optimizes battery usage for location-based features by implementing intelligent location - tracking - strategies, including geofencing, significant location changes, deferred location updates, - and - region monitoring. Balances accuracy requirements with power consumption constraints. - - Develops efficient state management solutions for complex mobile applications using Redux, - MobX, Provider, or Riverpod for React Native apps, and native state management approaches - for iOS and Android. - - Creates responsive mobile interfaces that adapt to different device orientations, - screen sizes, - and pixel densities using constraint layouts, auto layout, size classes, and flexible - grid systems. - """, - pass_queries=[ - """ - What's the best approach for implementing offline-first data synchronization in my mobile - app? - """, - """ - How should I structure my Swift code to implement the MVVM pattern effectively? - """, - """ - What's the most efficient way to handle deep linking and app-to-app communication on - Android? - """, - """ - How can I optimize battery usage when implementing background location tracking? - """, - ], - fail_queries=[ - """ - How do I design a database schema with proper normalization for my web application? - """, - """ - What's the best approach for implementing a distributed caching layer in my microservices? - """, - """ - Can you help me set up a data pipeline for processing large datasets with Apache Spark? - """, - """ - How should I configure my load balancer to distribute traffic across my web servers? - """, - ], -) - -# Database Administrator Persona -database_administrator = PersonaMatchTest( - persona_name="database administrator", - persona_desc=""" - Expert in designing, implementing, and managing database systems for optimal performance and - reliability. - Specializes in database architecture, schema design, and query optimization techniques. - Proficient with relational databases like PostgreSQL, MySQL, Oracle, and SQL Server. - Implements and manages database security, access controls, and data protection measures. - Designs high-availability solutions using replication, clustering, and failover mechanisms. - Develops and executes backup strategies, disaster recovery plans, and data retention - policies. - Monitors database performance, identifies bottlenecks, and implements optimization - solutions. - Creates and maintains indexes, partitioning schemes, and other performance-enhancing - structures. - Experienced with database migration, version control, and change management processes. - Implements data integrity constraints, stored procedures, triggers, and database automation. - - Optimizes complex JOIN query performance in PostgreSQL through advanced techniques including - query rewriting, proper indexing strategies, materialized views, and query plan analysis. - Uses EXPLAIN ANALYZE to identify bottlenecks in query execution plans and implements - appropriate optimizations for specific query patterns. - - Designs and implements high-availability MySQL configurations with automatic failover using - technologies like MySQL Group Replication, Galera Cluster, Percona XtraDB Cluster, or MySQL - InnoDB Cluster with MySQL Router. Configures synchronous and asynchronous replication - strategies - to balance consistency and performance requirements. - - Develops sophisticated indexing strategies for tables with frequent write operations and - complex - read queries, balancing write performance with read optimization. Implements partial - indexes, - covering indexes, and composite indexes based on query patterns and cardinality analysis. - - Specializes in large-scale database migrations between different database engines, - particularly - Oracle to PostgreSQL transitions. Uses tools like ora2pg, AWS DMS, and custom ETL processes - to - ensure data integrity, schema compatibility, and minimal downtime during migration. - - Implements table partitioning schemes based on data access patterns, including range - partitioning - for time-series data, list partitioning for categorical data, and hash partitioning for - evenly - distributed workloads. - - Configures and manages database connection pooling, query caching, and buffer management to - optimize resource utilization and throughput under varying workloads. - - Designs and implements database sharding strategies for horizontal scaling, including - consistent hashing algorithms, shard key selection, and cross-shard query optimization. - """, - pass_queries=[ - """ - How can I optimize the performance of complex JOIN queries in my PostgreSQL database? - """, - """ - What's the best approach for implementing a high-availability MySQL setup with automatic - failover? - """, - """ - How should I design my indexing strategy for a table with frequent writes and complex read - queries? - """, - """ - What's the most efficient way to migrate a large Oracle database to PostgreSQL with minimal - downtime? - """, - ], - fail_queries=[ - """ - How do I structure my React components to implement the Redux state management pattern? - """, - """ - What's the best approach for implementing responsive design with CSS Grid and Flexbox? - """, - """ - Can you help me set up a CI/CD pipeline for my containerized microservices? - """, - ], -) - -# Natural Language Processing Specialist Persona -nlp_specialist = PersonaMatchTest( - persona_name="nlp specialist", - persona_desc=""" - Expertise: Natural language processing, computational linguistics, and text analytics. - Proficient with NLP libraries and frameworks like NLTK, spaCy, Hugging Face Transformers, - and Gensim. - Experience with language models such as BERT, GPT, T5, and their applications. - Skilled in text preprocessing, tokenization, lemmatization, and feature extraction - techniques. - Knowledge of sentiment analysis, named entity recognition, topic modeling, and text - classification. - Familiar with word embeddings, contextual embeddings, and language representation methods. - Understanding of machine translation, question answering, and text summarization systems. - Background in information retrieval, semantic search, and conversational AI development. - """, - pass_queries=[ - """ - What approach should I take to fine-tune BERT for my custom text classification task? - """, - """ - How can I improve the accuracy of my named entity recognition system for medical texts? - """, - """ - What's the best way to implement semantic search using embeddings from language models? - """, - """ - I need to build a sentiment analysis system that can handle sarcasm and idioms. - Any suggestions? - """, - ], - fail_queries=[ - """ - How do I optimize my React components to reduce rendering time? - """, - """ - What's the best approach for implementing a CI/CD pipeline with Jenkins? - """, - """ - Can you help me design a responsive UI for my web application? - """, - """ - How should I structure my microservices architecture for scalability? - """, - ], -) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "persona_match_test", - [ - simple_persona, - software_architect, - data_scientist, - ux_designer, - devops_engineer, - security_specialist, - mobile_developer, - database_administrator, - nlp_specialist, - ], -) -async def test_check_persona_match( - semantic_router_mocked_db: SemanticRouter, persona_match_test: PersonaMatchTest -): - """Test checking persona match.""" - await semantic_router_mocked_db.add_persona( - persona_match_test.persona_name, persona_match_test.persona_desc - ) - - # Check for the queries that should pass - for query in persona_match_test.pass_queries: - match = await semantic_router_mocked_db.check_persona_match( - persona_match_test.persona_name, query - ) - assert match is True - - # Check for the queries that should fail - for query in persona_match_test.fail_queries: - match = await semantic_router_mocked_db.check_persona_match( - persona_match_test.persona_name, query - ) - assert match is False diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py index d626b8cf..e856653c 100644 --- a/tests/pipeline/pii/test_analyzer.py +++ b/tests/pipeline/pii/test_analyzer.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from presidio_analyzer import RecognizerResult from codegate.pipeline.pii.analyzer import PiiAnalyzer diff --git a/tests/pipeline/sensitive_data/test_manager.py b/tests/pipeline/sensitive_data/test_manager.py index 6115ad14..66305388 100644 --- a/tests/pipeline/sensitive_data/test_manager.py +++ b/tests/pipeline/sensitive_data/test_manager.py @@ -1,6 +1,7 @@ -import json from unittest.mock import MagicMock, patch + import pytest + from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager from codegate.pipeline.sensitive_data.session_store import SessionStore diff --git a/tests/pipeline/sensitive_data/test_session_store.py b/tests/pipeline/sensitive_data/test_session_store.py index b9ab64fe..e90b953e 100644 --- a/tests/pipeline/sensitive_data/test_session_store.py +++ b/tests/pipeline/sensitive_data/test_session_store.py @@ -1,5 +1,5 @@ -import uuid import pytest + from codegate.pipeline.sensitive_data.session_store import SessionStore diff --git a/tests/test_server.py b/tests/test_server.py index aa549810..bcf55e7e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,7 +14,6 @@ from codegate import __version__ from codegate.pipeline.factory import PipelineFactory -from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers.registry import ProviderRegistry from codegate.server import init_app from src.codegate.cli import UvicornServer, cli