Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bbot_server/api/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ def make_mcp_server(fastapi_app, config, mcp_endpoints=None):
if mcp_endpoints is None:
mcp_endpoints = MCP_ENDPOINTS
log.debug(f"Creating MCP server with endpoints: {','.join(mcp_endpoints)}")
mcp = FastApiMCP(fastapi_app, include_operations=list(mcp_endpoints))
mcp = FastApiMCP(fastapi_app, include_operations=list(mcp_endpoints), headers=["x-api-key"])
mcp.mount()
8 changes: 6 additions & 2 deletions bbot_server/applets/_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def wrapped_function(self):
"""

# Define a new async function that wraps the original function
@functools.wraps(self.orig_function)
async def wrapper(*args, **kwargs):
generator = self.orig_function(*args, **kwargs)

Expand All @@ -237,7 +236,12 @@ async def async_generator():

return StreamingResponse(async_generator())

# Set the wrapper's signature to match the original function
# Preserve the original function's name and signature for FastAPI routing,
# but do NOT use @functools.wraps as it preserves async generator markers
# that cause FastAPI to mishandle the StreamingResponse
wrapper.__name__ = self.orig_function.__name__
wrapper.__qualname__ = self.orig_function.__qualname__
wrapper.__signature__ = inspect.signature(self.orig_function)
return wrapper

def add_to_router(self, router, **fastapi_kwargs):
Expand Down
14 changes: 9 additions & 5 deletions bbot_server/modules/activity/activity_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def handle_activity(self, activity: Activity, asset: Asset = None):
await self.collection.insert_one(activity.model_dump())

@api_endpoint(
"/list", methods=["GET"], type="http_stream", response_model=Activity, summary="Stream all activities"
"/list", methods=["GET"], type="http_stream", response_model=Activity, summary="Stream all activities", mcp=True
)
async def list_activities(self, host: str = None, type: str = None):
query = {}
Expand All @@ -27,19 +27,23 @@ async def list_activities(self, host: str = None, type: str = None):
async for activity in self.collection.find(query, sort=[("timestamp", 1), ("created", 1)]):
yield self.model(**activity)

@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="List activities")
async def query_activities(self, query: ActivityQuery):
@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="List activities", mcp=True)
async def query_activities(self, query: ActivityQuery | None = None):
"""
Advanced querying of activities. Choose your own filters and fields.
"""
if query is None:
query = ActivityQuery()
async for activity in query.mongo_iter(self):
yield activity

@api_endpoint("/count", methods=["POST"], summary="Count activities")
async def count_activities(self, query: ActivityQuery) -> int:
@api_endpoint("/count", methods=["POST"], summary="Count activities", mcp=True)
async def count_activities(self, query: ActivityQuery | None = None) -> int:
"""
Same as query_activities, except only returns the count
"""
if query is None:
query = ActivityQuery()
return await query.mongo_count(self)

@api_endpoint("/tail", type="websocket_stream_outgoing", response_model=Activity)
Expand Down
8 changes: 4 additions & 4 deletions bbot_server/modules/agents/agents_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def get_agents(self) -> list[Agent]:
agents.append(agent)
return agents

@api_endpoint("/", methods=["POST"], summary="Create an agent")
@api_endpoint("/", methods=["POST"], summary="Create an agent", mcp=True)
async def create_agent(self, name: str, description: str = "") -> Agent:
agent = Agent(name=name, description=description)
try:
Expand All @@ -52,12 +52,12 @@ async def create_agent(self, name: str, description: str = "") -> Agent:
raise self.BBOTServerError(f"Error creating agent {name}: {e}") from e
return agent

@api_endpoint("/", methods=["DELETE"], summary="Delete an agent")
@api_endpoint("/", methods=["DELETE"], summary="Delete an agent", mcp=True)
async def delete_agent(self, id: str):
agent = await self.get_agent(id)
await self.collection.delete_one({"id": str(agent.id)})

@api_endpoint("/", methods=["GET"], summary="Get an agent by its id")
@api_endpoint("/", methods=["GET"], summary="Get an agent by its id", mcp=True)
async def get_agent(self, id: str) -> Agent:
try:
query = {"id": str(UUID(str(id)))}
Expand Down Expand Up @@ -92,7 +92,7 @@ async def get_agent_status(
agent_status = {"agent_status": "OFFLINE", "scan_status": "UNKNOWN"}
return agent_status

@api_endpoint("/scan_status", methods=["GET"], summary="Get the status of an agent's scan")
@api_endpoint("/scan_status", methods=["GET"], summary="Get the status of an agent's scan", mcp=True)
async def get_scan_status(self, id: UUID, detailed: bool = False) -> dict[str, str]:
command_response = await self.connection_manager.execute_command(
str(id), "get_scan_status", timeout=10, detailed=detailed
Expand Down
14 changes: 9 additions & 5 deletions bbot_server/modules/assets/assets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AssetsApplet(BaseApplet):

model = Asset

@api_endpoint("/list", methods=["GET"], type="http_stream", response_model=Asset, summary="Stream all assets")
@api_endpoint("/list", methods=["GET"], type="http_stream", response_model=Asset, summary="Stream all assets", mcp=True)
async def list_assets(
self,
domain: Annotated[str, Query(description="Filter assets by domain or subdomain")] = None,
Expand All @@ -27,22 +27,26 @@ async def list_assets(
async for asset in query.mongo_iter(self):
yield self.model(**asset)

@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="Query assets")
@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="Query assets", mcp=True)
async def query_assets(self, query: AdvancedAssetQuery | None = None):
"""
Advanced querying of assets. Choose your own filters and fields.
"""
if query is None:
query = AdvancedAssetQuery()
async for asset in query.mongo_iter(self):
yield asset

@api_endpoint("/count", methods=["POST"], summary="Count assets")
@api_endpoint("/count", methods=["POST"], summary="Count assets", mcp=True)
async def count_assets(self, query: AdvancedAssetQuery | None = None) -> int:
"""
Same as query_assets, except only returns the count
"""
if query is None:
query = AdvancedAssetQuery()
return await query.mongo_count(self)

@api_endpoint("/{host}/detail", methods=["GET"], summary="Get a single asset by its host")
@api_endpoint("/{host}/detail", methods=["GET"], summary="Get a single asset by its host", mcp=True)
async def get_asset(self, host: Annotated[str, Path(description="The host of the asset to get")]) -> Asset:
asset = await self.collection.find_one({"host": host})
if not asset:
Expand All @@ -63,7 +67,7 @@ async def get_asset_history(self, host: str) -> list[str]:
history.append(activity["description"])
return history

@api_endpoint("/hosts", methods=["GET"], summary="List hosts")
@api_endpoint("/hosts", methods=["GET"], summary="List hosts", mcp=True)
async def get_hosts(self, domain: str = None, target_id: str = None) -> list[str]:
"""
List all hosts.
Expand Down
1 change: 1 addition & 0 deletions bbot_server/modules/cloud/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def get_cloud_providers_for_asset(self, host: str) -> list[dict[str, str]]
"/check/{host}",
methods=["GET"],
summary="Check a hostname or IP address against the cloud provider database",
mcp=True,
)
async def cloudcheck(self, host: str) -> list[dict[str, str]]:
# cloudcheck v9+ uses async API and handles caching internally
Expand Down
17 changes: 9 additions & 8 deletions bbot_server/modules/emails/emails_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# from bbot_server.workers.emails import EmailWorker
from bbot_server.applets.base import BaseApplet, api_endpoint, BaseModel, Field
from bbot_server.assets import CustomAssetFields
from bbot_server.applets.base import BaseApplet, api_endpoint, Field


class EmailsFields(CustomAssetFields):
emails: list[str] = Field(default_factory=list)


class EmailsApplet(BaseApplet):
Expand All @@ -10,15 +15,11 @@ class EmailsApplet(BaseApplet):
# workers = [EmailWorker]
attach_to = "assets"

class AssetFields(BaseModel):
emails: list[str] = Field(default_factory=list)

@api_endpoint("/emails/{domain}", methods=["GET"], summary="Get emails by domain")
@api_endpoint("/emails/{domain}", methods=["GET"], summary="Get emails by domain", mcp=True)
async def get_emails(self, domain: str) -> list[str]:
matching_assets = await self.root.assets.list_assets(host=domain)
emails = set()
for asset in matching_assets:
emails.update(asset.fields.get("emails", []))
async for asset in self.root.assets.list_assets(domain=domain):
emails.update(getattr(asset, "emails", []))
return sorted(emails)

# async def handle_event(self, asset: Asset, event: Event) -> list[Activity]:
Expand Down
16 changes: 10 additions & 6 deletions bbot_server/modules/events/events_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def handle_event(self, event: Event, asset):
# write the event to the database
await self.collection.insert_one(event.model_dump())

@api_endpoint("/", methods=["POST"], summary="Insert a BBOT event into the asset database")
@api_endpoint("/", methods=["POST"], summary="Insert a BBOT event into the asset database", mcp=True)
async def insert_event(self, event: Event):
"""
Insert a BBOT event into the asset database
Expand All @@ -32,14 +32,14 @@ async def insert_event(self, event: Event):
# it will be picked up by the worker and ingested
await self.root.message_queue.publish_event(event)

@api_endpoint("/get/{uuid}", methods=["GET"], summary="Get an event by its UUID")
@api_endpoint("/get/{uuid}", methods=["GET"], summary="Get an event by its UUID", mcp=True)
async def get_event(self, uuid: str) -> Event:
event = await self.collection.find_one({"uuid": uuid})
if event is None:
raise self.BBOTServerNotFoundError(f"Event {uuid} not found")
return Event(**event)

@api_endpoint("/list", methods=["GET"], type="http_stream", response_model=Event, summary="Stream all events")
@api_endpoint("/list", methods=["GET"], type="http_stream", response_model=Event, summary="Stream all events", mcp=True)
async def list_events(
self,
type: str = None,
Expand All @@ -64,27 +64,31 @@ async def list_events(
async for event in query.mongo_iter(self):
yield Event(**event)

@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="Query events")
@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="Query events", mcp=True)
async def query_events(self, query: EventsQuery | None = None):
"""
Advanced querying of events. Choose your own filters and fields.
"""
if query is None:
query = EventsQuery()
async for event in query.mongo_iter(self):
yield event

@api_endpoint("/count", methods=["POST"], summary="Count events")
@api_endpoint("/count", methods=["POST"], summary="Count events", mcp=True)
async def count_events(self, query: EventsQuery | None = None) -> int:
"""
Same as query_events, except only returns the count
"""
if query is None:
query = EventsQuery()
return await query.mongo_count(self)

@api_endpoint("/tail", type="websocket_stream_outgoing", response_model=Event)
async def tail_events(self, n: int = 0):
async for event in self.message_queue.tail_events(n=n):
yield event

@api_endpoint("/archive", methods=["POST"], summary="Archive old events")
@api_endpoint("/archive", methods=["POST"], summary="Archive old events", mcp=True)
async def archive_old_events(
self,
older_than: Annotated[int, Query(description="Archive events older than this many days")],
Expand Down
10 changes: 7 additions & 3 deletions bbot_server/modules/findings/findings_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,23 @@ async def list_findings(
async for finding in query.mongo_iter(self):
yield Finding(**finding)

@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="Query findings")
@api_endpoint("/query", methods=["POST"], type="http_stream", response_model=dict, summary="Query findings", mcp=True)
async def query_findings(self, query: FindingsQuery | None = None):
"""
Advanced querying of findings. Choose your own filters and fields.
"""
if query is None:
query = FindingsQuery()
async for finding in query.mongo_iter(self):
yield finding

@api_endpoint("/count", methods=["POST"], summary="Count findings")
@api_endpoint("/count", methods=["POST"], summary="Count findings", mcp=True)
async def count_findings(self, query: FindingsQuery | None = None) -> int:
"""
Same as query_findings, except only returns the count
"""
if query is None:
query = FindingsQuery()
return await query.mongo_count(self)

@api_endpoint(
Expand Down Expand Up @@ -152,7 +156,7 @@ async def severity_counts(
findings = dict(sorted(findings.items(), key=lambda x: x[1], reverse=True))
return findings

@api_endpoint("/set_risk", methods=["PATCH"], summary="Set or clear a manual risk score for an asset")
@api_endpoint("/set_risk", methods=["PATCH"], summary="Set or clear a manual risk score for an asset", mcp=True)
async def set_risk(
self,
host: Annotated[str, Query(description="The host of the asset to update")],
Expand Down
4 changes: 2 additions & 2 deletions bbot_server/modules/open_ports/open_ports_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def compute_stats(self, asset, statistics):
open_ports_stats = dict(sorted(open_ports_stats.items(), key=lambda x: x[1], reverse=True))
statistics["open_ports"] = open_ports_stats

@api_endpoint("/list", methods=["GET"], summary="Get all the open ports for all hosts")
@api_endpoint("/list", methods=["GET"], summary="Get all the open ports for all hosts", mcp=True)
async def get_open_ports(self, domain: str = None, target_id: str = None) -> dict[str, list[int]]:
open_ports = {}
query = AssetQuery(
Expand All @@ -60,7 +60,7 @@ async def get_open_ports(self, domain: str = None, target_id: str = None) -> dic
open_ports[asset["host"]] = asset["open_ports"]
return open_ports

@api_endpoint("/list/{host}", methods=["GET"], summary="Get all the open ports for a host")
@api_endpoint("/list/{host}", methods=["GET"], summary="Get all the open ports for a host", mcp=True)
async def get_open_ports_by_host(self, host: str) -> list[int]:
asset = await self.collection.find_one({"host": str(host), "type": "Asset"}, {"open_ports": 1})
if asset is None:
Expand Down
10 changes: 5 additions & 5 deletions bbot_server/modules/presets/presets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PresetsApplet(BaseApplet):
model = Preset
attach_to = "scans"

@api_endpoint("/get/{preset_id}", methods=["GET"], summary="Get a preset by its name or id")
@api_endpoint("/get/{preset_id}", methods=["GET"], summary="Get a preset by its name or id", mcp=True)
async def get_preset(self, preset_id: UUID | str) -> Preset:
try:
query = {"id": str(UUID(str(preset_id)))}
Expand All @@ -23,12 +23,12 @@ async def get_preset(self, preset_id: UUID | str) -> Preset:
raise self.BBOTServerNotFoundError(f"Preset not found: {query}")
return Preset(**preset)

@api_endpoint("/list", methods=["GET"], summary="List all presets")
@api_endpoint("/list", methods=["GET"], summary="List all presets", mcp=True)
async def get_presets(self) -> list[Preset]:
presets = await self.collection.find().to_list(length=None)
return [Preset(**preset) for preset in presets]

@api_endpoint("/create", methods=["POST"], summary="Create a new preset")
@api_endpoint("/create", methods=["POST"], summary="Create a new preset", mcp=True)
async def create_preset(self, preset: dict[str, Any]) -> Preset:
preset = Preset(preset=preset)
if not preset.name:
Expand All @@ -39,7 +39,7 @@ async def create_preset(self, preset: dict[str, Any]) -> Preset:
raise self.BBOTServerValueError(f"Preset with name '{preset.name}' already exists")
return preset

@api_endpoint("/update/{preset_id}", methods=["PATCH"], summary="Update a preset by its name or id")
@api_endpoint("/update/{preset_id}", methods=["PATCH"], summary="Update a preset by its name or id", mcp=True)
async def update_preset(self, preset_id: UUID | str, preset: dict[str, Any]) -> Preset:
existing_preset = await self.get_preset(preset_id)
# Create new preset with the updated dictionary
Expand All @@ -54,7 +54,7 @@ async def update_preset(self, preset_id: UUID | str, preset: dict[str, Any]) ->
raise self.BBOTServerValueError(f"Preset with name '{new_preset.name}' already exists")
return new_preset

@api_endpoint("/delete/{preset_id}", methods=["DELETE"], summary="Delete a preset by its name or id")
@api_endpoint("/delete/{preset_id}", methods=["DELETE"], summary="Delete a preset by its name or id", mcp=True)
async def delete_preset(self, preset_id: UUID | str) -> None:
existing_preset = await self.get_preset(preset_id)
await self.collection.delete_one({"id": str(existing_preset.id)})
Expand Down
Loading
Loading