diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py new file mode 100644 index 00000000..cc646436 --- /dev/null +++ b/experiments/generic/run_generic_agent.py @@ -0,0 +1,65 @@ +import argparse + +from dotenv import load_dotenv + +load_dotenv() + +import argparse +import logging + +from agentlab.agents.generic_agent.tmlr_config import get_base_agent +from agentlab.experiments.study import Study +from bgym import DEFAULT_BENCHMARKS + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="http://github.com/store_true") + parser.add_argument("--n-jobs", type=int, default=5) + parser.add_argument("--n-relaunch", type=int, default=3) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="http://github.com/store_true") + + args = parser.parse_args() + + # instantiate agent + agent_args = [get_base_agent(args.llm_config)] + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + ##################### Shuffle env args list, pick subset + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:33] + ##################### + + # for env_args in benchmark.env_args_list: + # env_args.max_steps = 100 + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend="ray", + strict_reproducibility=args.reproducibility_mode, + n_relaunch=args.n_relaunch, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/generic/run_generic_agent.sh b/experiments/generic/run_generic_agent.sh new file mode 100644 index 00000000..426af66e --- /dev/null +++ b/experiments/generic/run_generic_agent.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +N_JOBS=5 +N_RELAUNCH=3 + +python experiments/generic/run_generic_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --n-relaunch $N_RELAUNCH \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py new file mode 100644 index 00000000..a5a0d544 --- /dev/null +++ b/experiments/hinter/run_hinter_agent.py @@ -0,0 +1,84 @@ + +from dotenv import load_dotenv +import argparse + +load_dotenv() + +import logging +import argparse + +from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs +from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o +from bgym import DEFAULT_BENCHMARKS +from agentlab.experiments.study import Study + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="http://github.com/store_true") + parser.add_argument("--n-jobs", type=int, default=6) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="http://github.com/store_true") + # hint flags + parser.add_argument("--hint-type", type=str, default="docs") + parser.add_argument("--hint-index-type", type=str, default="sparse") + parser.add_argument("--hint-query-type", type=str, default="direct") + parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") + parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") + parser.add_argument("--hint-num-results", type=int, default=5) + parser.add_argument("--debug", action="http://github.com/store_true") + args = parser.parse_args() + + flags = FLAGS_GPT_4o + flags.use_task_hint = True + flags.hint_type = args.hint_type + flags.hint_index_type = args.hint_index_type + flags.hint_query_type = args.hint_query_type + flags.hint_index_path = args.hint_index_path + flags.hint_retriever_path = args.hint_retriever_path + flags.hint_num_results = args.hint_num_results + + # instantiate agent + agent_args = [GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config], + flags=flags, + )] + + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + if args.debug: + # shuffle env_args_list and + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:6] + + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend=args.parallel_backend, + strict_reproducibility=args.reproducibility_mode, + n_relaunch=3, + ) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh new file mode 100644 index 00000000..9d998ef2 --- /dev/null +++ b/experiments/hinter/run_hinter_agent.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +HINT_TYPE="docs" # human, llm, docs +HINT_INDEX_TYPE="sparse" # sparse, dense +HINT_QUERY_TYPE="goal" # goal, llm +HINT_NUM_RESULTS=3 + +HINT_INDEX_PATH="indexes/servicenow-docs-bm25" +# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" +HINT_RETRIEVER_PATH="google/embeddinggemma-300m" + +N_JOBS=6 + +python experiments/hinter/run_hinter_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --hint-type $HINT_TYPE \ + --hint-index-type $HINT_INDEX_TYPE \ + --hint-query-type $HINT_QUERY_TYPE \ + --hint-index-path $HINT_INDEX_PATH \ + --hint-retriever-path $HINT_RETRIEVER_PATH \ + --hint-num-results $HINT_NUM_RESULTS \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index d1f48f76..646a52b2 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict: def get_action(self, obs): self.obs_history.append(obs) + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, diff --git a/src/agentlab/agents/generic_agent_hinter/__init__.py b/src/agentlab/agents/generic_agent_hinter/__init__.py new file mode 100644 index 00000000..4ad73676 --- /dev/null +++ b/src/agentlab/agents/generic_agent_hinter/__init__.py @@ -0,0 +1,19 @@ +import importlib, sys, warnings + +OLD = __name__ +NEW = "agentlab.agents.hint_use_agent" +SUBS = ("agent_configs", "generic_agent_prompt", "generic_agent", "tmlr_config") + +warnings.warn( + f"{OLD} is renamed to {NEW}. {OLD} will be removed in future", + DeprecationWarning, + stacklevel=2, +) + +# Alias the top-level +new_mod = importlib.import_module(NEW) +sys.modules[OLD] = new_mod + +# Alias known submodules +for sub in SUBS: + sys.modules[f"{OLD}.{sub}"] = importlib.import_module(f"{NEW}.{sub}") diff --git a/src/agentlab/agents/hint_use_agent/__init__.py b/src/agentlab/agents/hint_use_agent/__init__.py new file mode 100644 index 00000000..08b44255 --- /dev/null +++ b/src/agentlab/agents/hint_use_agent/__init__.py @@ -0,0 +1,54 @@ +""" +Baseline agent for all ServiceNow papers + +This module contains the GenericAgent class, which is the baseline agent for all ServiceNow papers. \ +It is a simple agent that can be ran OOB on all BrowserGym environments. It is also shipped with \ +a few configurations that can be used to run it on different environments. +""" + +from .agent_configs import ( + AGENT_3_5, + AGENT_8B, + AGENT_37_SONNET, + AGENT_CLAUDE_SONNET_35, + AGENT_CLAUDE_SONNET_35_VISION, + AGENT_CUSTOM, + AGENT_GPT5_MINI, + AGENT_GPT5_NANO, + AGENT_LLAMA3_70B, + AGENT_LLAMA4_17B_INSTRUCT, + AGENT_LLAMA31_70B, + CHAT_MODEL_ARGS_DICT, + RANDOM_SEARCH_AGENT, + AGENT_4o, + AGENT_4o_MINI, + AGENT_4o_MINI_VISION, + AGENT_4o_VISION, + AGENT_o1_MINI, + AGENT_o3_MINI, + FLAGS_GPT_4o, + GenericAgentArgs, +) +from .generic_agent import GenericAgent, GenericAgentArgs + +__all__ = [ + "AGENT_3_5", + "AGENT_4o", + "AGENT_4o_MINI", + "AGENT_4o_VISION", + "AGENT_o3_MINI", + "AGENT_o1_MINI", + "AGENT_LLAMA4_17B_INSTRUCT", + "AGENT_LLAMA3_70B", + "AGENT_LLAMA31_70B", + "AGENT_8B", + "RANDOM_SEARCH_AGENT", + "AGENT_CUSTOM", + "AGENT_CLAUDE_SONNET_35", + "AGENT_37_SONNET", + "AGENT_4o_VISION", + "AGENT_4o_MINI_VISION", + "AGENT_CLAUDE_SONNET_35_VISION", + "AGENT_GPT5_MINI", + "AGENT_GPT5_NANO", +] diff --git a/src/agentlab/agents/hint_use_agent/agent_configs.py b/src/agentlab/agents/hint_use_agent/agent_configs.py new file mode 100644 index 00000000..031b824c --- /dev/null +++ b/src/agentlab/agents/hint_use_agent/agent_configs.py @@ -0,0 +1,424 @@ +""" +Basic flags and agent configurations for generic agents. +""" + +import bgym +from bgym import HighLevelActionSetArgs + +from agentlab.agents import dynamic_prompting as dp +from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags +from .tmlr_config import BASE_FLAGS + +FLAGS_CUSTOM = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=False, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +AGENT_CUSTOM = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"], + flags=FLAGS_CUSTOM, +) + + +# GPT-3.5 default config +FLAGS_GPT_3_5 = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, # too big for most benchmark except miniwob + use_ax_tree=True, # very useful + use_focused_element=True, # detrimental on minowob according to ablation study + use_error_logs=True, + use_history=True, + use_past_error_logs=False, # very detrimental on L1 and miniwob + use_action_history=True, # helpful on miniwob + use_think_history=False, # detrimental on L1 and miniwob + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, # doesn't change much + extract_clickable_tag=False, # doesn't change much + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, # usually detrimental + use_criticise=False, # usually detrimental + use_thinking=True, # very useful + use_memory=False, + use_concrete_example=True, # useful + use_abstract_example=True, # useful + use_hints=True, # useful + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +AGENT_3_5 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-3.5-turbo-1106"], + flags=FLAGS_GPT_3_5, +) + +# llama3-70b default config +FLAGS_LLAMA3_70B = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=False, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=True, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=False, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, + add_missparsed_messages=True, +) + +AGENT_LLAMA3_70B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3-70b-instruct"], + flags=FLAGS_LLAMA3_70B, +) +AGENT_LLAMA31_70B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-70b-instruct"], + flags=FLAGS_LLAMA3_70B, +) + +FLAGS_8B = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=False, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=False, + extract_clickable_tag=False, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=True, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, + add_missparsed_messages=True, +) + + +AGENT_8B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["meta-llama/Meta-Llama-3-8B-Instruct"], + flags=FLAGS_8B, +) + + +AGENT_LLAMA31_8B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"], + flags=FLAGS_8B, +) + + +# GPT-4o default config +FLAGS_GPT_4o = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + +AGENT_4o = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=FLAGS_GPT_4o, +) + +AGENT_4o_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"], + flags=FLAGS_GPT_4o, +) +AGENT_CLAUDE_SONNET_35 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"], + flags=FLAGS_GPT_4o, +) + +# Claude Sonnet 4 default config with task hints enabled +FLAGS_CLAUDE_SONNET_4 = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + use_task_hint=True, # Explicitly enable task hints + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + +AGENT_CLAUDE_SONNET_4 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["anthropic/claude-sonnet-4-20250514"], + flags=FLAGS_CLAUDE_SONNET_4, +) +AGENT_37_SONNET = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.7-sonnet"], + flags=FLAGS_GPT_4o, +) +# AGENT_o3_MINI = GenericAgentArgs( +# chat_model_args=CHAT_MODEL_ARGS_DICT["openai/o3-mini-2025-01-31"], +# flags=FLAGS_GPT_4o, +# ) +AGENT_o3_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o3-mini"], + flags=FLAGS_GPT_4o, +) + +AGENT_o1_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o1-mini-2024-09-12"], + flags=FLAGS_GPT_4o, +) +# GPT-4o vision default config +FLAGS_GPT_4o_VISION = FLAGS_GPT_4o.copy() +FLAGS_GPT_4o_VISION.obs.use_screenshot = True +FLAGS_GPT_4o_VISION.obs.use_som = True + +AGENT_4o_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_4o_MINI_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_CLAUDE_SONNET_35_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"], + flags=FLAGS_GPT_4o_VISION, +) +AGENT_LLAMA4_17B_INSTRUCT = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-4-maverick"], + flags=BASE_FLAGS, +) +GPT5_MINI_FLAGS = BASE_FLAGS.copy() +GPT5_MINI_FLAGS.action = dp.ActionFlags( # action should not be str to work with agentlab-assistant + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ) +) + +AGENT_GPT5_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-mini-2025-08-07"], + flags=GPT5_MINI_FLAGS, +) +AGENT_GPT5_NANO = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-nano-2025-08-07"], + flags=GPT5_MINI_FLAGS, +) + +AGENT_GPT5 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-2025-08-07"], + flags=GPT5_MINI_FLAGS, +) + +DEFAULT_RS_FLAGS = GenericPromptFlags( + flag_group="default_rs", + obs=dp.ObsFlags( + use_html=True, + use_ax_tree=args.Choice([True, False]), + use_focused_element=False, + use_error_logs=True, + use_history=True, + use_past_error_logs=args.Choice([True, False], p=[0.7, 0.3]), + use_action_history=True, + use_think_history=args.Choice([True, False], p=[0.7, 0.3]), + use_diff=args.Choice([True, False], p=[0.3, 0.7]), + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=args.Choice([True, False]), + extract_clickable_tag=False, + extract_coords=args.Choice(["center", "box"]), + filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]), + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=args.Choice([["bid"], ["bid", "coord"]]), + multiaction=args.Choice([True, False], p=[0.7, 0.3]), + ), + long_description=False, + individual_examples=False, + ), + # drop_ax_tree_first=True, # this flag is no longer active, according to browsergym doc + use_plan=args.Choice([True, False]), + use_criticise=args.Choice([True, False], p=[0.7, 0.3]), + use_thinking=args.Choice([True, False], p=[0.7, 0.3]), + use_memory=args.Choice([True, False], p=[0.7, 0.3]), + use_concrete_example=True, + use_abstract_example=True, + use_hints=args.Choice([True, False], p=[0.7, 0.3]), + be_cautious=args.Choice([True, False]), + enable_chat=False, + max_prompt_tokens=40_000, + extra_instructions=None, +) + + +RANDOM_SEARCH_AGENT = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=DEFAULT_RS_FLAGS, +) diff --git a/src/agentlab/agents/hint_use_agent/generic_agent.py b/src/agentlab/agents/hint_use_agent/generic_agent.py new file mode 100644 index 00000000..afc688ed --- /dev/null +++ b/src/agentlab/agents/hint_use_agent/generic_agent.py @@ -0,0 +1,391 @@ +""" +GenericAgent implementation for AgentLab + +This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \ +The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \ +observations. It includes methods for preprocessing observations, generating actions, and managing internal \ +state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \ +the agent, including model arguments and flags for various behaviors. +""" + +import os +from copy import deepcopy +from dataclasses import asdict, dataclass +from pathlib import Path +from warnings import warn + +import pandas as pd +from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.chat_api import BaseModelArgs +from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry +from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource +from bgym import Benchmark +from browsergym.experiments.agent import Agent, AgentInfo + +from .generic_agent_prompt import ( + GenericPromptFlags, + MainPrompt, + StepWiseContextIdentificationPrompt, +) + + +@dataclass +class GenericAgentArgs(AgentArgs): + chat_model_args: BaseModelArgs = None + flags: GenericPromptFlags = None + max_retry: int = 4 + + def __post_init__(self): + try: # some attributes might be temporarily args.CrossProd for hyperparameter generation + # TODO: Rename the agent to HintUseAgent when appropriate + self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace( + "/", "_" + ) + except AttributeError: + pass + + def set_benchmark(self, benchmark: Benchmark, demo_mode): + """Override Some flags based on the benchmark.""" + if benchmark.name.startswith("miniwob"): + self.flags.obs.use_html = True + + self.flags.obs.use_tabs = benchmark.is_multi_tab + self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args) + + # for backward compatibility with old traces + if self.flags.action.multi_actions is not None: + self.flags.action.action_set.multiaction = self.flags.action.multi_actions + if self.flags.action.is_strict is not None: + self.flags.action.action_set.strict = self.flags.action.is_strict + + # verify if we can remove this + if demo_mode: + self.flags.action.action_set.demo_mode = "all_blue" + + def set_reproducibility_mode(self): + self.chat_model_args.temperature = 0 + + def prepare(self): + return self.chat_model_args.prepare_server() + + def close(self): + return self.chat_model_args.close_server() + + def make_agent(self): + return GenericAgent( + chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry + ) + + +class GenericAgent(Agent): + + def __init__( + self, + chat_model_args: BaseModelArgs, + flags: GenericPromptFlags, + max_retry: int = 4, + ): + + self.chat_llm = chat_model_args.make_model() + self.chat_model_args = chat_model_args + self.max_retry = max_retry + + self.flags = flags + + if self.flags.hint_db_path is not None and self.flags.use_task_hint: + assert os.path.exists( + self.flags.hint_db_path + ), f"Hint database path {self.flags.hint_db_path} does not exist." + self.action_set = self.flags.action.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + + self._init_hints_index() + + self._check_flag_constancy() + self.reset(seed=None) + + def obs_preprocessor(self, obs: dict) -> dict: + return self._obs_preprocessor(obs) + + def set_task_name(self, task_name: str): + """Set the task name for task hints functionality.""" + self.task_name = task_name + + @cost_tracker_decorator + def get_action(self, obs): + + self.obs_history.append(obs) + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + + # use those queries to retrieve from the database and pass to prompt if step-level + self.queries = ( + self._get_queries()[0] + if getattr(self.flags, "hint_level", "episode") == "step" + else None + ) + + # get hints + if self.flags.use_hints: + task_hints = self._get_task_hints() + else: + task_hints = [] + + main_prompt = MainPrompt( + action_set=self.action_set, + obs_history=self.obs_history, + actions=self.actions, + memories=self.memories, + thoughts=self.thoughts, + previous_plan=self.plan, + step=self.plan_step, + flags=self.flags, + llm=self.chat_llm, + task_hints=task_hints, + ) + + # Set task name for task hints if available + if self.flags.use_task_hint and hasattr(self, "task_name"): + main_prompt.set_task_name(self.task_name) + + max_prompt_tokens, max_trunc_itr = self._get_maxes() + + human_prompt = dp.fit_tokens( + shrinkable=main_prompt, + max_prompt_tokens=max_prompt_tokens, + model_name=self.chat_model_args.model_name, + max_iterations=max_trunc_itr, + additional_prompts=system_prompt, + ) + try: + # TODO, we would need to further shrink the prompt if the retry + # cause it to be too long + + chat_messages = Discussion([system_prompt, human_prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=main_prompt._parse_answer, + ) + ans_dict["busted_retry"] = 0 + # inferring the number of retries, TODO: make this less hacky + ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 + except ParseError as e: + ans_dict = dict( + action=None, + n_retry=self.max_retry + 1, + busted_retry=1, + ) + + stats = self.chat_llm.get_stats() + stats["n_retry"] = ans_dict["n_retry"] + stats["busted_retry"] = ans_dict["busted_retry"] + + self.plan = ans_dict.get("plan", self.plan) + self.plan_step = ans_dict.get("step", self.plan_step) + self.actions.append(ans_dict["action"]) + self.memories.append(ans_dict.get("memory", None)) + self.thoughts.append(ans_dict.get("think", None)) + + agent_info = AgentInfo( + think=ans_dict.get("think", None), + chat_messages=chat_messages, + stats=stats, + extra_info={"chat_model_args": asdict(self.chat_model_args)}, + ) + return ans_dict["action"], agent_info + + def _get_queries(self): + """Retrieve queries for hinting.""" + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + query_prompt = StepWiseContextIdentificationPrompt( + obs_history=self.obs_history, + actions=self.actions, + thoughts=self.thoughts, + obs_flags=self.flags.obs, + n_queries=self.flags.n_retrieval_queries, # TODO + ) + + chat_messages = Discussion([system_prompt, query_prompt.prompt]) + # BUG: Parsing fails multiple times. + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=query_prompt._parse_answer, + ) + + queries = ans_dict.get("queries", []) + assert len(queries) <= self.flags.n_retrieval_queries + + # TODO: we should probably propagate these chat_messages to be able to see them in xray + return queries, ans_dict.get("think", None) + + def reset(self, seed=None): + self.seed = seed + self.plan = "No plan yet" + self.plan_step = -1 + self.memories = [] + self.thoughts = [] + self.actions = [] + self.obs_history = [] + + def _check_flag_constancy(self): + flags = self.flags + if flags.obs.use_som: + if not flags.obs.use_screenshot: + warn( + """ +Warning: use_som=True requires use_screenshot=True. Disabling use_som.""" + ) + flags.obs.use_som = False + if flags.obs.use_screenshot: + if not self.chat_model_args.vision_support: + warn( + """ +Warning: use_screenshot is set to True, but the chat model \ +does not support vision. Disabling use_screenshot.""" + ) + flags.obs.use_screenshot = False + return flags + + def _get_maxes(self): + maxes = ( + self.flags.max_prompt_tokens, + self.chat_model_args.max_total_tokens, + self.chat_model_args.max_input_tokens, + ) + maxes = [m for m in maxes if m is not None] + max_prompt_tokens = min(maxes) if maxes else None + max_trunc_itr = ( + self.flags.max_trunc_itr + if self.flags.max_trunc_itr + else 20 # dangerous to change the default value here? + ) + return max_prompt_tokens, max_trunc_itr + + def _init_hints_index(self): + """Initialize the block.""" + try: + if self.flags.hint_type == "docs": + if self.flags.hint_index_type == "sparse": + import bm25s + + self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) + elif self.flags.hint_index_type == "dense": + from datasets import load_from_disk + from sentence_transformers import SentenceTransformer + + self.hint_index = load_from_disk(self.flags.hint_index_path) + self.hint_index.load_faiss_index( + "embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss" + ) + self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) + else: + raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") + else: + # Use external path if provided, otherwise fall back to relative path + if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists(): + hint_db_path = Path(self.flags.hint_db_path) + else: + hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path + + if hint_db_path.exists(): + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + # Verify the expected columns exist + if ( + "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + else: + print(f"Warning: Hint database not found at {hint_db_path}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( + hint_db_path=hint_db_path.as_posix(), + hint_retrieval_mode=self.flags.hint_retrieval_mode, + skip_hints_for_current_task=self.flags.skip_hints_for_current_task, + ) + except Exception as e: + # Fallback to empty database on any error + print(f"Warning: Could not load hint database: {e}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + + def _get_task_hints(self) -> list[str]: + """Get hints for a specific task.""" + if not self.flags.use_task_hint: + return [] + + if self.flags.hint_type == "docs": + if not hasattr(self, "hint_index"): + print("Initializing hint index new time") + # @patricebechard It seems _.init() method is missing do we still need it? + # self._init() + if self.flags.hint_query_type == "goal": + query = self.obs_history[-1]["goal_object"][0]["text"] + elif self.flags.hint_query_type == "llm": + queries, _ = self._get_queries() + # HACK: only 1 query supported + query = queries[0] + else: + # @patricebechard: This raises an error with the default value 'direct' + raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") + + print(f"Query: {query}") + if self.flags.hint_index_type == "sparse": + import bm25s + + query_tokens = bm25s.tokenize(query) + docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) + docs = [elem["text"] for elem in docs[0]] + # HACK: truncate to 20k characters (should cover >99% of the cases) + for doc in docs: + if len(doc) > 20000: + doc = doc[:20000] + doc += " ...[truncated]" + elif self.flags.hint_index_type == "dense": + query_embedding = self.hint_retriever.encode(query) + _, docs = self.hint_index.get_nearest_examples( + "embeddings", query_embedding, k=self.flags.hint_num_results + ) + docs = docs["text"] + + return docs + + # Check if hint_db has the expected structure + if ( + self.hint_db.empty + or "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): + return [] + + try: + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"] + if self.flags.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.chat_llm, + self.task_name, + goal_or_queries, + ) + + hints = [] + for hint in task_hints: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + return hints + except Exception as e: + print(f"Warning: Error getting hints for task {self.task_name}: {e}") + + return [] diff --git a/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py b/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py new file mode 100644 index 00000000..5ccb73a9 --- /dev/null +++ b/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py @@ -0,0 +1,411 @@ +""" +Prompt builder for GenericAgent + +It is based on the dynamic_prompting module from the agentlab package. +""" + +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import pandas as pd +from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource +from agentlab.llm.chat_api import ChatModel +from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise +from browsergym.core.action.base import AbstractActionSet + +logger = logging.getLogger(__name__) + + +@dataclass +class GenericPromptFlags(dp.Flags): + """ + A class to represent various flags used to control features in an application. + + Attributes: + use_plan (bool): Ask the LLM to provide a plan. + use_criticise (bool): Ask the LLM to first draft and criticise the action before producing it. + use_thinking (bool): Enable a chain of thoughts. + use_concrete_example (bool): Use a concrete example of the answer in the prompt for a generic task. + use_abstract_example (bool): Use an abstract example of the answer in the prompt. + use_hints (bool): Add some human-engineered hints to the prompt. + use_task_hint (bool): Enable task-specific hints from hint database. + hint_db_path (str): Path to the hint database file. + enable_chat (bool): Enable chat mode, where the agent can interact with the user. + max_prompt_tokens (int): Maximum number of tokens allowed in the prompt. + be_cautious (bool): Instruct the agent to be cautious about its actions. + extra_instructions (Optional[str]): Extra instructions to provide to the agent. + add_missparsed_messages (bool): When retrying, add the missparsed messages to the prompt. + flag_group (Optional[str]): Group of flags used. + """ + + obs: dp.ObsFlags + action: dp.ActionFlags + use_plan: bool = False # + use_criticise: bool = False # + use_thinking: bool = False + use_memory: bool = False # + use_concrete_example: bool = True + use_abstract_example: bool = False + use_hints: bool = False + use_task_hint: bool = False + enable_chat: bool = False + max_prompt_tokens: int = None + be_cautious: bool = True + extra_instructions: str | None = None + add_missparsed_messages: bool = True + max_trunc_itr: int = 20 + flag_group: str = None + + # hint related + use_task_hint: bool = False + hint_type: str = "docs" + hint_index_type: str = "sparse" + hint_query_type: str = "direct" + hint_index_path: str = "indexes/servicenow-docs-bm25" + hint_retriever_path: str = "google/embeddinggemma-300m" + hint_num_results: int = 5 + n_retrieval_queries: int = 1 + + +class MainPrompt(dp.Shrinkable): + def __init__( + self, + action_set: AbstractActionSet, + obs_history: list[dict], + actions: list[str], + memories: list[str], + thoughts: list[str], + previous_plan: str, + step: int, + flags: GenericPromptFlags, + llm: ChatModel, + task_hints: list[str] = [], + ) -> None: + super().__init__() + self.flags = flags + self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs) + goal = obs_history[-1]["goal_object"] + if self.flags.enable_chat: + self.instructions = dp.ChatInstructions( + obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions + ) + else: + if sum([msg["role"] == "user" for msg in obs_history[-1].get("chat_messages", [])]) > 1: + logging.warning( + "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." + ) + self.instructions = dp.GoalInstructions( + goal, extra_instructions=flags.extra_instructions + ) + + self.obs = dp.Observation( + obs_history[-1], + self.flags.obs, + ) + + self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) + + def time_for_caution(): + # no need for caution if we're in single action mode + return flags.be_cautious and ( + flags.action.action_set.multiaction or flags.action.action_set == "python" + ) + + self.be_cautious = dp.BeCautious(visible=time_for_caution) + self.think = dp.Think(visible=lambda: flags.use_thinking) + self.hints = dp.Hints(visible=lambda: flags.use_hints) + self.task_hints = TaskHint(visible=lambda: flags.use_task_hint, task_hints=task_hints) + self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan + self.criticise = Criticise(visible=lambda: flags.use_criticise) + self.memory = Memory(visible=lambda: flags.use_memory) + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +{self.action_prompt.prompt}\ +{self.hints.prompt}\ +{self.task_hints.prompt}\ +{self.be_cautious.prompt}\ +{self.think.prompt}\ +{self.plan.prompt}\ +{self.memory.prompt}\ +{self.criticise.prompt}\ +""" + ) + + if self.flags.use_abstract_example: + prompt.add_text( + f""" +# Abstract Example + +Here is an abstract version of the answer with description of the content of +each tag. Make sure you follow this structure, but replace the content with your +answer: +{self.think.abstract_ex}\ +{self.plan.abstract_ex}\ +{self.memory.abstract_ex}\ +{self.criticise.abstract_ex}\ +{self.task_hints.abstract_ex}\ +{self.action_prompt.abstract_ex}\ +""" + ) + + if self.flags.use_concrete_example: + prompt.add_text( + f""" +# Concrete Example + +Here is a concrete example of how to format your answer. +Make sure to follow the template with proper tags: +{self.think.concrete_ex}\ +{self.plan.concrete_ex}\ +{self.memory.concrete_ex}\ +{self.criticise.concrete_ex}\ +{self.task_hints.concrete_ex}\ +{self.action_prompt.concrete_ex}\ +""" + ) + return self.obs.add_screenshot(prompt) + + def shrink(self): + self.history.shrink() + self.obs.shrink() + + def set_task_name(self, task_name: str): + """Set the task name for task hints functionality.""" + self.task_name = task_name + + def _parse_answer(self, text_answer): + ans_dict = {} + ans_dict.update(self.think.parse_answer(text_answer)) + ans_dict.update(self.plan.parse_answer(text_answer)) + ans_dict.update(self.memory.parse_answer(text_answer)) + ans_dict.update(self.criticise.parse_answer(text_answer)) + ans_dict.update(self.action_prompt.parse_answer(text_answer)) + return ans_dict + + +class Memory(dp.PromptElement): + _prompt = "" # provided in the abstract and concrete examples + + _abstract_ex = """ + +Write down anything you need to remember for next steps. You will be presented +with the list of previous memories and past actions. Some tasks require to +remember hints from previous steps in order to solve it. + +""" + + _concrete_ex = """ + +I clicked on bid "32" to activate tab 2. The accessibility tree should mention +focusable for elements of the form at next step. + +""" + + def _parse_answer(self, text_answer): + return parse_html_tags_raise(text_answer, optional_keys=["memory"], merge_multiple=True) + + +class Plan(dp.PromptElement): + def __init__(self, previous_plan, plan_step, visible: bool = True) -> None: + super().__init__(visible=visible) + self.previous_plan = previous_plan + self._prompt = f""" +# Plan: + +You just executed step {plan_step} of the previously proposed plan:\n{previous_plan}\n +After reviewing the effect of your previous actions, verify if your plan is still +relevant and update it if necessary. +""" + + _abstract_ex = """ + +Provide a multi step plan that will guide you to accomplish the goal. There +should always be steps to verify if the previous action had an effect. The plan +can be revisited at each steps. Specifically, if there was something unexpected. +The plan should be cautious and favor exploring befor submitting. + + +Integer specifying the step of current action + +""" + + _concrete_ex = """ + +1. fill form (failed) + * type first name + * type last name +2. Try to activate the form + * click on tab 2 +3. fill form again + * type first name + * type last name +4. verify and submit + * verify form is filled + * submit if filled, if not, replan + + +2 +""" + + def _parse_answer(self, text_answer): + return parse_html_tags_raise(text_answer, optional_keys=["plan", "step"]) + + +class Criticise(dp.PromptElement): + _prompt = "" + + _abstract_ex = """ + +Write a first version of what you think is the right action. + + + +Criticise action_draft. What could be wrong with it? Enumerate reasons why it +could fail. Did your past actions had the expected effect? Make sure you're not +repeating the same mistakes. + +""" + + _concrete_ex = """ + +click("32") + + + +click("32") might not work because the element is not visible yet. I need to +explore the page to find a way to activate the form. + +""" + + def _parse_answer(self, text_answer): + return parse_html_tags_raise(text_answer, optional_keys=["action_draft", "criticise"]) + + +class TaskHint(dp.PromptElement): + def __init__(self, visible: bool, task_hints: list[str]) -> None: + super().__init__(visible=visible) + self.task_hints = task_hints + + @property + def _prompt(self): + task_hint_str = "# Hints:\nHere are some hints for the task you are working on:\n" + for hint in self.task_hints: + task_hint_str += f"{hint}\n" + return task_hint_str + + _abstract_ex = """ + +What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none. + +""" + + _concrete_ex = """ + +Relevant hint: Based on the hints provided, I should focus on the form elements and use the +accessibility tree to identify interactive elements before taking actions. + +""" + + +class StepWiseContextIdentificationPrompt(dp.Shrinkable): + def __init__( + self, + obs_history: list[dict], + actions: list[str], + thoughts: list[str], + obs_flags: dp.ObsFlags, + n_queries: int = 1, + ) -> None: + super().__init__() + self.obs_flags = obs_flags + self.n_queries = n_queries + self.history = dp.History(obs_history, actions, None, thoughts, obs_flags) + self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"]) + self.obs = dp.Observation(obs_history[-1], obs_flags) + + self.think = dp.Think(visible=True) # To replace with static text maybe + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +""" + ) + + example_queries = [ + "The user has started sorting a table and needs to apply multiple column criteria simultaneously.", + "The user is attempting to configure advanced sorting options but the interface is unclear.", + "The user has selected the first sort column and is now looking for how to add a second sort criterion.", + "The user is in the middle of a multi-step sorting process and needs guidance on the next action.", + ] + + example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) + + prompt.add_text( + f""" +# Querying memory + +Before choosing an action, let's search our available documentation and memory for relevant context. +Generate a brief, general summary of the current status to help identify useful hints. Return your answer in the following format: +chain of thought +json list of strings of queries + +Additional instructions: List of queries should contain up to {self.n_queries} queries. Both the think and the queries blocks are required! + +# Concrete Example +``` + +I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if +I will be able to sort by both at the same time. + + + +{example_queries_str} + +``` +Note: do not generate backticks. +Now proceed to generate your own thoughts and queries. +Always return non-empty answer, its very important! +""" + ) + + return self.obs.add_screenshot(prompt) + + def shrink(self): + self.history.shrink() + self.obs.shrink() + + def _parse_answer(self, text_answer): + try: + ans_dict = parse_html_tags_raise( + text_answer, keys=["think", "queries"], merge_multiple=True + ) + except Exception as e: + t = text_answer.replace("\n", "\\n") + logger.warning(f"Failed to parse llm answer: {e}. RAW answer: '{t}'. Will retry") + raise e + raw_queries = ans_dict.get("queries", "[]") + try: + ans_dict["queries"] = json.loads(raw_queries) + except Exception as e: + t = text_answer.replace("\n", "\\n") + logger.warning( + f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry" + ) + raise e + return ans_dict diff --git a/src/agentlab/agents/hint_use_agent/tmlr_config.py b/src/agentlab/agents/hint_use_agent/tmlr_config.py new file mode 100644 index 00000000..5a749721 --- /dev/null +++ b/src/agentlab/agents/hint_use_agent/tmlr_config.py @@ -0,0 +1,78 @@ +""" +Specific configurations for our 2024 TMLR submission. +""" + +from copy import deepcopy + +from agentlab.agents import dynamic_prompting as dp +from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags + +BASE_FLAGS = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=True, # gpt-4o config except for this line + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + multi_actions=False, + action_set="bid", + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +def get_base_agent(llm_config: str): + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=BASE_FLAGS, + ) + + +def get_vision_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + agent_args = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) + agent_args.agent_name = f"{agent_args.agent_name}_vision" + return agent_args + + +def get_som_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + flags.obs.use_som = True + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 375c829e..894616a4 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -28,6 +28,7 @@ from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark from agentlab.benchmarks.osworld import OSWorldActionSet from agentlab.llm.base_api import BaseModelArgs +from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( APIPayload, @@ -40,8 +41,10 @@ ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) @dataclass @@ -316,39 +319,22 @@ class TaskHint(Block): def _init(self): """Initialize the block.""" - if Path(self.hint_db_rel_path).is_absolute(): - hint_db_path = Path(self.hint_db_rel_path) - else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - if self.hint_retrieval_mode == "emb": - self.encode_hints() - - def oai_embed(self, text: str): - response = self._oai_emb.create(input=text, model="text-embedding-3-small") - return response.data[0].embedding - - def encode_hints(self): - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") - logger.info( - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." - ) - hints = self.uniq_hints["hint"].tolist() - semantic_keys = self.uniq_hints["semantic_keys"].tolist() - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] - emb_path = f"{self.hint_db_rel_path}.embs.npy" - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" - logger.info(f"Loading hint embeddings from: {emb_path}") - emb_dict = np.load(emb_path, allow_pickle=True).item() - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + if self.use_task_hint: + self.hints_source = HintsSource( + hint_db_path=self.hint_db_rel_path, + hint_retrieval_mode=self.hint_retrieval_mode, + top_n=self.top_n, + embedder_model=self.embedder_model, + embedder_server=self.embedder_server, + llm_prompt=self.llm_prompt, + ) def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: return {} goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) - task_hints = self.choose_hints(llm, task_name, goal) + task_hints = self.hints_source.choose_hints(llm, task_name, goal) hints = [] for hint in task_hints: @@ -358,101 +344,13 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if len(hints) > 0: hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" + "\n# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join(hints) ) msg = llm.msg.user().add_text(hints_str) discussion.append(msg) - def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: - """Choose hints based on the task name.""" - if self.hint_retrieval_mode == "llm": - return self.choose_hints_llm(llm, goal) - elif self.hint_retrieval_mode == "direct": - return self.choose_hints_direct(task_name) - elif self.hint_retrieval_mode == "emb": - return self.choose_hints_emb(goal) - else: - raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") - - def choose_hints_llm(self, llm, goal: str) -> list[str]: - """Choose hints using LLM to filter the hints.""" - topic_to_hints = defaultdict(list) - for i, row in self.hint_db.iterrows(): - topic_to_hints[row["semantic_keys"]].append(i) - hint_topics = list(topic_to_hints.keys()) - topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) - prompt = self.llm_prompt.format(goal=goal, topics=topics) - response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])) - try: - hint_topic_idx = json.loads(response.think) - if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): - logger.error(f"Wrong LLM hint id response: {response.think}, no hints") - return [] - hint_topic = hint_topics[hint_topic_idx] - hint_indices = topic_to_hints[hint_topic] - df = self.hint_db.iloc[hint_indices].copy() - df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints - hints = df["hint"].tolist() - logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") - except json.JSONDecodeError: - logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints") - hints = [] - return hints - - def choose_hints_emb(self, goal: str) -> list[str]: - """Choose hints using embeddings to filter the hints.""" - goal_embeddings = self._encode([goal], prompt="task description") - similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist()) - top_indices = similarities.argsort()[0][-self.top_n :].tolist() - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = self.uniq_hints.iloc[top_indices] - logger.info(f"Embedding-based hints chosen: {hints}") - return hints["hint"].tolist() - - def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): - """Call the encode API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/encode", - json={"texts": texts, "prompt": prompt}, - timeout=timeout, - ) - embs = response.json()["embeddings"] - return np.asarray(embs) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - - def _similarity( - self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 - ): - """Call the similarity API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/similarity", - json={"texts1": texts1, "texts2": texts2}, - timeout=timeout, - ) - similarities = response.json()["similarities"] - return np.asarray(similarities) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - - def choose_hints_direct(self, task_name: str) -> list[str]: - hints = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] - return hints["hint"].tolist() - @dataclass class PromptConfig: diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 6dbec117..ce8654dc 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -537,9 +537,10 @@ def run_gradio(results_dir: Path): do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", "false").lower() == "true" port = os.getenv("AGENTXRAY_APP_PORT", None) + server_name = "0.0.0.0" if os.getenv("AGENTXRAY_PUBLIC", "false") == "true" else "127.0.0.1" if isinstance(port, str): port = int(port) - demo.launch(server_port=port, share=do_share) + demo.launch(server_name=server_name, server_port=port, share=do_share) def handle_key_event(key_event, step_id: StepId): diff --git a/src/agentlab/experiments/graph_execution_ray.py b/src/agentlab/experiments/graph_execution_ray.py index f047f866..f7aad780 100644 --- a/src/agentlab/experiments/graph_execution_ray.py +++ b/src/agentlab/experiments/graph_execution_ray.py @@ -3,9 +3,8 @@ import bgym import ray -from ray.util import state - from agentlab.experiments.exp_utils import _episode_timeout, run_exp +from ray.util import state logger = logging.getLogger(__name__) @@ -79,6 +78,7 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter try: result = ray.get(task) except Exception as e: + logger.exception(f"Task failed: {e}") result = e results.append(result) diff --git a/src/agentlab/experiments/launch_exp.py b/src/agentlab/experiments/launch_exp.py index 3bc6c54e..f14ae4c7 100644 --- a/src/agentlab/experiments/launch_exp.py +++ b/src/agentlab/experiments/launch_exp.py @@ -1,4 +1,5 @@ import logging +import os from importlib import import_module from pathlib import Path @@ -7,6 +8,8 @@ from agentlab.experiments.exp_utils import run_exp from agentlab.experiments.loop import ExpArgs, yield_all_exp_results +RAY_PUBLIC_DASHBOARD = os.environ.get("RAY_PUBLIC_DASHBOARD", "false") == "true" + def run_experiments( n_jobs, @@ -82,7 +85,9 @@ def run_experiments( elif parallel_backend == "ray": from agentlab.experiments.graph_execution_ray import execute_task_graph, ray - ray.init(num_cpus=n_jobs) + ray.init( + num_cpus=n_jobs, dashboard_host="0.0.0.0" if RAY_PUBLIC_DASHBOARD else "127.0.0.1" + ) try: execute_task_graph(exp_args_list, avg_step_timeout=avg_step_timeout) finally: diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index de4b976a..f69322a6 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -907,7 +907,6 @@ def _move_old_exp(exp_dir): def _get_env_name(task_name: str): """Register tasks if needed (lazy import) and return environment name.""" - # lazy import if task_name.startswith("miniwob"): import browsergym.miniwob @@ -915,6 +914,7 @@ def _get_env_name(task_name: str): import browsergym.workarena elif task_name.startswith("webarena"): import browsergym.webarena + import browsergym.webarenalite elif task_name.startswith("visualwebarena"): import browsergym.visualwebarena elif task_name.startswith("assistantbench"): diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index c1ee458f..2d0cb6ed 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -20,37 +20,6 @@ ] CHAT_MODEL_ARGS_DICT = { - "openai/gpt-5-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-2025-08-07", - max_total_tokens=400_000, - max_input_tokens=256_000, - max_new_tokens=128_000, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), - "openai/gpt-5-nano-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-nano-2025-08-07", - max_total_tokens=400_000, - max_input_tokens=256_000, - max_new_tokens=128_000, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), - "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-mini-2025-08-07", - max_total_tokens=400_000, - max_input_tokens=256_000, - max_new_tokens=128_000, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), - "openai/gpt-4.1-nano-2025-04-14": OpenAIModelArgs( - model_name="gpt-4.1-nano-2025-04-14", - max_total_tokens=128_000, - max_input_tokens=128_000, - max_new_tokens=16_384, - vision_support=True, - ), "openai/gpt-4.1-mini-2025-04-14": OpenAIModelArgs( model_name="gpt-4.1-mini-2025-04-14", max_total_tokens=128_000, @@ -126,6 +95,30 @@ max_new_tokens=64_000, temperature=1e-1, ), + "openai/gpt-5-nano-2025-08-07": OpenAIModelArgs( + model_name="gpt-5-nano-2025-08-07", + max_total_tokens=400_000, + max_input_tokens=400_000 - 4_000, + max_new_tokens=4_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + ), + "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs( + model_name="gpt-5-mini-2025-08-07", + max_total_tokens=400_000, + max_input_tokens=400_000 - 4_000, + max_new_tokens=4_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + ), + "openai/gpt-5-2025-08-07": OpenAIModelArgs( + model_name="gpt-5-2025-08-07", + max_total_tokens=400_000, + max_input_tokens=400_000 - 4_000, + max_new_tokens=4_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + ), "azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs( model_name="gpt-35-turbo", max_total_tokens=8_192, diff --git a/src/agentlab/utils/__init__.py b/src/agentlab/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py new file mode 100644 index 00000000..506513d5 --- /dev/null +++ b/src/agentlab/utils/hinting.py @@ -0,0 +1,221 @@ +import fnmatch +import json +import logging +import os +import random +import time +from collections import defaultdict +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +import requests +from agentlab.llm.chat_api import ChatModel +import re +import json +from agentlab.llm.response_api import APIPayload + +logger = logging.getLogger(__name__) + + +class HintsSource: + + def __init__( + self, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + skip_hints_for_current_task: bool = False, + skip_hints_for_current_goal: bool = False, + top_n: int = 4, + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", + embedder_server: str = "http://localhost:5000/", + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number. Use the following output format: +index for e.g. 0 for the topic with index 0. If you don't know the answer, return -1""", + ) -> None: + self.hint_db_path = hint_db_path + self.hint_retrieval_mode = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task + self.skip_hints_for_current_goal = skip_hints_for_current_goal + self.top_n = top_n + self.embedder_model = embedder_model + self.embedder_server = embedder_server + self.llm_prompt = llm_prompt + + if Path(hint_db_path).is_absolute(): + self.hint_db_path = Path(hint_db_path).as_posix() + else: + self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() + self.hint_db = pd.read_csv( + self.hint_db_path, + header=0, + index_col=None, + dtype=str, + converters={ + "trace_paths_json": lambda x: json.loads(x) if pd.notna(x) else [], + "source_trace_goals": lambda x: json.loads(x) if pd.notna(x) else [], + }, + ) + logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") + if self.hint_retrieval_mode == "emb": + self.load_hint_vectors() + + def load_hint_vectors(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: + """Choose hints based on the task name.""" + logger.info( + f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" + ) + if self.hint_retrieval_mode == "llm": + return self.choose_hints_llm(llm, goal, task_name) + elif self.hint_retrieval_mode == "direct": + return self.choose_hints_direct(task_name) + elif self.hint_retrieval_mode == "emb": + return self.choose_hints_emb(goal, task_name) + else: + raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") + + def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: + """Choose hints using LLM to filter the hints.""" + topic_to_hints = defaultdict(list) + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints += self.get_current_task_hints(task_name) + if self.skip_hints_for_current_goal: + skip_hints += self.get_current_goal_hints(goal) + for _, row in self.hint_db.iterrows(): + hint = row["hint"] + if hint in skip_hints: + continue + topic_to_hints[row["semantic_keys"]].append(hint) + logger.info(f"Collected {len(topic_to_hints)} hint topics") + hint_topics = list(topic_to_hints.keys()) + topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) + prompt = self.llm_prompt.format(goal=goal, topics=topics) + + if isinstance(llm, ChatModel): + response: str = llm(messages=[dict(role="user", content=prompt)])["content"] + else: + response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think + try: + matches = re.findall(r"(-?\d+)", response) + if not matches: + logger.error(f"No choice tags found in LLM response: {response}") + return [] + if len(matches) > 1: + logger.warning( + f"LLM selected multiple topics for retrieval using only the first one." + ) + topic_number = int(matches[0]) + if topic_number < 0 or topic_number >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response}, no hints") + return [] + hint_topic = hint_topics[topic_number] + hints = list(set(topic_to_hints[hint_topic])) + logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") + except Exception as e: + logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") + hints = [] + return hints + + def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: + """Choose hints using embeddings to filter the hints.""" + try: + goal_embeddings = self._encode([goal], prompt="task description") + hint_embeddings = self.hint_embeddings.copy() + all_hints = self.uniq_hints["hint"].tolist() + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints += self.get_current_task_hints(task_name) + if self.skip_hints_for_current_goal: + skip_hints += self.get_current_goal_hints(goal) + hint_embeddings = [] + id_to_hint = {} + for hint, emb in zip(all_hints, self.hint_embeddings): + if hint in skip_hints: + continue + hint_embeddings.append(emb.tolist()) + id_to_hint[len(hint_embeddings) - 1] = hint + logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = [id_to_hint[idx] for idx in top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + except Exception as e: + logger.exception(f"Failed to choose hints using embeddings: {e}") + hints = [] + return hints + + def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): + """Call the encode API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/encode", + json={"texts": texts, "prompt": prompt}, + timeout=timeout, + ) + embs = response.json()["embeddings"] + return np.asarray(embs) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to encode hints") + + def _similarity( + self, + texts1: list, + texts2: list, + timeout: int = 2, + max_retries: int = 5, + ): + """Call the similarity API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/similarity", + json={"texts1": texts1, "texts2": texts2}, + timeout=timeout, + ) + similarities = response.json()["similarities"] + return np.asarray(similarities) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to compute similarity") + + def choose_hints_direct(self, task_name: str) -> list[str]: + hints = self.get_current_task_hints(task_name) + logger.info(f"Direct hints chosen: {hints}") + return hints + + def get_current_task_hints(self, task_name): + hints_df = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + return hints_df["hint"].tolist() + + def get_current_goal_hints(self, goal_str: str): + mask = self.hint_db["source_trace_goals"].apply(lambda goals: goal_str in goals) + return self.hint_db.loc[mask, "hint"].tolist()