Skip to content

Commit 1993705

Browse files
authored
Enhance EnterpriseActionTool with improved schema processing and erro… (crewAIInc#371)
* Enhance EnterpriseActionTool with improved schema processing and error handling - Added methods for sanitizing names and processing schema types, including support for nested models and nullable types. - Improved error handling during schema creation and processing, with warnings for failures. - Updated parameter handling in the `_run` method to clean up `kwargs` before sending requests. - Introduced a detailed description generation for nested schema structures to enhance tool documentation. * Add tests for EnterpriseActionTool schema conversion and validation - Introduced a new test class for validating complex nested schemas in EnterpriseActionTool. - Added tests for schema conversion, optional fields, enum validation, and required nested fields. - Implemented execution tests to ensure the tool can handle complex validated input correctly. - Verified model naming conventions and added tests for simpler schemas with basic enum validation. - Enhanced overall test coverage for the EnterpriseActionTool functionality. * Update chromadb dependency version in pyproject.toml and uv.lock - Changed chromadb version from >=0.4.22 to ==0.5.23 in both pyproject.toml and uv.lock to ensure compatibility and stability. * Update test workflow configuration - Changed EMBEDCHAIN_DB_URI to point to a temporary test database location. - Added CHROMA_PERSIST_PATH for specifying the path to the Chroma test database. - Cleaned up the test run command in the workflow file. * reverted
1 parent 504062f commit 1993705

File tree

5 files changed

+458
-73
lines changed

5 files changed

+458
-73
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
# This is a temporary fix until the runner includes SQLite with FTS5 or Python's sqlite3
4545
# module is compiled with FTS5 support by default.
4646
# TODO: Remove this workaround once GitHub Actions runners include SQLite FTS5 support
47-
47+
4848
# Install pysqlite3-binary which has FTS5 support
4949
uv pip install pysqlite3-binary
5050
# Create a sitecustomize.py to override sqlite3 with pysqlite3
@@ -55,5 +55,4 @@ jobs:
5555
PYTHONPATH=.pytest_sqlite_override uv run python -c "import sqlite3; conn = sqlite3.connect(':memory:'); conn.execute('CREATE VIRTUAL TABLE test USING fts5(content)'); print('FTS5 module available')"
5656
5757
- name: Run tests
58-
run: PYTHONPATH=.pytest_sqlite_override uv run pytest tests -vv
59-
58+
run: PYTHONPATH=.pytest_sqlite_override uv run pytest tests -vv

crewai_tools/adapters/enterprise_adapter.py

Lines changed: 186 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
import json
33
import requests
4-
from typing import List, Any, Dict, Optional
4+
from typing import List, Any, Dict, Literal, Optional, Union, get_origin
55
from pydantic import Field, create_model
66
from crewai.tools import BaseTool
7+
import re
8+
79

810
# DEFAULTS
911
ENTERPRISE_ACTION_KIT_PROJECT_ID = "dd525517-df22-49d2-a69e-6a0eed211166"
@@ -37,29 +39,46 @@ def __init__(
3739
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
3840
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
3941
):
42+
self._model_registry = {}
43+
self._base_name = self._sanitize_name(name)
44+
4045
schema_props, required = self._extract_schema_info(action_schema)
4146

4247
# Define field definitions for the model
4348
field_definitions = {}
4449
for param_name, param_details in schema_props.items():
4550
param_desc = param_details.get("description", "")
4651
is_required = param_name in required
47-
is_nullable, param_type = self._analyze_field_type(param_details)
4852

49-
# Create field definition based on nullable and required status
53+
try:
54+
field_type = self._process_schema_type(
55+
param_details, self._sanitize_name(param_name).title()
56+
)
57+
except Exception as e:
58+
print(f"Warning: Could not process schema for {param_name}: {e}")
59+
field_type = str
60+
61+
# Create field definition based on requirement
5062
field_definitions[param_name] = self._create_field_definition(
51-
param_type, is_required, is_nullable, param_desc
63+
field_type, is_required, param_desc
5264
)
5365

5466
# Create the model
5567
if field_definitions:
56-
args_schema = create_model(
57-
f"{name.capitalize()}Schema", **field_definitions
58-
)
68+
try:
69+
args_schema = create_model(
70+
f"{self._base_name}Schema", **field_definitions
71+
)
72+
except Exception as e:
73+
print(f"Warning: Could not create main schema model: {e}")
74+
args_schema = create_model(
75+
f"{self._base_name}Schema",
76+
input_text=(str, Field(description="Input for the action")),
77+
)
5978
else:
6079
# Fallback for empty schema
6180
args_schema = create_model(
62-
f"{name.capitalize()}Schema",
81+
f"{self._base_name}Schema",
6382
input_text=(str, Field(description="Input for the action")),
6483
)
6584

@@ -73,6 +92,12 @@ def __init__(
7392
if enterprise_action_kit_project_url is not None:
7493
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
7594

95+
def _sanitize_name(self, name: str) -> str:
96+
"""Sanitize names to create proper Python class names."""
97+
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name)
98+
parts = sanitized.split("_")
99+
return "".join(word.capitalize() for word in parts if word)
100+
76101
def _extract_schema_info(
77102
self, action_schema: Dict[str, Any]
78103
) -> tuple[Dict[str, Any], List[str]]:
@@ -87,58 +112,105 @@ def _extract_schema_info(
87112
)
88113
return schema_props, required
89114

90-
def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]:
91-
"""Analyze field type and nullability from parameter details."""
92-
is_nullable = False
93-
param_type = str # Default type
94-
95-
if "anyOf" in param_details:
96-
any_of_types = param_details["anyOf"]
115+
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> type:
116+
"""Process a JSON schema and return appropriate Python type."""
117+
if "anyOf" in schema:
118+
any_of_types = schema["anyOf"]
97119
is_nullable = any(t.get("type") == "null" for t in any_of_types)
98120
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
121+
99122
if non_null_types:
100-
first_type = non_null_types[0].get("type", "string")
101-
param_type = self._map_json_type_to_python(
102-
first_type, non_null_types[0]
123+
base_type = self._process_schema_type(non_null_types[0], type_name)
124+
return Optional[base_type] if is_nullable else base_type
125+
return Optional[str]
126+
127+
if "oneOf" in schema:
128+
return self._process_schema_type(schema["oneOf"][0], type_name)
129+
130+
if "allOf" in schema:
131+
return self._process_schema_type(schema["allOf"][0], type_name)
132+
133+
json_type = schema.get("type", "string")
134+
135+
if "enum" in schema:
136+
enum_values = schema["enum"]
137+
if not enum_values:
138+
return self._map_json_type_to_python(json_type)
139+
return Literal[tuple(enum_values)] # type: ignore
140+
141+
if json_type == "array":
142+
items_schema = schema.get("items", {"type": "string"})
143+
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
144+
return List[item_type]
145+
146+
if json_type == "object":
147+
return self._create_nested_model(schema, type_name)
148+
149+
return self._map_json_type_to_python(json_type)
150+
151+
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> type:
152+
"""Create a nested Pydantic model for complex objects."""
153+
full_model_name = f"{self._base_name}{model_name}"
154+
155+
if full_model_name in self._model_registry:
156+
return self._model_registry[full_model_name]
157+
158+
properties = schema.get("properties", {})
159+
required_fields = schema.get("required", [])
160+
161+
if not properties:
162+
return dict
163+
164+
field_definitions = {}
165+
for prop_name, prop_schema in properties.items():
166+
prop_desc = prop_schema.get("description", "")
167+
is_required = prop_name in required_fields
168+
169+
try:
170+
prop_type = self._process_schema_type(
171+
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
103172
)
104-
else:
105-
json_type = param_details.get("type", "string")
106-
param_type = self._map_json_type_to_python(json_type, param_details)
107-
is_nullable = json_type == "null"
173+
except Exception as e:
174+
print(f"Warning: Could not process schema for {prop_name}: {e}")
175+
prop_type = str
108176

109-
return is_nullable, param_type
177+
field_definitions[prop_name] = self._create_field_definition(
178+
prop_type, is_required, prop_desc
179+
)
180+
181+
try:
182+
nested_model = create_model(full_model_name, **field_definitions)
183+
self._model_registry[full_model_name] = nested_model
184+
return nested_model
185+
except Exception as e:
186+
print(f"Warning: Could not create nested model {full_model_name}: {e}")
187+
return dict
110188

111189
def _create_field_definition(
112-
self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str
190+
self, field_type: type, is_required: bool, description: str
113191
) -> tuple:
114-
"""Create Pydantic field definition based on type, requirement, and nullability."""
115-
if is_nullable:
116-
return (
117-
Optional[param_type],
118-
Field(default=None, description=param_desc),
119-
)
120-
elif is_required:
121-
return (
122-
param_type,
123-
Field(description=param_desc),
124-
)
192+
"""Create Pydantic field definition based on type and requirement."""
193+
if is_required:
194+
return (field_type, Field(description=description))
125195
else:
126-
return (
127-
Optional[param_type],
128-
Field(default=None, description=param_desc),
129-
)
196+
if get_origin(field_type) is Union:
197+
return (field_type, Field(default=None, description=description))
198+
else:
199+
return (
200+
Optional[field_type],
201+
Field(default=None, description=description),
202+
)
130203

131-
def _map_json_type_to_python(
132-
self, json_type: str, param_details: Dict[str, Any]
133-
) -> type:
134-
"""Map JSON schema types to Python types."""
204+
def _map_json_type_to_python(self, json_type: str) -> type:
205+
"""Map basic JSON schema types to Python types."""
135206
type_mapping = {
136207
"string": str,
137208
"integer": int,
138209
"number": float,
139210
"boolean": bool,
140211
"array": list,
141212
"object": dict,
213+
"null": type(None),
142214
}
143215
return type_mapping.get(json_type, str)
144216

@@ -149,29 +221,37 @@ def _get_required_nullable_fields(self) -> List[str]:
149221
required_nullable_fields = []
150222
for param_name in required:
151223
param_details = schema_props.get(param_name, {})
152-
is_nullable, _ = self._analyze_field_type(param_details)
153-
if is_nullable:
224+
if self._is_nullable_type(param_details):
154225
required_nullable_fields.append(param_name)
155226

156227
return required_nullable_fields
157228

229+
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
230+
"""Check if a schema represents a nullable type."""
231+
if "anyOf" in schema:
232+
return any(t.get("type") == "null" for t in schema["anyOf"])
233+
return schema.get("type") == "null"
234+
158235
def _run(self, **kwargs) -> str:
159236
"""Execute the specific enterprise action with validated parameters."""
160237
try:
238+
cleaned_kwargs = {}
239+
for key, value in kwargs.items():
240+
if value is not None:
241+
cleaned_kwargs[key] = value
242+
161243
required_nullable_fields = self._get_required_nullable_fields()
162244

163245
for field_name in required_nullable_fields:
164-
if field_name not in kwargs:
165-
kwargs[field_name] = None
166-
167-
params = {k: v for k, v in kwargs.items() if v is not None}
246+
if field_name not in cleaned_kwargs:
247+
cleaned_kwargs[field_name] = None
168248

169249
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
170250
headers = {
171251
"Authorization": f"Bearer {self.enterprise_action_token}",
172252
"Content-Type": "application/json",
173253
}
174-
payload = {"action": self.action_name, "parameters": params}
254+
payload = {"action": self.action_name, "parameters": cleaned_kwargs}
175255

176256
response = requests.post(
177257
url=api_url, headers=headers, json=payload, timeout=60
@@ -198,19 +278,14 @@ def __init__(
198278
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
199279
):
200280
"""Initialize the adapter with an enterprise action token."""
201-
202281
self.enterprise_action_token = enterprise_action_token
203282
self._actions_schema = {}
204283
self._tools = None
205284
self.enterprise_action_kit_project_id = enterprise_action_kit_project_id
206285
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
207286

208287
def tools(self) -> List[BaseTool]:
209-
"""Get the list of tools created from enterprise actions.
210-
211-
Returns:
212-
List of BaseTool instances, one for each enterprise action.
213-
"""
288+
"""Get the list of tools created from enterprise actions."""
214289
if self._tools is None:
215290
self._fetch_actions()
216291
self._create_tools()
@@ -261,6 +336,53 @@ def _fetch_actions(self):
261336

262337
traceback.print_exc()
263338

339+
def _generate_detailed_description(
340+
self, schema: Dict[str, Any], indent: int = 0
341+
) -> List[str]:
342+
"""Generate detailed description for nested schema structures."""
343+
descriptions = []
344+
indent_str = " " * indent
345+
346+
schema_type = schema.get("type", "string")
347+
348+
if schema_type == "object":
349+
properties = schema.get("properties", {})
350+
required_fields = schema.get("required", [])
351+
352+
if properties:
353+
descriptions.append(f"{indent_str}Object with properties:")
354+
for prop_name, prop_schema in properties.items():
355+
prop_desc = prop_schema.get("description", "")
356+
is_required = prop_name in required_fields
357+
req_str = " (required)" if is_required else " (optional)"
358+
descriptions.append(
359+
f"{indent_str} - {prop_name}: {prop_desc}{req_str}"
360+
)
361+
362+
if prop_schema.get("type") == "object":
363+
descriptions.extend(
364+
self._generate_detailed_description(prop_schema, indent + 2)
365+
)
366+
elif prop_schema.get("type") == "array":
367+
items_schema = prop_schema.get("items", {})
368+
if items_schema.get("type") == "object":
369+
descriptions.append(f"{indent_str} Array of objects:")
370+
descriptions.extend(
371+
self._generate_detailed_description(
372+
items_schema, indent + 3
373+
)
374+
)
375+
elif "enum" in items_schema:
376+
descriptions.append(
377+
f"{indent_str} Array of enum values: {items_schema['enum']}"
378+
)
379+
elif "enum" in prop_schema:
380+
descriptions.append(
381+
f"{indent_str} Enum values: {prop_schema['enum']}"
382+
)
383+
384+
return descriptions
385+
264386
def _create_tools(self):
265387
"""Create BaseTool instances for each action."""
266388
tools = []
@@ -269,19 +391,16 @@ def _create_tools(self):
269391
function_details = action_schema.get("function", {})
270392
description = function_details.get("description", f"Execute {action_name}")
271393

272-
# Get parameter info for a better description
273-
parameters = function_details.get("parameters", {}).get("properties", {})
274-
param_info = []
275-
for param_name, param_details in parameters.items():
276-
param_desc = param_details.get("description", "")
277-
required = param_name in function_details.get("parameters", {}).get(
278-
"required", []
279-
)
280-
param_info.append(
281-
f"- {param_name}: {param_desc} {'(required)' if required else '(optional)'}"
394+
parameters = function_details.get("parameters", {})
395+
param_descriptions = []
396+
397+
if parameters.get("properties"):
398+
param_descriptions.append("\nDetailed Parameter Structure:")
399+
param_descriptions.extend(
400+
self._generate_detailed_description(parameters)
282401
)
283402

284-
full_description = f"{description}\n\nParameters:\n" + "\n".join(param_info)
403+
full_description = description + "\n".join(param_descriptions)
285404

286405
tool = EnterpriseActionTool(
287406
name=action_name.lower().replace(" ", "_"),
@@ -297,7 +416,6 @@ def _create_tools(self):
297416

298417
self._tools = tools
299418

300-
# Adding context manager support for convenience, but direct usage is also supported
301419
def __enter__(self):
302420
return self.tools()
303421

0 commit comments

Comments
 (0)