-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathapp.py
More file actions
171 lines (134 loc) · 5.13 KB
/
app.py
File metadata and controls
171 lines (134 loc) · 5.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import asynccontextmanager
from typing import Callable
from agent import agent_run_config
from fastapi import FastAPI
from fastapi.routing import APIRoute
from fastmcp import FastMCP
from starlette.routing import Route
from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
from a2a.types import AgentProvider
from veadk.a2a.ve_a2a_server import init_app
from veadk.runner import Runner
from veadk.types import AgentRunConfig
from veadk.utils.logger import get_logger
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry import context
logger = get_logger(__name__)
assert isinstance(agent_run_config, AgentRunConfig), (
f"Invalid agent_run_config type: {type(agent_run_config)}, expected `AgentRunConfig`"
)
app_name = agent_run_config.app_name
agent = agent_run_config.agent
short_term_memory = agent_run_config.short_term_memory
VEFAAS_REGION = os.getenv("APP_REGION", "cn-beijing")
VEFAAS_FUNC_ID = os.getenv("_FAAS_FUNC_ID", "")
agent_card_builder = AgentCardBuilder(agent=agent, provider=AgentProvider(organization="Volcengine Agent Development Kit (VeADK)", url=f"https://console.volcengine.com/vefaas/region:vefaas+{VEFAAS_REGION}/function/detail/{VEFAAS_FUNC_ID}"))
def build_mcp_run_agent_func() -> Callable:
runner = Runner(
agent=agent,
short_term_memory=short_term_memory,
app_name=app_name,
user_id="",
)
async def run_agent(
user_input: str,
user_id: str = "mcp_user",
session_id: str = "mcp_session",
) -> str:
# Set user_id for runner
runner.user_id = user_id
# Running agent and get final output
final_output = await runner.run(
messages=user_input,
session_id=session_id,
)
return final_output
run_agent_doc = f"""{agent.description}
Args:
user_input: User's input message (required).
user_id: User identifier. Defaults to "mcp_user".
session_id: Session identifier. Defaults to "mcp_session".
Returns:
Final agent response as a string."""
run_agent.__doc__ = run_agent_doc
return run_agent
async def agent_card() -> dict:
agent_card = await agent_card_builder.build()
return agent_card.model_dump()
async def get_cozeloop_space_id() -> dict:
return {"space_id": os.getenv("OBSERVABILITY_OPENTELEMETRY_COZELOOP_SERVICE_NAME", default="")}
# Build a run_agent function for building MCP server
run_agent_func = build_mcp_run_agent_func()
a2a_app = init_app(
server_url="0.0.0.0",
app_name=app_name,
agent=agent,
short_term_memory=short_term_memory,
)
a2a_app.post("/run_agent", operation_id="run_agent", tags=["mcp"])(run_agent_func)
a2a_app.get("/agent_card", operation_id="agent_card", tags=["mcp"])(agent_card)
a2a_app.get("/get_cozeloop_space_id", operation_id="get_cozeloop_space_id", tags=["mcp"])(get_cozeloop_space_id)
# === Build mcp server ===
mcp = FastMCP.from_fastapi(app=a2a_app, name=app_name, include_tags={"mcp"})
# Create MCP ASGI app
mcp_app = mcp.http_app(path="/", transport="streamable-http")
# Combined lifespan management
@asynccontextmanager
async def combined_lifespan(app: FastAPI):
async with mcp_app.lifespan(app):
yield
# Create main FastAPI app with combined lifespan
app = FastAPI(
title=a2a_app.title,
version=a2a_app.version,
lifespan=combined_lifespan,
openapi_url=None,
docs_url=None,
redoc_url=None
)
@app.middleware("http")
async def otel_context_middleware(request, call_next):
carrier = {
"traceparent": request.headers.get("Traceparent"),
"tracestate": request.headers.get("Tracestate"),
}
logger.debug(f"carrier: {carrier}")
if carrier["traceparent"] is None:
return await call_next(request)
else:
ctx = TraceContextTextMapPropagator().extract(carrier=carrier)
logger.debug(f"ctx: {ctx}")
token = context.attach(ctx)
try:
response = await call_next(request)
finally:
context.detach(token)
return response
# Mount A2A routes to main app
for route in a2a_app.routes:
app.routes.append(route)
# Mount MCP server at /mcp endpoint
app.mount("/mcp", mcp_app)
# remove openapi routes
paths = ["/openapi.json", "/docs", "/redoc"]
new_routes = []
for route in app.router.routes:
if isinstance(route, (APIRoute, Route)) and route.path in paths:
continue
new_routes.append(route)
app.router.routes = new_routes
# === Build mcp server end ===