Skip to content

Commit

Permalink
fix issue with image width, enable text truncation, enable viewing im…
Browse files Browse the repository at this point in the history
…ages full screen
  • Loading branch information
victordibia committed Nov 27, 2024
1 parent aefaf35 commit 2cbc905
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_ext.agents import MultimodalWebSurfer

from autogen_ext.models import OpenAIChatCompletionClient

from ..datamodel.types import (
AgentConfig,
Expand All @@ -37,11 +36,9 @@
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer]
ModelComponent = Union[OpenAIChatCompletionClient]
ToolComponent = Union[FunctionTool] # Will grow with more tool types
TerminationComponent = Union[MaxMessageTermination,
StopMessageTermination, TextMentionTermination]
TerminationComponent = Union[MaxMessageTermination, StopMessageTermination, TextMentionTermination]

Component = Union[TeamComponent, AgentComponent,
ModelComponent, ToolComponent, TerminationComponent]
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent, TerminationComponent]

ReturnType = Literal["object", "dict", "config"]

Expand Down Expand Up @@ -122,8 +119,7 @@ async def load(

handler = handlers.get(config.component_type)
if not handler:
raise ValueError(
f"Unknown component type: {config.component_type}")
raise ValueError(f"Unknown component type: {config.component_type}")

return await handler(config)

Expand All @@ -147,8 +143,7 @@ async def load_directory(
component = await self.load(path, return_type=return_type)
components.append(component)
except Exception as e:
logger.info(
f"Failed to load component: {str(e)}, {path}")
logger.info(f"Failed to load component: {str(e)}, {path}")

return components
except Exception as e:
Expand Down Expand Up @@ -181,11 +176,9 @@ async def load_termination(self, config: TerminationConfig) -> TerminationCompon
try:
if config.termination_type == TerminationTypes.COMBINATION:
if not config.conditions or len(config.conditions) < 2:
raise ValueError(
"Combination termination requires at least 2 conditions")
raise ValueError("Combination termination requires at least 2 conditions")
if not config.operator:
raise ValueError(
"Combination termination requires an operator (and/or)")
raise ValueError("Combination termination requires an operator (and/or)")

# Load first two conditions
conditions = [await self.load_termination(cond) for cond in config.conditions[:2]]
Expand All @@ -200,27 +193,23 @@ async def load_termination(self, config: TerminationConfig) -> TerminationCompon

elif config.termination_type == TerminationTypes.MAX_MESSAGES:
if config.max_messages is None:
raise ValueError(
"max_messages parameter required for MaxMessageTermination")
raise ValueError("max_messages parameter required for MaxMessageTermination")
return MaxMessageTermination(max_messages=config.max_messages)

elif config.termination_type == TerminationTypes.STOP_MESSAGE:
return StopMessageTermination()

elif config.termination_type == TerminationTypes.TEXT_MENTION:
if not config.text:
raise ValueError(
"text parameter required for TextMentionTermination")
raise ValueError("text parameter required for TextMentionTermination")
return TextMentionTermination(text=config.text)

else:
raise ValueError(
f"Unsupported termination type: {config.termination_type}")
raise ValueError(f"Unsupported termination type: {config.termination_type}")

except Exception as e:
logger.error(f"Failed to create termination condition: {str(e)}")
raise ValueError(
f"Termination condition creation failed: {str(e)}") from e
raise ValueError(f"Termination condition creation failed: {str(e)}") from e

async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = None) -> TeamComponent:
"""Create team instance from configuration."""
Expand All @@ -246,8 +235,7 @@ async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = N
return RoundRobinGroupChat(participants=participants, termination_condition=termination)
elif config.team_type == TeamTypes.SELECTOR:
if not model_client:
raise ValueError(
"SelectorGroupChat requires a model_client")
raise ValueError("SelectorGroupChat requires a model_client")
selector_prompt = config.selector_prompt if config.selector_prompt else DEFAULT_SELECTOR_PROMPT
return SelectorGroupChat(
participants=participants,
Expand Down Expand Up @@ -306,8 +294,7 @@ async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] =
)

else:
raise ValueError(
f"Unsupported agent type: {config.agent_type}")
raise ValueError(f"Unsupported agent type: {config.agent_type}")

except Exception as e:
logger.error(f"Failed to create agent {config.name}: {str(e)}")
Expand All @@ -323,13 +310,11 @@ async def load_model(self, config: ModelConfig) -> ModelComponent:
return self._model_cache[cache_key]

if config.model_type == ModelTypes.OPENAI:
model = OpenAIChatCompletionClient(
model=config.model, api_key=config.api_key, base_url=config.base_url)
model = OpenAIChatCompletionClient(model=config.model, api_key=config.api_key, base_url=config.base_url)
self._model_cache[cache_key] = model
return model
else:
raise ValueError(
f"Unsupported model type: {config.model_type}")
raise ValueError(f"Unsupported model type: {config.model_type}")

except Exception as e:
logger.error(f"Failed to create model {config.model}: {str(e)}")
Expand All @@ -350,8 +335,7 @@ async def load_tool(self, config: ToolConfig) -> ToolComponent:

if config.tool_type == ToolTypes.PYTHON_FUNCTION:
tool = FunctionTool(
name=config.name, description=config.description, func=self._func_from_string(
config.content)
name=config.name, description=config.description, func=self._func_from_string(config.content)
)
self._tool_cache[cache_key] = tool
return tool
Expand Down Expand Up @@ -396,8 +380,7 @@ def _is_version_supported(self, component_type: ComponentTypes, ver: str) -> boo
"""Check if version is supported for component type."""
try:
version = Version(ver)
supported = [Version(v)
for v in self.SUPPORTED_VERSIONS[component_type]]
supported = [Version(v) for v in self.SUPPORTED_VERSIONS[component_type]]
return any(version == v for v in supported)
except ValueError:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class SocketMessage(BaseModel):
type: str


ComponentConfig = Union[TeamConfig, AgentConfig,
ModelConfig, ToolConfig, TerminationConfig]
ComponentConfig = Union[TeamConfig, AgentConfig, ModelConfig, ToolConfig, TerminationConfig]

ComponentConfigInput = Union[str, Path, dict, ComponentConfig]
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@
from uuid import UUID

from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
from ...datamodel import Run, RunStatus, TeamResult
from ...database import DatabaseManager
from ...teammanager import TeamManager
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage, MultiModalMessage
from autogen_agentchat.messages import AgentMessage, ChatMessage, MultiModalMessage, TextMessage
from autogen_core.base import CancellationToken
from autogen_core.components import Image as AGImage
from fastapi import WebSocket, WebSocketDisconnect

from ...database import DatabaseManager
from ...datamodel import Message, MessageConfig, Run, RunStatus, TeamResult
from ...teammanager import TeamManager
from autogen_core.components import Image as AGImage

logger = logging.getLogger(__name__)

Expand All @@ -42,8 +38,7 @@ def __init__(self, db_manager: DatabaseManager):

def _get_stop_message(self, reason: str) -> dict:
return TeamResult(
task_result=TaskResult(messages=[TextMessage(
source="user", content=reason)], stop_reason=reason),
task_result=TaskResult(messages=[TextMessage(source="user", content=reason)], stop_reason=reason),
usage="",
duration=0,
).model_dump()
Expand All @@ -57,8 +52,7 @@ async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
self._input_responses[run_id] = asyncio.Queue()

await self._send_message(
run_id, {"type": "system", "status": "connected",
"timestamp": datetime.now(timezone.utc).isoformat()}
run_id, {"type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat()}
)

return True
Expand All @@ -80,8 +74,7 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None
# Update run with task and status
run = await self._get_run(run_id)
if run:
run.task = MessageConfig(
content=task, source="user").model_dump()
run.task = MessageConfig(content=task, source="user").model_dump()
run.status = RunStatus.ACTIVE
self.db_manager.upsert(run)

Expand All @@ -91,8 +84,7 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None
task=task, team_config=team_config, input_func=input_func, cancellation_token=cancellation_token
):
if cancellation_token.is_cancelled() or run_id in self._closed_connections:
logger.info(
f"Stream cancelled or connection closed for run {run_id}")
logger.info(f"Stream cancelled or connection closed for run {run_id}")
break

formatted_message = self._format_message(message)
Expand All @@ -110,8 +102,7 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None
if final_result:
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
else:
logger.warning(
f"No final result captured for completed run {run_id}")
logger.warning(f"No final result captured for completed run {run_id}")
await self._update_run_status(run_id, RunStatus.COMPLETE)
else:
await self._send_message(
Expand Down Expand Up @@ -191,8 +182,7 @@ async def handle_input_response(self, run_id: UUID, response: str) -> None:
if run_id in self._input_responses:
await self._input_responses[run_id].put(response)
else:
logger.warning(
f"Received input response for inactive run {run_id}")
logger.warning(f"Received input response for inactive run {run_id}")

async def stop_run(self, run_id: UUID, reason: str) -> None:
if run_id in self._cancellation_tokens:
Expand Down Expand Up @@ -247,21 +237,18 @@ async def _send_message(self, run_id: UUID, message: dict) -> None:
message: Message dictionary to send
"""
if run_id in self._closed_connections:
logger.warning(
f"Attempted to send message to closed connection for run {run_id}")
logger.warning(f"Attempted to send message to closed connection for run {run_id}")
return

try:
if run_id in self._connections:
websocket = self._connections[run_id]
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(
f"WebSocket disconnected while sending message for run {run_id}")
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
await self.disconnect(run_id)
except Exception as e:
logger.error(
f"Error sending message for run {run_id}: {e}, {message}")
logger.error(f"Error sending message for run {run_id}: {e}, {message}")
# Don't try to send error message here to avoid potential recursive loop
await self._update_run_status(run_id, RunStatus.ERROR, str(e))
await self.disconnect(run_id)
Expand Down Expand Up @@ -303,7 +290,10 @@ def _format_message(self, message: Any) -> Optional[dict]:
message_dump = message.model_dump()
message_dump["content"] = [
message_dump["content"][0],
{"url": f"data:image/png;base64,{message_dump["content"][1]['data']}", "alt": "WebSurfer Screenshot"},
{
"url": f"data:image/png;base64,{message_dump['content'][1]['data']}",
"alt": "WebSurfer Screenshot",
},
]
return {"type": "message", "data": message_dump}
elif isinstance(message, (AgentMessage, ChatMessage)):
Expand All @@ -329,8 +319,7 @@ async def _get_run(self, run_id: UUID) -> Optional[Run]:
Returns:
Optional[Run]: Run object if found, None otherwise
"""
response = self.db_manager.get(
Run, filters={"id": run_id}, return_json=False)
response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False)
return response.data[0] if response.status and response.data else None

async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optional[str] = None) -> None:
Expand All @@ -349,8 +338,7 @@ async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optio

async def cleanup(self) -> None:
"""Clean up all active connections and resources when server is shutting down"""
logger.info(
f"Cleaning up {len(self.active_connections)} active connections")
logger.info(f"Cleaning up {len(self.active_connections)} active connections")

try:
# First cancel all running tasks
Expand All @@ -361,8 +349,7 @@ async def cleanup(self) -> None:
if run and run.status == RunStatus.ACTIVE:
interrupted_result = TeamResult(
task_result=TaskResult(
messages=[TextMessage(
source="system", content="Run interrupted by server shutdown")],
messages=[TextMessage(source="system", content="Run interrupted by server shutdown")],
stop_reason="server_shutdown",
),
usage="",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export interface RequestUsage {
export interface ImageContent {
url: string;
alt?: string;
data?: string;
}

export interface FunctionCall {
Expand Down Expand Up @@ -95,15 +96,6 @@ export interface BaseConfig {
version?: string;
}

// WebSocket message types
export type ThreadStatus =
| "streaming"
| "complete"
| "error"
| "cancelled"
| "awaiting_input"
| "timeout";

export interface WebSocketMessage {
type: "message" | "result" | "completion" | "input_request" | "error";
data?: AgentMessageConfig | TaskResult;
Expand All @@ -119,7 +111,10 @@ export interface TaskResult {

export type ModelTypes = "OpenAIChatCompletionClient";

export type AgentTypes = "AssistantAgent" | "CodingAssistantAgent" | "MultimodalWebSurfer";
export type AgentTypes =
| "AssistantAgent"
| "CodingAssistantAgent"
| "MultimodalWebSurfer";

export type TeamTypes = "RoundRobinGroupChat" | "SelectorGroupChat";

Expand Down Expand Up @@ -197,14 +192,6 @@ export interface Run {
error_message?: string;
}

// Separate transient state
interface TransientRunState {
pendingInput?: {
prompt: string;
isPending: boolean;
};
}

export type RunStatus =
| "created"
| "active" // covers 'streaming'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
Bot,
Flag,
} from "lucide-react";
import { ThreadStatus } from "../../../../types/datamodel";
import { RunStatus } from "../../../../types/datamodel";

export type NodeType = "agent" | "user" | "end";

Expand All @@ -18,7 +18,7 @@ export interface AgentNodeData {
agentType?: string;
description?: string;
isActive?: boolean;
status?: ThreadStatus | null;
status?: RunStatus | null;
reason?: string | null;
draggable: boolean;
}
Expand Down Expand Up @@ -54,7 +54,7 @@ function AgentNode({ data, isConnectable }: AgentNodeProps) {
return <CheckCircle className="text-accent" size={24} />;
case "error":
return <AlertTriangle className="text-red-500" size={24} />;
case "cancelled":
case "stopped":
return <StopCircle className="text-red-500" size={24} />;
default:
return null;
Expand Down
Loading

0 comments on commit 2cbc905

Please sign in to comment.