Skip to content

Commit 8a8514d

Browse files
authored
chore: Update A2A CLI for SDK 0.2.16 (#259)
* Remove dependency on common for cli * remove other usage of common * Update to snake_case
1 parent bcbf73f commit 8a8514d

File tree

6 files changed

+1743
-1582
lines changed

6 files changed

+1743
-1582
lines changed

‎samples/python/hosts/cli/README.md‎

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
## CLI
1+
# A2A CLI
22

3-
The CLI is a small host application that demonstrates the capabilities of an A2AClient. It supports reading a server's AgentCard and text-based collaboration with a remote agent. All content received from the A2A server is printed to the console.
3+
The CLI is a small host application that demonstrates the capabilities of an `A2AClient`. It supports reading a server's `AgentCard` and text-based collaboration with a remote agent. All content received from the A2A server is printed to the console.
44

55
The client will use streaming if the server supports it.
66

@@ -13,19 +13,23 @@ The client will use streaming if the server supports it.
1313
## Running the CLI
1414

1515
1. Navigate to the CLI sample directory:
16+
1617
```bash
1718
cd samples/python/hosts/cli
1819
```
20+
1921
2. Run the example client
20-
```
22+
23+
```sh
2124
uv run . --agent [url-of-your-a2a-server]
2225
```
2326

24-
for example `--agent http://localhost:10000`. More command line options are documented in the source code.
27+
for example `--agent https://sample-a2a-agent-908687846511.us-central1.run.app`. More command line options are documented in the source code.
2528

2629
## Disclaimer
30+
2731
Important: The sample code provided is for demonstration purposes and illustrates the mechanics of the Agent-to-Agent (A2A) protocol. When building production applications, it is critical to treat any agent operating outside of your direct control as a potentially untrusted entity.
2832

2933
All data received from an external agent—including but not limited to its AgentCard, messages, artifacts, and task statuses—should be handled as untrusted input. For example, a malicious agent could provide an AgentCard containing crafted data in its fields (e.g., description, name, skills.description). If this data is used without sanitization to construct prompts for a Large Language Model (LLM), it could expose your application to prompt injection attacks. Failure to properly validate and sanitize this data before use can introduce security vulnerabilities into your application.
3034

31-
Developers are responsible for implementing appropriate security measures, such as input validation and secure handling of credentials to protect their systems and users.
35+
Developers are responsible for implementing appropriate security measures, such as input validation and secure handling of credentials to protect their systems and users.

‎samples/python/hosts/cli/__main__.py‎

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import base64
33
import os
44
import urllib
5+
56
from uuid import uuid4
67

78
import asyncclick as click
@@ -26,12 +27,16 @@
2627
TaskStatusUpdateEvent,
2728
TextPart,
2829
)
29-
from common.utils.push_notification_auth import PushNotificationReceiverAuth
30+
from push_notification_auth import PushNotificationReceiverAuth
3031

3132

3233
@click.command()
3334
@click.option('--agent', default='http://localhost:10000')
34-
@click.option('--bearer-token', help='Bearer token for authentication.', envvar='A2A_CLI_BEARER_TOKEN')
35+
@click.option(
36+
'--bearer-token',
37+
help='Bearer token for authentication.',
38+
envvar='A2A_CLI_BEARER_TOKEN',
39+
)
3540
@click.option('--session', default=0)
3641
@click.option('--history', default=False)
3742
@click.option('--use_push_notifications', default=False)
@@ -88,7 +93,7 @@ async def cli(
8893

8994
while continue_loop:
9095
print('========= starting a new task ======== ')
91-
continue_loop, _, taskId = await completeTask(
96+
continue_loop, _, task_id = await completeTask(
9297
client,
9398
streaming,
9499
use_push_notifications,
@@ -101,7 +106,7 @@ async def cli(
101106
if history and continue_loop:
102107
print('========= history ======== ')
103108
task_response = await client.get_task(
104-
{'id': taskId, 'historyLength': 10}
109+
{'id': task_id, 'historyLength': 10}
105110
)
106111
print(
107112
task_response.model_dump_json(
@@ -116,8 +121,8 @@ async def completeTask(
116121
use_push_notifications: bool,
117122
notification_receiver_host: str,
118123
notification_receiver_port: int,
119-
taskId,
120-
contextId,
124+
task_id,
125+
context_id,
121126
):
122127
prompt = click.prompt(
123128
'\nWhat do you want to send to the agent? (:q or quit to exit)'
@@ -128,9 +133,9 @@ async def completeTask(
128133
message = Message(
129134
role='user',
130135
parts=[TextPart(text=prompt)],
131-
messageId=str(uuid4()),
132-
taskId=taskId,
133-
contextId=contextId,
136+
message_id=str(uuid4()),
137+
task_id=task_id,
138+
context_id=context_id,
134139
)
135140

136141
file_path = click.prompt(
@@ -155,7 +160,7 @@ async def completeTask(
155160
id=str(uuid4()),
156161
message=message,
157162
configuration=MessageSendConfiguration(
158-
acceptedOutputModes=['text'],
163+
accepted_output_modes=['text'],
159164
),
160165
)
161166

@@ -179,32 +184,39 @@ async def completeTask(
179184
)
180185
async for result in response_stream:
181186
if isinstance(result.root, JSONRPCErrorResponse):
182-
print(f'Error: {result.root.error}, contextId: {contextId}, taskId: {taskId}')
183-
return False, contextId, taskId
187+
print(
188+
f'Error: {result.root.error}, context_id: {context_id}, task_id: {task_id}'
189+
)
190+
return False, context_id, task_id
184191
event = result.root.result
185-
contextId = event.contextId
192+
context_id = event.context_id
186193
if isinstance(event, Task):
187-
taskId = event.id
194+
task_id = event.id
188195
elif isinstance(event, TaskStatusUpdateEvent) or isinstance(
189196
event, TaskArtifactUpdateEvent
190197
):
191-
taskId = event.taskId
192-
if isinstance(event, TaskStatusUpdateEvent) and event.status.state == 'completed':
198+
task_id = event.task_id
199+
if (
200+
isinstance(event, TaskStatusUpdateEvent)
201+
and event.status.state == 'completed'
202+
):
193203
task_completed = True
194204
elif isinstance(event, Message):
195205
message = event
196206
print(f'stream event => {event.model_dump_json(exclude_none=True)}')
197207
# Upon completion of the stream. Retrieve the full task if one was made.
198-
if taskId and not task_completed:
208+
if task_id and not task_completed:
199209
taskResultResponse = await client.get_task(
200210
GetTaskRequest(
201211
id=str(uuid4()),
202-
params=TaskQueryParams(id=taskId),
212+
params=TaskQueryParams(id=task_id),
203213
)
204214
)
205215
if isinstance(taskResultResponse.root, JSONRPCErrorResponse):
206-
print(f'Error: {taskResultResponse.root.error}, contextId: {contextId}, taskId: {taskId}')
207-
return False, contextId, taskId
216+
print(
217+
f'Error: {taskResultResponse.root.error}, context_id: {context_id}, task_id: {task_id}'
218+
)
219+
return False, context_id, task_id
208220
taskResult = taskResultResponse.root.result
209221
else:
210222
try:
@@ -218,18 +230,18 @@ async def completeTask(
218230
event = event.root.result
219231
except Exception as e:
220232
print('Failed to complete the call', e)
221-
if not contextId:
222-
contextId = event.contextId
233+
if not context_id:
234+
context_id = event.context_id
223235
if isinstance(event, Task):
224-
if not taskId:
225-
taskId = event.id
236+
if not task_id:
237+
task_id = event.id
226238
taskResult = event
227239
elif isinstance(event, Message):
228240
message = event
229241

230242
if message:
231243
print(f'\n{message.model_dump_json(exclude_none=True)}')
232-
return True, contextId, taskId
244+
return True, context_id, task_id
233245
if taskResult:
234246
# Don't print the contents of a file.
235247
task_content = taskResult.model_dump_json(
@@ -248,19 +260,23 @@ async def completeTask(
248260
## if the result is that more input is required, loop again.
249261
state = TaskState(taskResult.status.state)
250262
if state.name == TaskState.input_required.name:
251-
return await completeTask(
252-
client,
253-
streaming,
254-
use_push_notifications,
255-
notification_receiver_host,
256-
notification_receiver_port,
257-
taskId,
258-
contextId,
259-
), contextId, taskId
263+
return (
264+
await completeTask(
265+
client,
266+
streaming,
267+
use_push_notifications,
268+
notification_receiver_host,
269+
notification_receiver_port,
270+
task_id,
271+
context_id,
272+
),
273+
context_id,
274+
task_id,
275+
)
260276
## task is complete
261-
return True, contextId, taskId
277+
return True, context_id, task_id
262278
## Failure case, shouldn't reach
263-
return True, contextId, taskId
279+
return True, context_id, task_id
264280

265281

266282
if __name__ == '__main__':
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import hashlib
2+
import json
3+
import logging
4+
import time
5+
import uuid
6+
7+
from typing import Any
8+
9+
import httpx
10+
import jwt
11+
12+
from jwcrypto import jwk
13+
from jwt import PyJWK, PyJWKClient
14+
from starlette.requests import Request
15+
from starlette.responses import JSONResponse
16+
17+
18+
logger = logging.getLogger(__name__)
19+
AUTH_HEADER_PREFIX = 'Bearer '
20+
21+
22+
class PushNotificationAuth:
23+
def _calculate_request_body_sha256(self, data: dict[str, Any]):
24+
"""Calculates the SHA256 hash of a request body.
25+
26+
This logic needs to be same for both the agent who signs the payload and the client verifier.
27+
"""
28+
body_str = json.dumps(
29+
data,
30+
ensure_ascii=False,
31+
allow_nan=False,
32+
indent=None,
33+
separators=(',', ':'),
34+
)
35+
return hashlib.sha256(body_str.encode()).hexdigest()
36+
37+
38+
class PushNotificationSenderAuth(PushNotificationAuth):
39+
def __init__(self):
40+
self.public_keys = []
41+
self.private_key_jwk: PyJWK = None
42+
43+
@staticmethod
44+
async def verify_push_notification_url(url: str) -> bool:
45+
async with httpx.AsyncClient(timeout=10) as client:
46+
try:
47+
validation_token = str(uuid.uuid4())
48+
response = await client.get(
49+
url, params={'validationToken': validation_token}
50+
)
51+
response.raise_for_status()
52+
is_verified = response.text == validation_token
53+
54+
logger.info(
55+
f'Verified push-notification URL: {url} => {is_verified}'
56+
)
57+
return is_verified
58+
except Exception as e:
59+
logger.warning(
60+
f'Error during sending push-notification for URL {url}: {e}'
61+
)
62+
63+
return False
64+
65+
def generate_jwk(self):
66+
key = jwk.JWK.generate(
67+
kty='RSA', size=2048, kid=str(uuid.uuid4()), use='sig'
68+
)
69+
self.public_keys.append(key.export_public(as_dict=True))
70+
self.private_key_jwk = PyJWK.from_json(key.export_private())
71+
72+
def handle_jwks_endpoint(self, _request: Request):
73+
"""Allow clients to fetch public keys."""
74+
return JSONResponse({'keys': self.public_keys})
75+
76+
def _generate_jwt(self, data: dict[str, Any]):
77+
"""JWT is generated by signing both the request payload SHA digest and time of token generation.
78+
79+
Payload is signed with private key and it ensures the integrity of payload for client.
80+
Including iat prevents from replay attack.
81+
"""
82+
iat = int(time.time())
83+
84+
return jwt.encode(
85+
{
86+
'iat': iat,
87+
'request_body_sha256': self._calculate_request_body_sha256(
88+
data
89+
),
90+
},
91+
key=self.private_key_jwk,
92+
headers={'kid': self.private_key_jwk.key_id},
93+
algorithm='RS256',
94+
)
95+
96+
async def send_push_notification(self, url: str, data: dict[str, Any]):
97+
jwt_token = self._generate_jwt(data)
98+
headers = {'Authorization': f'Bearer {jwt_token}'}
99+
async with httpx.AsyncClient(timeout=10) as client:
100+
try:
101+
response = await client.post(url, json=data, headers=headers)
102+
response.raise_for_status()
103+
logger.info(f'Push-notification sent for URL: {url}')
104+
except Exception as e:
105+
logger.warning(
106+
f'Error during sending push-notification for URL {url}: {e}'
107+
)
108+
109+
110+
class PushNotificationReceiverAuth(PushNotificationAuth):
111+
def __init__(self):
112+
self.public_keys_jwks = []
113+
self.jwks_client = None
114+
115+
async def load_jwks(self, jwks_url: str):
116+
self.jwks_client = PyJWKClient(jwks_url)
117+
118+
async def verify_push_notification(self, request: Request) -> bool:
119+
auth_header = request.headers.get('Authorization')
120+
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
121+
print('Invalid authorization header')
122+
return False
123+
124+
token = auth_header[len(AUTH_HEADER_PREFIX) :]
125+
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
126+
127+
decode_token = jwt.decode(
128+
token,
129+
signing_key,
130+
options={'require': ['iat', 'request_body_sha256']},
131+
algorithms=['RS256'],
132+
)
133+
134+
actual_body_sha256 = self._calculate_request_body_sha256(
135+
await request.json()
136+
)
137+
if actual_body_sha256 != decode_token['request_body_sha256']:
138+
# Payload signature does not match the digest in signed token.
139+
raise ValueError('Invalid request body')
140+
141+
if time.time() - decode_token['iat'] > 60 * 5:
142+
# Do not allow push-notifications older than 5 minutes.
143+
# This is to prevent replay attack.
144+
raise ValueError('Token is expired')
145+
146+
return True

‎samples/python/hosts/cli/push_notification_listener.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import threading
33
import traceback
44

5-
from common.utils.push_notification_auth import PushNotificationReceiverAuth
5+
from push_notification_auth import PushNotificationReceiverAuth
66
from starlette.applications import Starlette
77
from starlette.requests import Request
88
from starlette.responses import Response

‎samples/python/hosts/cli/pyproject.toml‎

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "A CLI application that demonstrates the capabilities of an A2ACli
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8-
"a2a-samples",
8+
"a2a-sdk",
99
"asyncclick>=8.1.8",
1010
"sse-starlette>=2.2.1",
1111
"starlette>=0.46.1",
@@ -14,9 +14,6 @@ dependencies = [
1414
[tool.hatch.build.targets.wheel]
1515
packages = ["."]
1616

17-
[tool.uv.sources]
18-
a2a-samples = { workspace = true }
19-
2017
[build-system]
2118
requires = ["hatchling"]
2219
build-backend = "hatchling.build"

0 commit comments

Comments
 (0)