Skip to content

Commit

Permalink
Fix LLM parameter propagation and refactor BrowserUseTool (#1033)
Browse files Browse the repository at this point in the history
* wip

* Fix browser-use and crawl4ai tools when using ollama with client_host parameter

* wip

* wip

* Browser use tool refactoring WIP

* Browser use tool refactoring WIP

* Fix pre-commit

* wip

* wip

* wip

* wip

* Fix mypy

* WIP

* Fix ci

* Fix ci

* fixes

* fixes

* fixes

* Fix import_utils.py for python3.9

* Fix tests

---------

Co-authored-by: Davor Runje <[email protected]>
  • Loading branch information
rjambrecic and davorrunje authored Feb 20, 2025
1 parent 5ac5fe1 commit 812acf7
Show file tree
Hide file tree
Showing 9 changed files with 501 additions and 194 deletions.
12 changes: 1 addition & 11 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -1187,16 +1187,6 @@
"is_secret": false
}
],
"test/tools/experimental/browser_use/test_browser_use.py": [
{
"type": "Secret Keyword",
"filename": "test/tools/experimental/browser_use/test_browser_use.py",
"hashed_secret": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3",
"is_verified": false,
"line_number": 47,
"is_secret": false
}
],
"test/tools/experimental/crawl4ai/test_crawl4ai.py": [
{
"type": "Secret Keyword",
Expand Down Expand Up @@ -1616,5 +1606,5 @@
}
]
},
"generated_at": "2025-02-19T11:06:40Z"
"generated_at": "2025-02-19T12:25:16Z"
}
71 changes: 58 additions & 13 deletions autogen/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, o: T, missing_modules: Iterable[str], dep_target: str):
def accept(cls, o: Any) -> bool: ...

@abstractmethod
def patch(self) -> T: ...
def patch(self, except_for: Iterable[str]) -> T: ...

def get_object_with_metadata(self) -> Any:
return self.o
Expand Down Expand Up @@ -122,7 +122,13 @@ def decorator(subclass: type["PatchObject[Any]"]) -> type["PatchObject[Any]"]:
return decorator

@classmethod
def create(cls, o: T, *, missing_modules: Iterable[str], dep_target: str) -> Optional["PatchObject[T]"]:
def create(
cls,
o: T,
*,
missing_modules: Iterable[str],
dep_target: str,
) -> Optional["PatchObject[T]"]:
for subclass in cls._registry:
if subclass.accept(o):
return subclass(o, missing_modules, dep_target)
Expand All @@ -135,7 +141,10 @@ class PatchCallable(PatchObject[F]):
def accept(cls, o: Any) -> bool:
return inspect.isfunction(o) or inspect.ismethod(o)

def patch(self) -> F:
def patch(self, except_for: Iterable[str]) -> F:
if self.o.__name__ in except_for:
return self.o

f: Callable[..., Any] = self.o

@wraps(f.__call__) # type: ignore[operator]
Expand All @@ -154,7 +163,16 @@ def accept(cls, o: Any) -> bool:
# return inspect.ismethoddescriptor(o)
return isinstance(o, staticmethod)

def patch(self) -> F:
def patch(self, except_for: Iterable[str]) -> F:
if hasattr(self.o, "__name__"):
name = self.o.__name__
elif hasattr(self.o, "__func__"):
name = self.o.__func__.__name__
else:
raise ValueError(f"Cannot determine name for object {self.o}")
if name in except_for:
return self.o

f: Callable[..., Any] = self.o.__func__ # type: ignore[attr-defined]

@wraps(f)
Expand All @@ -175,7 +193,10 @@ class PatchInit(PatchObject[F]):
def accept(cls, o: Any) -> bool:
return inspect.ismethoddescriptor(o) and o.__name__ == "__init__"

def patch(self) -> F:
def patch(self, except_for: Iterable[str]) -> F:
if self.o.__name__ in except_for:
return self.o

f: Callable[..., Any] = self.o

@wraps(f)
Expand All @@ -196,11 +217,14 @@ class PatchProperty(PatchObject[Any]):
def accept(cls, o: Any) -> bool:
return inspect.isdatadescriptor(o) and hasattr(o, "fget")

def patch(self) -> property:
def patch(self, except_for: Iterable[str]) -> property:
if not hasattr(self.o, "fget"):
raise ValueError(f"Cannot patch property without getter: {self.o}")
f: Callable[..., Any] = self.o.fget

if f.__name__ in except_for:
return self.o # type: ignore[no-any-return]

@wraps(f)
def _call(*args: Any, **kwargs: Any) -> Any:
raise ImportError(self.msg)
Expand All @@ -219,30 +243,51 @@ class PatchClass(PatchObject[type[Any]]):
def accept(cls, o: Any) -> bool:
return inspect.isclass(o)

def patch(self) -> type[Any]:
# Patch __init__ method if possible
def patch(self, except_for: Iterable[str]) -> type[Any]:
if self.o.__name__ in except_for:
return self.o

for name, member in inspect.getmembers(self.o):
# Patch __init__ method if possible, but not other internal methods
if name.startswith("__") and name != "__init__":
continue
patched = patch_object(
member, missing_modules=self.missing_modules, dep_target=self.dep_target, fail_if_not_patchable=False
member,
missing_modules=self.missing_modules,
dep_target=self.dep_target,
fail_if_not_patchable=False,
except_for=except_for,
)
with suppress(AttributeError):
setattr(self.o, name, patched)

return self.o


def patch_object(o: T, *, missing_modules: Iterable[str], dep_target: str, fail_if_not_patchable: bool = True) -> T:
def patch_object(
o: T,
*,
missing_modules: Iterable[str],
dep_target: str,
fail_if_not_patchable: bool = True,
except_for: Optional[Union[str, Iterable[str]]] = None,
) -> T:
patcher = PatchObject.create(o, missing_modules=missing_modules, dep_target=dep_target)
if fail_if_not_patchable and patcher is None:
raise ValueError(f"Cannot patch object of type {type(o)}")

return patcher.patch() if patcher else o
except_for = except_for if except_for is not None else []
except_for = [except_for] if isinstance(except_for, str) else except_for

return patcher.patch(except_for=except_for) if patcher else o


def require_optional_import(modules: Union[str, Iterable[str]], dep_target: str) -> Callable[[T], T]:
def require_optional_import(
modules: Union[str, Iterable[str]],
dep_target: str,
*,
except_for: Optional[Union[str, Iterable[str]]] = None,
) -> Callable[[T], T]:
"""Decorator to handle optional module dependencies
Args:
Expand All @@ -259,7 +304,7 @@ def decorator(o: T) -> T:
else:

def decorator(o: T) -> T:
return patch_object(o, missing_modules=missing_modules, dep_target=dep_target)
return patch_object(o, missing_modules=missing_modules, dep_target=dep_target, except_for=except_for)

return decorator

Expand Down
65 changes: 11 additions & 54 deletions autogen/tools/experimental/browser_use/browser_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
with optional_import_block():
from browser_use import Agent, Controller
from browser_use.browser.browser import Browser, BrowserConfig
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from .langchain_factory import LangchainFactory


__all__ = ["BrowserUseResult", "BrowserUseTool"]
Expand All @@ -37,7 +35,14 @@ class BrowserUseResult(BaseModel):


@require_optional_import(
["langchain_anthropic", "langchain_google_genai", "langchain_ollama", "langchain_openai", "browser_use"],
[
"langchain_anthropic",
"langchain_google_genai",
"langchain_ollama",
"langchain_openai",
"langchain_core",
"browser_use",
],
"browser-use",
)
@export_module("autogen.tools.experimental")
Expand Down Expand Up @@ -88,7 +93,7 @@ async def browser_use( # type: ignore[no-any-unimported]
browser: Annotated[Browser, Depends(on(browser))],
agent_kwargs: Annotated[dict[str, Any], Depends(on(agent_kwargs))],
) -> BrowserUseResult:
llm = BrowserUseTool._get_llm(llm_config)
llm = LangchainFactory.create_base_chat_model(llm_config)

max_steps = agent_kwargs.pop("max_steps", 100)

Expand Down Expand Up @@ -121,51 +126,3 @@ def _get_controller(llm_config: dict[str, Any]) -> Any:
else llm_config.get("response_format")
)
return Controller(output_model=response_format)

@staticmethod
def _get_llm(
llm_config: dict[str, Any],
) -> Any:
if "config_list" not in llm_config:
if "model" in llm_config:
return ChatOpenAI(model=llm_config["model"])
raise ValueError("llm_config must be a valid config dictionary.")

try:
model = llm_config["config_list"][0]["model"]
api_type = llm_config["config_list"][0].get("api_type", "openai")

# Ollama does not require an api_key
api_key = None if api_type == "ollama" else llm_config["config_list"][0]["api_key"]

if api_type == "deepseek" or api_type == "azure" or api_type == "azure":
base_url = llm_config["config_list"][0].get("base_url")
if not base_url:
raise ValueError(f"base_url is required for {api_type} api type.")
if api_type == "azure":
api_version = llm_config["config_list"][0].get("api_version")
if not api_version:
raise ValueError(f"api_version is required for {api_type} api type.")

except (KeyError, TypeError) as e:
raise ValueError(f"llm_config must be a valid config dictionary: {e}")

if api_type == "openai":
return ChatOpenAI(model=model, api_key=api_key)
elif api_type == "azure":
return AzureChatOpenAI(
model=model,
api_key=api_key,
azure_endpoint=base_url,
api_version=api_version,
)
elif api_type == "deepseek":
return ChatOpenAI(model=model, api_key=api_key, base_url=base_url)
elif api_type == "anthropic":
return ChatAnthropic(model=model, api_key=api_key)
elif api_type == "google":
return ChatGoogleGenerativeAI(model=model, api_key=api_key)
elif api_type == "ollama":
return ChatOllama(model=model, num_ctx=32000)
else:
raise ValueError(f"Currently unsupported language model api type for browser use: {api_type}")
Loading

0 comments on commit 812acf7

Please sign in to comment.