Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weather forecast llm tool #137314

Draft
wants to merge 9 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
from homeassistant.components.homeassistant import async_should_expose
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.weather import INTENT_GET_WEATHER
from homeassistant.components.weather import (
DOMAIN as WEATHER_DOMAIN,
INTENT_GET_WEATHER,
SERVICE_GET_FORECASTS,
)
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_NAME,
ATTR_SERVICE,
ENTITY_MATCH_ALL,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
)
Expand Down Expand Up @@ -443,6 +450,9 @@
for intent_handler in intent_handlers
]

if exposed_domains and WEATHER_DOMAIN in exposed_domains:
tools.append(WeatherForecastTool())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure that Weather is no longer included in _get_exposed_entities

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weather entities report the current weather, and the tool reports the forecast, I think we need both

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change that though, include "current weather" in the tool as well. That way we don't always include all that info.


if exposed_entities:
if exposed_entities[CALENDAR_DOMAIN]:
names = []
Expand Down Expand Up @@ -745,7 +755,7 @@
"""Init the class."""
self._domain = domain
self._action = action
self.name = f"{domain}.{action}"
self.name = f"{domain}_{action}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found out that some models are not happy with dots in the tool name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to its own PR

self.description, self.parameters = _get_cached_action_parameters(
hass, domain, action
)
Expand Down Expand Up @@ -884,3 +894,48 @@
]

return {"success": True, "result": events}


class WeatherForecastTool(Tool):
"""LLM Tool wrapper for weather forecast action."""

name = f"{WEATHER_DOMAIN}_{SERVICE_GET_FORECASTS}"
description = "Get weather forecasts"
parameters = vol.Schema(
{
vol.Required("type"): vol.In(("daily", "hourly", "twice_daily")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we include twice daily?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the service capabilities. I can remove it.
Also while thinking about it I realized that I would also need to add a fallback strategy for entities that do not provide the requested type of forecast.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we shouldn't match the services as-is. I've noticed with the event tool that LLms are very easily confused.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if we should follow similar ranges like we do for events. today and week. Today would return the hourly forecast for next 24h, while week will return for next 7 days daily.

vol.Optional(ATTR_NAME, description="Weather entity name"): cv.string,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that for calendar, just limiting this to the actual allowed weather values gives a lot better results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, smart homes don't usually have too many weather entities. I will do it.

}
)

async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Get the forecast."""
data = self.parameters(tool_input.tool_args)
if ATTR_NAME in data:
result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=data[ATTR_NAME],
domains=[WEATHER_DOMAIN],
assistant=llm_context.assistant,
),
)
if not result.is_match:
return {"success": False, "error": "Weather entity not found"}

Check warning on line 926 in homeassistant/helpers/llm.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/helpers/llm.py#L926

Added line #L926 was not covered by tests
data.pop(ATTR_NAME)
data[ATTR_ENTITY_ID] = [state.entity_id for state in result.states]
else:
data[ATTR_ENTITY_ID] = ENTITY_MATCH_ALL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that we should do this. It's unpredictable.


service_result = await hass.services.async_call(
WEATHER_DOMAIN,
SERVICE_GET_FORECASTS,
data,
context=llm_context.context,
blocking=True,
return_response=True,
)

return {"success": True, "result": service_result}
146 changes: 145 additions & 1 deletion tests/helpers/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import voluptuous as vol

from homeassistant.components import calendar
from homeassistant.components import calendar, weather
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.script.config import ScriptConfig
Expand Down Expand Up @@ -1267,3 +1267,147 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
"start_date_time": now,
"end_date_time": dt_util.start_of_local_day() + timedelta(days=7),
}


async def test_weather_forecast_tool(hass: HomeAssistant) -> None:
"""Test the weather forecast tool."""
assert await async_setup_component(hass, "homeassistant", {})
hass.states.async_set(
"weather.test_weather",
"cloudy",
{"friendly_name": "Home", "supported_features": 3},
)
async_expose_entity(hass, "conversation", "weather.test_weather", True)
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
context=context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert [tool for tool in api.tools if tool.name == "weather_get_forecasts"]

forecast = {
"weather.test_weather": {
"forecast": [
{
"condition": "cloudy",
"datetime": dt_util.start_of_local_day().isoformat(),
"wind_bearing": 200.1,
"uv_index": 0.0,
"temperature": -0.4,
"templow": -0.7,
"wind_speed": 6.5,
"precipitation": 0.0,
"humidity": 76,
},
{
"condition": "partlycloudy",
"datetime": (
dt_util.start_of_local_day() + timedelta(days=1)
).isoformat(),
"wind_bearing": 206.0,
"uv_index": 0.6,
"temperature": 1.1,
"templow": -2.6,
"wind_speed": 8.3,
"precipitation": 0.0,
"humidity": 73,
},
{
"condition": "cloudy",
"datetime": (
dt_util.start_of_local_day() + timedelta(days=2)
).isoformat(),
"wind_bearing": 54.7,
"uv_index": 0.6,
"temperature": 0.2,
"templow": -2.5,
"wind_speed": 16.2,
"precipitation": 0.0,
"humidity": 75,
},
{
"condition": "cloudy",
"datetime": (
dt_util.start_of_local_day() + timedelta(days=3)
).isoformat(),
"wind_bearing": 81.0,
"uv_index": 0.7,
"temperature": 0.1,
"templow": -1.7,
"wind_speed": 8.3,
"precipitation": 0.0,
"humidity": 76,
},
{
"condition": "sunny",
"datetime": (
dt_util.start_of_local_day() + timedelta(days=4)
).isoformat(),
"wind_bearing": 76.6,
"temperature": -1.0,
"templow": -4.4,
"wind_speed": 12.2,
"precipitation": 0.0,
"humidity": 54,
},
{
"condition": "sunny",
"datetime": (
dt_util.start_of_local_day() + timedelta(days=5)
).isoformat(),
"wind_bearing": 87.7,
"temperature": -1.2,
"templow": -7.4,
"wind_speed": 9.0,
"precipitation": 0.0,
"humidity": 52,
},
]
}
}

calls = async_mock_service(
hass,
domain=weather.DOMAIN,
service=weather.SERVICE_GET_FORECASTS,
schema=cv.make_entity_service_schema(
{vol.Required("type"): vol.In(("daily", "hourly", "twice_daily"))}
),
response=forecast,
supports_response=SupportsResponse.ONLY,
)

tool_input = llm.ToolInput(
tool_name="weather_get_forecasts",
tool_args={"type": "daily", "name": "Home"},
)
response = await api.async_call_tool(tool_input)

assert len(calls) == 1
call = calls[0]
assert call.domain == weather.DOMAIN
assert call.service == weather.SERVICE_GET_FORECASTS
assert call.data == {
"entity_id": ["weather.test_weather"],
"type": "daily",
}

assert response == {
"success": True,
"result": forecast,
}

tool_input.tool_args.pop("name")
response = await api.async_call_tool(tool_input)

assert len(calls) == 2
call = calls[1]
assert call.data == {
"entity_id": "all",
"type": "daily",
}