diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py
new file mode 100644
index 00000000..cc646436
--- /dev/null
+++ b/experiments/generic/run_generic_agent.py
@@ -0,0 +1,65 @@
+import argparse
+
+from dotenv import load_dotenv
+
+load_dotenv()
+
+import argparse
+import logging
+
+from agentlab.agents.generic_agent.tmlr_config import get_base_agent
+from agentlab.experiments.study import Study
+from bgym import DEFAULT_BENCHMARKS
+
+logging.getLogger().setLevel(logging.WARNING)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--benchmark", required=True)
+ parser.add_argument("--llm-config", required=True)
+ parser.add_argument("--relaunch", action="http://github.com/store_true")
+ parser.add_argument("--n-jobs", type=int, default=5)
+ parser.add_argument("--n-relaunch", type=int, default=3)
+ parser.add_argument("--parallel-backend", type=str, default="ray")
+ parser.add_argument("--reproducibility-mode", action="http://github.com/store_true")
+
+ args = parser.parse_args()
+
+ # instantiate agent
+ agent_args = [get_base_agent(args.llm_config)]
+ benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
+
+ ##################### Shuffle env args list, pick subset
+ import numpy as np
+ rng = np.random.default_rng(42)
+ rng.shuffle(benchmark.env_args_list)
+ benchmark.env_args_list = benchmark.env_args_list[:33]
+ #####################
+
+ # for env_args in benchmark.env_args_list:
+ # env_args.max_steps = 100
+
+ if args.relaunch:
+ # relaunch an existing study
+ study = Study.load_most_recent(contains=None)
+ study.find_incomplete(include_errors=True)
+
+ else:
+ study = Study(
+ agent_args,
+ benchmark,
+ logging_level=logging.WARNING,
+ logging_level_stdout=logging.WARNING,
+ )
+
+ study.run(
+ n_jobs=args.n_jobs,
+ parallel_backend="ray",
+ strict_reproducibility=args.reproducibility_mode,
+ n_relaunch=args.n_relaunch,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/experiments/generic/run_generic_agent.sh b/experiments/generic/run_generic_agent.sh
new file mode 100644
index 00000000..426af66e
--- /dev/null
+++ b/experiments/generic/run_generic_agent.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+BENCHMARK="workarena_l1"
+
+LLM_CONFIG="azure/gpt-5-mini-2025-08-07"
+# PARALLEL_BACKEND="sequential"
+PARALLEL_BACKEND="ray"
+
+N_JOBS=5
+N_RELAUNCH=3
+
+python experiments/generic/run_generic_agent.py \
+ --benchmark $BENCHMARK \
+ --llm-config $LLM_CONFIG \
+ --parallel-backend $PARALLEL_BACKEND \
+ --n-jobs $N_JOBS \
+ --n-relaunch $N_RELAUNCH
\ No newline at end of file
diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py
new file mode 100644
index 00000000..a5a0d544
--- /dev/null
+++ b/experiments/hinter/run_hinter_agent.py
@@ -0,0 +1,84 @@
+
+from dotenv import load_dotenv
+import argparse
+
+load_dotenv()
+
+import logging
+import argparse
+
+from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs
+from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o
+from bgym import DEFAULT_BENCHMARKS
+from agentlab.experiments.study import Study
+
+logging.getLogger().setLevel(logging.WARNING)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--benchmark", required=True)
+ parser.add_argument("--llm-config", required=True)
+ parser.add_argument("--relaunch", action="http://github.com/store_true")
+ parser.add_argument("--n-jobs", type=int, default=6)
+ parser.add_argument("--parallel-backend", type=str, default="ray")
+ parser.add_argument("--reproducibility-mode", action="http://github.com/store_true")
+ # hint flags
+ parser.add_argument("--hint-type", type=str, default="docs")
+ parser.add_argument("--hint-index-type", type=str, default="sparse")
+ parser.add_argument("--hint-query-type", type=str, default="direct")
+ parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25")
+ parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m")
+ parser.add_argument("--hint-num-results", type=int, default=5)
+ parser.add_argument("--debug", action="http://github.com/store_true")
+ args = parser.parse_args()
+
+ flags = FLAGS_GPT_4o
+ flags.use_task_hint = True
+ flags.hint_type = args.hint_type
+ flags.hint_index_type = args.hint_index_type
+ flags.hint_query_type = args.hint_query_type
+ flags.hint_index_path = args.hint_index_path
+ flags.hint_retriever_path = args.hint_retriever_path
+ flags.hint_num_results = args.hint_num_results
+
+ # instantiate agent
+ agent_args = [GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config],
+ flags=flags,
+ )]
+
+ benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
+
+ if args.debug:
+ # shuffle env_args_list and
+ import numpy as np
+ rng = np.random.default_rng(42)
+ rng.shuffle(benchmark.env_args_list)
+ benchmark.env_args_list = benchmark.env_args_list[:6]
+
+
+ if args.relaunch:
+ # relaunch an existing study
+ study = Study.load_most_recent(contains=None)
+ study.find_incomplete(include_errors=True)
+
+ else:
+ study = Study(
+ agent_args,
+ benchmark,
+ logging_level=logging.WARNING,
+ logging_level_stdout=logging.WARNING,
+ )
+
+ study.run(
+ n_jobs=args.n_jobs,
+ parallel_backend=args.parallel_backend,
+ strict_reproducibility=args.reproducibility_mode,
+ n_relaunch=3,
+ )
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh
new file mode 100644
index 00000000..9d998ef2
--- /dev/null
+++ b/experiments/hinter/run_hinter_agent.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+BENCHMARK="workarena_l1"
+
+LLM_CONFIG="azure/gpt-5-mini-2025-08-07"
+# PARALLEL_BACKEND="sequential"
+PARALLEL_BACKEND="ray"
+
+HINT_TYPE="docs" # human, llm, docs
+HINT_INDEX_TYPE="sparse" # sparse, dense
+HINT_QUERY_TYPE="goal" # goal, llm
+HINT_NUM_RESULTS=3
+
+HINT_INDEX_PATH="indexes/servicenow-docs-bm25"
+# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m"
+HINT_RETRIEVER_PATH="google/embeddinggemma-300m"
+
+N_JOBS=6
+
+python experiments/hinter/run_hinter_agent.py \
+ --benchmark $BENCHMARK \
+ --llm-config $LLM_CONFIG \
+ --parallel-backend $PARALLEL_BACKEND \
+ --n-jobs $N_JOBS \
+ --hint-type $HINT_TYPE \
+ --hint-index-type $HINT_INDEX_TYPE \
+ --hint-query-type $HINT_QUERY_TYPE \
+ --hint-index-path $HINT_INDEX_PATH \
+ --hint-retriever-path $HINT_RETRIEVER_PATH \
+ --hint-num-results $HINT_NUM_RESULTS
\ No newline at end of file
diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py
index d1f48f76..646a52b2 100644
--- a/src/agentlab/agents/generic_agent/generic_agent.py
+++ b/src/agentlab/agents/generic_agent/generic_agent.py
@@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict:
def get_action(self, obs):
self.obs_history.append(obs)
+
main_prompt = MainPrompt(
action_set=self.action_set,
obs_history=self.obs_history,
diff --git a/src/agentlab/agents/generic_agent_hinter/__init__.py b/src/agentlab/agents/generic_agent_hinter/__init__.py
new file mode 100644
index 00000000..4ad73676
--- /dev/null
+++ b/src/agentlab/agents/generic_agent_hinter/__init__.py
@@ -0,0 +1,19 @@
+import importlib, sys, warnings
+
+OLD = __name__
+NEW = "agentlab.agents.hint_use_agent"
+SUBS = ("agent_configs", "generic_agent_prompt", "generic_agent", "tmlr_config")
+
+warnings.warn(
+ f"{OLD} is renamed to {NEW}. {OLD} will be removed in future",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
+# Alias the top-level
+new_mod = importlib.import_module(NEW)
+sys.modules[OLD] = new_mod
+
+# Alias known submodules
+for sub in SUBS:
+ sys.modules[f"{OLD}.{sub}"] = importlib.import_module(f"{NEW}.{sub}")
diff --git a/src/agentlab/agents/hint_use_agent/__init__.py b/src/agentlab/agents/hint_use_agent/__init__.py
new file mode 100644
index 00000000..08b44255
--- /dev/null
+++ b/src/agentlab/agents/hint_use_agent/__init__.py
@@ -0,0 +1,54 @@
+"""
+Baseline agent for all ServiceNow papers
+
+This module contains the GenericAgent class, which is the baseline agent for all ServiceNow papers. \
+It is a simple agent that can be ran OOB on all BrowserGym environments. It is also shipped with \
+a few configurations that can be used to run it on different environments.
+"""
+
+from .agent_configs import (
+ AGENT_3_5,
+ AGENT_8B,
+ AGENT_37_SONNET,
+ AGENT_CLAUDE_SONNET_35,
+ AGENT_CLAUDE_SONNET_35_VISION,
+ AGENT_CUSTOM,
+ AGENT_GPT5_MINI,
+ AGENT_GPT5_NANO,
+ AGENT_LLAMA3_70B,
+ AGENT_LLAMA4_17B_INSTRUCT,
+ AGENT_LLAMA31_70B,
+ CHAT_MODEL_ARGS_DICT,
+ RANDOM_SEARCH_AGENT,
+ AGENT_4o,
+ AGENT_4o_MINI,
+ AGENT_4o_MINI_VISION,
+ AGENT_4o_VISION,
+ AGENT_o1_MINI,
+ AGENT_o3_MINI,
+ FLAGS_GPT_4o,
+ GenericAgentArgs,
+)
+from .generic_agent import GenericAgent, GenericAgentArgs
+
+__all__ = [
+ "AGENT_3_5",
+ "AGENT_4o",
+ "AGENT_4o_MINI",
+ "AGENT_4o_VISION",
+ "AGENT_o3_MINI",
+ "AGENT_o1_MINI",
+ "AGENT_LLAMA4_17B_INSTRUCT",
+ "AGENT_LLAMA3_70B",
+ "AGENT_LLAMA31_70B",
+ "AGENT_8B",
+ "RANDOM_SEARCH_AGENT",
+ "AGENT_CUSTOM",
+ "AGENT_CLAUDE_SONNET_35",
+ "AGENT_37_SONNET",
+ "AGENT_4o_VISION",
+ "AGENT_4o_MINI_VISION",
+ "AGENT_CLAUDE_SONNET_35_VISION",
+ "AGENT_GPT5_MINI",
+ "AGENT_GPT5_NANO",
+]
diff --git a/src/agentlab/agents/hint_use_agent/agent_configs.py b/src/agentlab/agents/hint_use_agent/agent_configs.py
new file mode 100644
index 00000000..031b824c
--- /dev/null
+++ b/src/agentlab/agents/hint_use_agent/agent_configs.py
@@ -0,0 +1,424 @@
+"""
+Basic flags and agent configurations for generic agents.
+"""
+
+import bgym
+from bgym import HighLevelActionSetArgs
+
+from agentlab.agents import dynamic_prompting as dp
+from agentlab.experiments import args
+from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
+
+from .generic_agent import GenericAgentArgs
+from .generic_agent_prompt import GenericPromptFlags
+from .tmlr_config import BASE_FLAGS
+
+FLAGS_CUSTOM = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False,
+ use_ax_tree=True,
+ use_focused_element=True,
+ use_error_logs=True,
+ use_history=True,
+ use_past_error_logs=False,
+ use_action_history=True,
+ use_think_history=False,
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=True,
+ extract_clickable_tag=False,
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ ),
+ long_description=False,
+ individual_examples=True,
+ ),
+ use_plan=False,
+ use_criticise=False,
+ use_thinking=True,
+ use_memory=False,
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=True,
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+)
+
+
+AGENT_CUSTOM = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"],
+ flags=FLAGS_CUSTOM,
+)
+
+
+# GPT-3.5 default config
+FLAGS_GPT_3_5 = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False, # too big for most benchmark except miniwob
+ use_ax_tree=True, # very useful
+ use_focused_element=True, # detrimental on minowob according to ablation study
+ use_error_logs=True,
+ use_history=True,
+ use_past_error_logs=False, # very detrimental on L1 and miniwob
+ use_action_history=True, # helpful on miniwob
+ use_think_history=False, # detrimental on L1 and miniwob
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=True, # doesn't change much
+ extract_clickable_tag=False, # doesn't change much
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ ),
+ long_description=False,
+ individual_examples=True,
+ ),
+ use_plan=False, # usually detrimental
+ use_criticise=False, # usually detrimental
+ use_thinking=True, # very useful
+ use_memory=False,
+ use_concrete_example=True, # useful
+ use_abstract_example=True, # useful
+ use_hints=True, # useful
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+)
+
+
+AGENT_3_5 = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-3.5-turbo-1106"],
+ flags=FLAGS_GPT_3_5,
+)
+
+# llama3-70b default config
+FLAGS_LLAMA3_70B = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False,
+ use_ax_tree=True,
+ use_focused_element=True,
+ use_error_logs=False,
+ use_history=True,
+ use_past_error_logs=False,
+ use_action_history=True,
+ use_think_history=True,
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=True,
+ extract_clickable_tag=False,
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ ),
+ long_description=False,
+ individual_examples=True,
+ ),
+ use_plan=False,
+ use_criticise=False,
+ use_thinking=True,
+ use_memory=False,
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=True,
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+ add_missparsed_messages=True,
+)
+
+AGENT_LLAMA3_70B = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3-70b-instruct"],
+ flags=FLAGS_LLAMA3_70B,
+)
+AGENT_LLAMA31_70B = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-70b-instruct"],
+ flags=FLAGS_LLAMA3_70B,
+)
+
+FLAGS_8B = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False,
+ use_ax_tree=True,
+ use_focused_element=True,
+ use_error_logs=False,
+ use_history=True,
+ use_past_error_logs=False,
+ use_action_history=True,
+ use_think_history=False,
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=False,
+ extract_clickable_tag=False,
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=True,
+ ),
+ long_description=False,
+ individual_examples=True,
+ ),
+ use_plan=False,
+ use_criticise=False,
+ use_thinking=True,
+ use_memory=False,
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=True,
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+ add_missparsed_messages=True,
+)
+
+
+AGENT_8B = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["meta-llama/Meta-Llama-3-8B-Instruct"],
+ flags=FLAGS_8B,
+)
+
+
+AGENT_LLAMA31_8B = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"],
+ flags=FLAGS_8B,
+)
+
+
+# GPT-4o default config
+FLAGS_GPT_4o = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False,
+ use_ax_tree=True,
+ use_focused_element=True,
+ use_error_logs=True,
+ use_history=True,
+ use_past_error_logs=False,
+ use_action_history=True,
+ use_think_history=False,
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=True,
+ extract_clickable_tag=True,
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ ),
+ long_description=False,
+ individual_examples=False,
+ ),
+ use_plan=False,
+ use_criticise=False,
+ use_thinking=True,
+ use_memory=False,
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=True,
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+)
+
+AGENT_4o = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"],
+ flags=FLAGS_GPT_4o,
+)
+
+AGENT_4o_MINI = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
+ flags=FLAGS_GPT_4o,
+)
+AGENT_CLAUDE_SONNET_35 = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
+ flags=FLAGS_GPT_4o,
+)
+
+# Claude Sonnet 4 default config with task hints enabled
+FLAGS_CLAUDE_SONNET_4 = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False,
+ use_ax_tree=True,
+ use_focused_element=True,
+ use_error_logs=True,
+ use_history=True,
+ use_past_error_logs=False,
+ use_action_history=True,
+ use_think_history=False,
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=True,
+ extract_clickable_tag=True,
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ ),
+ long_description=False,
+ individual_examples=False,
+ ),
+ use_plan=False,
+ use_criticise=False,
+ use_thinking=True,
+ use_memory=False,
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=True,
+ use_task_hint=True, # Explicitly enable task hints
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+)
+
+AGENT_CLAUDE_SONNET_4 = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["anthropic/claude-sonnet-4-20250514"],
+ flags=FLAGS_CLAUDE_SONNET_4,
+)
+AGENT_37_SONNET = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.7-sonnet"],
+ flags=FLAGS_GPT_4o,
+)
+# AGENT_o3_MINI = GenericAgentArgs(
+# chat_model_args=CHAT_MODEL_ARGS_DICT["openai/o3-mini-2025-01-31"],
+# flags=FLAGS_GPT_4o,
+# )
+AGENT_o3_MINI = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o3-mini"],
+ flags=FLAGS_GPT_4o,
+)
+
+AGENT_o1_MINI = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o1-mini-2024-09-12"],
+ flags=FLAGS_GPT_4o,
+)
+# GPT-4o vision default config
+FLAGS_GPT_4o_VISION = FLAGS_GPT_4o.copy()
+FLAGS_GPT_4o_VISION.obs.use_screenshot = True
+FLAGS_GPT_4o_VISION.obs.use_som = True
+
+AGENT_4o_VISION = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"],
+ flags=FLAGS_GPT_4o_VISION,
+)
+
+AGENT_4o_MINI_VISION = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
+ flags=FLAGS_GPT_4o_VISION,
+)
+
+AGENT_CLAUDE_SONNET_35_VISION = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
+ flags=FLAGS_GPT_4o_VISION,
+)
+AGENT_LLAMA4_17B_INSTRUCT = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-4-maverick"],
+ flags=BASE_FLAGS,
+)
+GPT5_MINI_FLAGS = BASE_FLAGS.copy()
+GPT5_MINI_FLAGS.action = dp.ActionFlags( # action should not be str to work with agentlab-assistant
+ action_set=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ )
+)
+
+AGENT_GPT5_MINI = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-mini-2025-08-07"],
+ flags=GPT5_MINI_FLAGS,
+)
+AGENT_GPT5_NANO = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-nano-2025-08-07"],
+ flags=GPT5_MINI_FLAGS,
+)
+
+AGENT_GPT5 = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-2025-08-07"],
+ flags=GPT5_MINI_FLAGS,
+)
+
+DEFAULT_RS_FLAGS = GenericPromptFlags(
+ flag_group="default_rs",
+ obs=dp.ObsFlags(
+ use_html=True,
+ use_ax_tree=args.Choice([True, False]),
+ use_focused_element=False,
+ use_error_logs=True,
+ use_history=True,
+ use_past_error_logs=args.Choice([True, False], p=[0.7, 0.3]),
+ use_action_history=True,
+ use_think_history=args.Choice([True, False], p=[0.7, 0.3]),
+ use_diff=args.Choice([True, False], p=[0.3, 0.7]),
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=args.Choice([True, False]),
+ extract_clickable_tag=False,
+ extract_coords=args.Choice(["center", "box"]),
+ filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]),
+ ),
+ action=dp.ActionFlags(
+ action_set=HighLevelActionSetArgs(
+ subsets=args.Choice([["bid"], ["bid", "coord"]]),
+ multiaction=args.Choice([True, False], p=[0.7, 0.3]),
+ ),
+ long_description=False,
+ individual_examples=False,
+ ),
+ # drop_ax_tree_first=True, # this flag is no longer active, according to browsergym doc
+ use_plan=args.Choice([True, False]),
+ use_criticise=args.Choice([True, False], p=[0.7, 0.3]),
+ use_thinking=args.Choice([True, False], p=[0.7, 0.3]),
+ use_memory=args.Choice([True, False], p=[0.7, 0.3]),
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=args.Choice([True, False], p=[0.7, 0.3]),
+ be_cautious=args.Choice([True, False]),
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ extra_instructions=None,
+)
+
+
+RANDOM_SEARCH_AGENT = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"],
+ flags=DEFAULT_RS_FLAGS,
+)
diff --git a/src/agentlab/agents/hint_use_agent/generic_agent.py b/src/agentlab/agents/hint_use_agent/generic_agent.py
new file mode 100644
index 00000000..afc688ed
--- /dev/null
+++ b/src/agentlab/agents/hint_use_agent/generic_agent.py
@@ -0,0 +1,391 @@
+"""
+GenericAgent implementation for AgentLab
+
+This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \
+The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \
+observations. It includes methods for preprocessing observations, generating actions, and managing internal \
+state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \
+the agent, including model arguments and flags for various behaviors.
+"""
+
+import os
+from copy import deepcopy
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from warnings import warn
+
+import pandas as pd
+from agentlab.agents import dynamic_prompting as dp
+from agentlab.agents.agent_args import AgentArgs
+from agentlab.llm.chat_api import BaseModelArgs
+from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
+from agentlab.llm.tracking import cost_tracker_decorator
+from agentlab.utils.hinting import HintsSource
+from bgym import Benchmark
+from browsergym.experiments.agent import Agent, AgentInfo
+
+from .generic_agent_prompt import (
+ GenericPromptFlags,
+ MainPrompt,
+ StepWiseContextIdentificationPrompt,
+)
+
+
+@dataclass
+class GenericAgentArgs(AgentArgs):
+ chat_model_args: BaseModelArgs = None
+ flags: GenericPromptFlags = None
+ max_retry: int = 4
+
+ def __post_init__(self):
+ try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
+ # TODO: Rename the agent to HintUseAgent when appropriate
+ self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace(
+ "/", "_"
+ )
+ except AttributeError:
+ pass
+
+ def set_benchmark(self, benchmark: Benchmark, demo_mode):
+ """Override Some flags based on the benchmark."""
+ if benchmark.name.startswith("miniwob"):
+ self.flags.obs.use_html = True
+
+ self.flags.obs.use_tabs = benchmark.is_multi_tab
+ self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args)
+
+ # for backward compatibility with old traces
+ if self.flags.action.multi_actions is not None:
+ self.flags.action.action_set.multiaction = self.flags.action.multi_actions
+ if self.flags.action.is_strict is not None:
+ self.flags.action.action_set.strict = self.flags.action.is_strict
+
+ # verify if we can remove this
+ if demo_mode:
+ self.flags.action.action_set.demo_mode = "all_blue"
+
+ def set_reproducibility_mode(self):
+ self.chat_model_args.temperature = 0
+
+ def prepare(self):
+ return self.chat_model_args.prepare_server()
+
+ def close(self):
+ return self.chat_model_args.close_server()
+
+ def make_agent(self):
+ return GenericAgent(
+ chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
+ )
+
+
+class GenericAgent(Agent):
+
+ def __init__(
+ self,
+ chat_model_args: BaseModelArgs,
+ flags: GenericPromptFlags,
+ max_retry: int = 4,
+ ):
+
+ self.chat_llm = chat_model_args.make_model()
+ self.chat_model_args = chat_model_args
+ self.max_retry = max_retry
+
+ self.flags = flags
+
+ if self.flags.hint_db_path is not None and self.flags.use_task_hint:
+ assert os.path.exists(
+ self.flags.hint_db_path
+ ), f"Hint database path {self.flags.hint_db_path} does not exist."
+ self.action_set = self.flags.action.action_set.make_action_set()
+ self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
+
+ self._init_hints_index()
+
+ self._check_flag_constancy()
+ self.reset(seed=None)
+
+ def obs_preprocessor(self, obs: dict) -> dict:
+ return self._obs_preprocessor(obs)
+
+ def set_task_name(self, task_name: str):
+ """Set the task name for task hints functionality."""
+ self.task_name = task_name
+
+ @cost_tracker_decorator
+ def get_action(self, obs):
+
+ self.obs_history.append(obs)
+
+ system_prompt = SystemMessage(dp.SystemPrompt().prompt)
+
+ # use those queries to retrieve from the database and pass to prompt if step-level
+ self.queries = (
+ self._get_queries()[0]
+ if getattr(self.flags, "hint_level", "episode") == "step"
+ else None
+ )
+
+ # get hints
+ if self.flags.use_hints:
+ task_hints = self._get_task_hints()
+ else:
+ task_hints = []
+
+ main_prompt = MainPrompt(
+ action_set=self.action_set,
+ obs_history=self.obs_history,
+ actions=self.actions,
+ memories=self.memories,
+ thoughts=self.thoughts,
+ previous_plan=self.plan,
+ step=self.plan_step,
+ flags=self.flags,
+ llm=self.chat_llm,
+ task_hints=task_hints,
+ )
+
+ # Set task name for task hints if available
+ if self.flags.use_task_hint and hasattr(self, "task_name"):
+ main_prompt.set_task_name(self.task_name)
+
+ max_prompt_tokens, max_trunc_itr = self._get_maxes()
+
+ human_prompt = dp.fit_tokens(
+ shrinkable=main_prompt,
+ max_prompt_tokens=max_prompt_tokens,
+ model_name=self.chat_model_args.model_name,
+ max_iterations=max_trunc_itr,
+ additional_prompts=system_prompt,
+ )
+ try:
+ # TODO, we would need to further shrink the prompt if the retry
+ # cause it to be too long
+
+ chat_messages = Discussion([system_prompt, human_prompt])
+ ans_dict = retry(
+ self.chat_llm,
+ chat_messages,
+ n_retry=self.max_retry,
+ parser=main_prompt._parse_answer,
+ )
+ ans_dict["busted_retry"] = 0
+ # inferring the number of retries, TODO: make this less hacky
+ ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
+ except ParseError as e:
+ ans_dict = dict(
+ action=None,
+ n_retry=self.max_retry + 1,
+ busted_retry=1,
+ )
+
+ stats = self.chat_llm.get_stats()
+ stats["n_retry"] = ans_dict["n_retry"]
+ stats["busted_retry"] = ans_dict["busted_retry"]
+
+ self.plan = ans_dict.get("plan", self.plan)
+ self.plan_step = ans_dict.get("step", self.plan_step)
+ self.actions.append(ans_dict["action"])
+ self.memories.append(ans_dict.get("memory", None))
+ self.thoughts.append(ans_dict.get("think", None))
+
+ agent_info = AgentInfo(
+ think=ans_dict.get("think", None),
+ chat_messages=chat_messages,
+ stats=stats,
+ extra_info={"chat_model_args": asdict(self.chat_model_args)},
+ )
+ return ans_dict["action"], agent_info
+
+ def _get_queries(self):
+ """Retrieve queries for hinting."""
+ system_prompt = SystemMessage(dp.SystemPrompt().prompt)
+ query_prompt = StepWiseContextIdentificationPrompt(
+ obs_history=self.obs_history,
+ actions=self.actions,
+ thoughts=self.thoughts,
+ obs_flags=self.flags.obs,
+ n_queries=self.flags.n_retrieval_queries, # TODO
+ )
+
+ chat_messages = Discussion([system_prompt, query_prompt.prompt])
+ # BUG: Parsing fails multiple times.
+ ans_dict = retry(
+ self.chat_llm,
+ chat_messages,
+ n_retry=self.max_retry,
+ parser=query_prompt._parse_answer,
+ )
+
+ queries = ans_dict.get("queries", [])
+ assert len(queries) <= self.flags.n_retrieval_queries
+
+ # TODO: we should probably propagate these chat_messages to be able to see them in xray
+ return queries, ans_dict.get("think", None)
+
+ def reset(self, seed=None):
+ self.seed = seed
+ self.plan = "No plan yet"
+ self.plan_step = -1
+ self.memories = []
+ self.thoughts = []
+ self.actions = []
+ self.obs_history = []
+
+ def _check_flag_constancy(self):
+ flags = self.flags
+ if flags.obs.use_som:
+ if not flags.obs.use_screenshot:
+ warn(
+ """
+Warning: use_som=True requires use_screenshot=True. Disabling use_som."""
+ )
+ flags.obs.use_som = False
+ if flags.obs.use_screenshot:
+ if not self.chat_model_args.vision_support:
+ warn(
+ """
+Warning: use_screenshot is set to True, but the chat model \
+does not support vision. Disabling use_screenshot."""
+ )
+ flags.obs.use_screenshot = False
+ return flags
+
+ def _get_maxes(self):
+ maxes = (
+ self.flags.max_prompt_tokens,
+ self.chat_model_args.max_total_tokens,
+ self.chat_model_args.max_input_tokens,
+ )
+ maxes = [m for m in maxes if m is not None]
+ max_prompt_tokens = min(maxes) if maxes else None
+ max_trunc_itr = (
+ self.flags.max_trunc_itr
+ if self.flags.max_trunc_itr
+ else 20 # dangerous to change the default value here?
+ )
+ return max_prompt_tokens, max_trunc_itr
+
+ def _init_hints_index(self):
+ """Initialize the block."""
+ try:
+ if self.flags.hint_type == "docs":
+ if self.flags.hint_index_type == "sparse":
+ import bm25s
+
+ self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True)
+ elif self.flags.hint_index_type == "dense":
+ from datasets import load_from_disk
+ from sentence_transformers import SentenceTransformer
+
+ self.hint_index = load_from_disk(self.flags.hint_index_path)
+ self.hint_index.load_faiss_index(
+ "embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss"
+ )
+ self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path)
+ else:
+ raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}")
+ else:
+ # Use external path if provided, otherwise fall back to relative path
+ if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists():
+ hint_db_path = Path(self.flags.hint_db_path)
+ else:
+ hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path
+
+ if hint_db_path.exists():
+ self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
+ # Verify the expected columns exist
+ if (
+ "task_name" not in self.hint_db.columns
+ or "hint" not in self.hint_db.columns
+ ):
+ print(
+ f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
+ )
+ self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
+ else:
+ print(f"Warning: Hint database not found at {hint_db_path}")
+ self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
+ self.hints_source = HintsSource(
+ hint_db_path=hint_db_path.as_posix(),
+ hint_retrieval_mode=self.flags.hint_retrieval_mode,
+ skip_hints_for_current_task=self.flags.skip_hints_for_current_task,
+ )
+ except Exception as e:
+ # Fallback to empty database on any error
+ print(f"Warning: Could not load hint database: {e}")
+ self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
+
+ def _get_task_hints(self) -> list[str]:
+ """Get hints for a specific task."""
+ if not self.flags.use_task_hint:
+ return []
+
+ if self.flags.hint_type == "docs":
+ if not hasattr(self, "hint_index"):
+ print("Initializing hint index new time")
+ # @patricebechard It seems _.init() method is missing do we still need it?
+ # self._init()
+ if self.flags.hint_query_type == "goal":
+ query = self.obs_history[-1]["goal_object"][0]["text"]
+ elif self.flags.hint_query_type == "llm":
+ queries, _ = self._get_queries()
+ # HACK: only 1 query supported
+ query = queries[0]
+ else:
+ # @patricebechard: This raises an error with the default value 'direct'
+ raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}")
+
+ print(f"Query: {query}")
+ if self.flags.hint_index_type == "sparse":
+ import bm25s
+
+ query_tokens = bm25s.tokenize(query)
+ docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results)
+ docs = [elem["text"] for elem in docs[0]]
+ # HACK: truncate to 20k characters (should cover >99% of the cases)
+ for doc in docs:
+ if len(doc) > 20000:
+ doc = doc[:20000]
+ doc += " ...[truncated]"
+ elif self.flags.hint_index_type == "dense":
+ query_embedding = self.hint_retriever.encode(query)
+ _, docs = self.hint_index.get_nearest_examples(
+ "embeddings", query_embedding, k=self.flags.hint_num_results
+ )
+ docs = docs["text"]
+
+ return docs
+
+ # Check if hint_db has the expected structure
+ if (
+ self.hint_db.empty
+ or "task_name" not in self.hint_db.columns
+ or "hint" not in self.hint_db.columns
+ ):
+ return []
+
+ try:
+ # When step-level, pass queries as goal string to fit the llm_prompt
+ goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"]
+ if self.flags.hint_level == "step" and self.queries:
+ goal_or_queries = "\n".join(self.queries)
+
+ task_hints = self.hints_source.choose_hints(
+ self.chat_llm,
+ self.task_name,
+ goal_or_queries,
+ )
+
+ hints = []
+ for hint in task_hints:
+ hint = hint.strip()
+ if hint:
+ hints.append(f"- {hint}")
+
+ return hints
+ except Exception as e:
+ print(f"Warning: Error getting hints for task {self.task_name}: {e}")
+
+ return []
diff --git a/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py b/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py
new file mode 100644
index 00000000..5ccb73a9
--- /dev/null
+++ b/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py
@@ -0,0 +1,411 @@
+"""
+Prompt builder for GenericAgent
+
+It is based on the dynamic_prompting module from the agentlab package.
+"""
+
+import json
+import logging
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal
+
+import pandas as pd
+from agentlab.agents import dynamic_prompting as dp
+from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
+from agentlab.llm.chat_api import ChatModel
+from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise
+from browsergym.core.action.base import AbstractActionSet
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class GenericPromptFlags(dp.Flags):
+ """
+ A class to represent various flags used to control features in an application.
+
+ Attributes:
+ use_plan (bool): Ask the LLM to provide a plan.
+ use_criticise (bool): Ask the LLM to first draft and criticise the action before producing it.
+ use_thinking (bool): Enable a chain of thoughts.
+ use_concrete_example (bool): Use a concrete example of the answer in the prompt for a generic task.
+ use_abstract_example (bool): Use an abstract example of the answer in the prompt.
+ use_hints (bool): Add some human-engineered hints to the prompt.
+ use_task_hint (bool): Enable task-specific hints from hint database.
+ hint_db_path (str): Path to the hint database file.
+ enable_chat (bool): Enable chat mode, where the agent can interact with the user.
+ max_prompt_tokens (int): Maximum number of tokens allowed in the prompt.
+ be_cautious (bool): Instruct the agent to be cautious about its actions.
+ extra_instructions (Optional[str]): Extra instructions to provide to the agent.
+ add_missparsed_messages (bool): When retrying, add the missparsed messages to the prompt.
+ flag_group (Optional[str]): Group of flags used.
+ """
+
+ obs: dp.ObsFlags
+ action: dp.ActionFlags
+ use_plan: bool = False #
+ use_criticise: bool = False #
+ use_thinking: bool = False
+ use_memory: bool = False #
+ use_concrete_example: bool = True
+ use_abstract_example: bool = False
+ use_hints: bool = False
+ use_task_hint: bool = False
+ enable_chat: bool = False
+ max_prompt_tokens: int = None
+ be_cautious: bool = True
+ extra_instructions: str | None = None
+ add_missparsed_messages: bool = True
+ max_trunc_itr: int = 20
+ flag_group: str = None
+
+ # hint related
+ use_task_hint: bool = False
+ hint_type: str = "docs"
+ hint_index_type: str = "sparse"
+ hint_query_type: str = "direct"
+ hint_index_path: str = "indexes/servicenow-docs-bm25"
+ hint_retriever_path: str = "google/embeddinggemma-300m"
+ hint_num_results: int = 5
+ n_retrieval_queries: int = 1
+
+
+class MainPrompt(dp.Shrinkable):
+ def __init__(
+ self,
+ action_set: AbstractActionSet,
+ obs_history: list[dict],
+ actions: list[str],
+ memories: list[str],
+ thoughts: list[str],
+ previous_plan: str,
+ step: int,
+ flags: GenericPromptFlags,
+ llm: ChatModel,
+ task_hints: list[str] = [],
+ ) -> None:
+ super().__init__()
+ self.flags = flags
+ self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs)
+ goal = obs_history[-1]["goal_object"]
+ if self.flags.enable_chat:
+ self.instructions = dp.ChatInstructions(
+ obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions
+ )
+ else:
+ if sum([msg["role"] == "user" for msg in obs_history[-1].get("chat_messages", [])]) > 1:
+ logging.warning(
+ "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
+ )
+ self.instructions = dp.GoalInstructions(
+ goal, extra_instructions=flags.extra_instructions
+ )
+
+ self.obs = dp.Observation(
+ obs_history[-1],
+ self.flags.obs,
+ )
+
+ self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action)
+
+ def time_for_caution():
+ # no need for caution if we're in single action mode
+ return flags.be_cautious and (
+ flags.action.action_set.multiaction or flags.action.action_set == "python"
+ )
+
+ self.be_cautious = dp.BeCautious(visible=time_for_caution)
+ self.think = dp.Think(visible=lambda: flags.use_thinking)
+ self.hints = dp.Hints(visible=lambda: flags.use_hints)
+ self.task_hints = TaskHint(visible=lambda: flags.use_task_hint, task_hints=task_hints)
+ self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
+ self.criticise = Criticise(visible=lambda: flags.use_criticise)
+ self.memory = Memory(visible=lambda: flags.use_memory)
+
+ @property
+ def _prompt(self) -> HumanMessage:
+ prompt = HumanMessage(self.instructions.prompt)
+
+ prompt.add_text(
+ f"""\
+{self.obs.prompt}\
+{self.history.prompt}\
+{self.action_prompt.prompt}\
+{self.hints.prompt}\
+{self.task_hints.prompt}\
+{self.be_cautious.prompt}\
+{self.think.prompt}\
+{self.plan.prompt}\
+{self.memory.prompt}\
+{self.criticise.prompt}\
+"""
+ )
+
+ if self.flags.use_abstract_example:
+ prompt.add_text(
+ f"""
+# Abstract Example
+
+Here is an abstract version of the answer with description of the content of
+each tag. Make sure you follow this structure, but replace the content with your
+answer:
+{self.think.abstract_ex}\
+{self.plan.abstract_ex}\
+{self.memory.abstract_ex}\
+{self.criticise.abstract_ex}\
+{self.task_hints.abstract_ex}\
+{self.action_prompt.abstract_ex}\
+"""
+ )
+
+ if self.flags.use_concrete_example:
+ prompt.add_text(
+ f"""
+# Concrete Example
+
+Here is a concrete example of how to format your answer.
+Make sure to follow the template with proper tags:
+{self.think.concrete_ex}\
+{self.plan.concrete_ex}\
+{self.memory.concrete_ex}\
+{self.criticise.concrete_ex}\
+{self.task_hints.concrete_ex}\
+{self.action_prompt.concrete_ex}\
+"""
+ )
+ return self.obs.add_screenshot(prompt)
+
+ def shrink(self):
+ self.history.shrink()
+ self.obs.shrink()
+
+ def set_task_name(self, task_name: str):
+ """Set the task name for task hints functionality."""
+ self.task_name = task_name
+
+ def _parse_answer(self, text_answer):
+ ans_dict = {}
+ ans_dict.update(self.think.parse_answer(text_answer))
+ ans_dict.update(self.plan.parse_answer(text_answer))
+ ans_dict.update(self.memory.parse_answer(text_answer))
+ ans_dict.update(self.criticise.parse_answer(text_answer))
+ ans_dict.update(self.action_prompt.parse_answer(text_answer))
+ return ans_dict
+
+
+class Memory(dp.PromptElement):
+ _prompt = "" # provided in the abstract and concrete examples
+
+ _abstract_ex = """
+
+Write down anything you need to remember for next steps. You will be presented
+with the list of previous memories and past actions. Some tasks require to
+remember hints from previous steps in order to solve it.
+
+"""
+
+ _concrete_ex = """
+
+I clicked on bid "32" to activate tab 2. The accessibility tree should mention
+focusable for elements of the form at next step.
+
+"""
+
+ def _parse_answer(self, text_answer):
+ return parse_html_tags_raise(text_answer, optional_keys=["memory"], merge_multiple=True)
+
+
+class Plan(dp.PromptElement):
+ def __init__(self, previous_plan, plan_step, visible: bool = True) -> None:
+ super().__init__(visible=visible)
+ self.previous_plan = previous_plan
+ self._prompt = f"""
+# Plan:
+
+You just executed step {plan_step} of the previously proposed plan:\n{previous_plan}\n
+After reviewing the effect of your previous actions, verify if your plan is still
+relevant and update it if necessary.
+"""
+
+ _abstract_ex = """
+
+Provide a multi step plan that will guide you to accomplish the goal. There
+should always be steps to verify if the previous action had an effect. The plan
+can be revisited at each steps. Specifically, if there was something unexpected.
+The plan should be cautious and favor exploring befor submitting.
+
+
+Integer specifying the step of current action
+
+"""
+
+ _concrete_ex = """
+
+1. fill form (failed)
+ * type first name
+ * type last name
+2. Try to activate the form
+ * click on tab 2
+3. fill form again
+ * type first name
+ * type last name
+4. verify and submit
+ * verify form is filled
+ * submit if filled, if not, replan
+
+
+2
+"""
+
+ def _parse_answer(self, text_answer):
+ return parse_html_tags_raise(text_answer, optional_keys=["plan", "step"])
+
+
+class Criticise(dp.PromptElement):
+ _prompt = ""
+
+ _abstract_ex = """
+
+Write a first version of what you think is the right action.
+
+
+
+Criticise action_draft. What could be wrong with it? Enumerate reasons why it
+could fail. Did your past actions had the expected effect? Make sure you're not
+repeating the same mistakes.
+
+"""
+
+ _concrete_ex = """
+
+click("32")
+
+
+
+click("32") might not work because the element is not visible yet. I need to
+explore the page to find a way to activate the form.
+
+"""
+
+ def _parse_answer(self, text_answer):
+ return parse_html_tags_raise(text_answer, optional_keys=["action_draft", "criticise"])
+
+
+class TaskHint(dp.PromptElement):
+ def __init__(self, visible: bool, task_hints: list[str]) -> None:
+ super().__init__(visible=visible)
+ self.task_hints = task_hints
+
+ @property
+ def _prompt(self):
+ task_hint_str = "# Hints:\nHere are some hints for the task you are working on:\n"
+ for hint in self.task_hints:
+ task_hint_str += f"{hint}\n"
+ return task_hint_str
+
+ _abstract_ex = """
+
+What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none.
+
+"""
+
+ _concrete_ex = """
+
+Relevant hint: Based on the hints provided, I should focus on the form elements and use the
+accessibility tree to identify interactive elements before taking actions.
+
+"""
+
+
+class StepWiseContextIdentificationPrompt(dp.Shrinkable):
+ def __init__(
+ self,
+ obs_history: list[dict],
+ actions: list[str],
+ thoughts: list[str],
+ obs_flags: dp.ObsFlags,
+ n_queries: int = 1,
+ ) -> None:
+ super().__init__()
+ self.obs_flags = obs_flags
+ self.n_queries = n_queries
+ self.history = dp.History(obs_history, actions, None, thoughts, obs_flags)
+ self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"])
+ self.obs = dp.Observation(obs_history[-1], obs_flags)
+
+ self.think = dp.Think(visible=True) # To replace with static text maybe
+
+ @property
+ def _prompt(self) -> HumanMessage:
+ prompt = HumanMessage(self.instructions.prompt)
+
+ prompt.add_text(
+ f"""\
+{self.obs.prompt}\
+{self.history.prompt}\
+"""
+ )
+
+ example_queries = [
+ "The user has started sorting a table and needs to apply multiple column criteria simultaneously.",
+ "The user is attempting to configure advanced sorting options but the interface is unclear.",
+ "The user has selected the first sort column and is now looking for how to add a second sort criterion.",
+ "The user is in the middle of a multi-step sorting process and needs guidance on the next action.",
+ ]
+
+ example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2)
+
+ prompt.add_text(
+ f"""
+# Querying memory
+
+Before choosing an action, let's search our available documentation and memory for relevant context.
+Generate a brief, general summary of the current status to help identify useful hints. Return your answer in the following format:
+chain of thought
+json list of strings of queries
+
+Additional instructions: List of queries should contain up to {self.n_queries} queries. Both the think and the queries blocks are required!
+
+# Concrete Example
+```
+
+I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
+I will be able to sort by both at the same time.
+
+
+
+{example_queries_str}
+
+```
+Note: do not generate backticks.
+Now proceed to generate your own thoughts and queries.
+Always return non-empty answer, its very important!
+"""
+ )
+
+ return self.obs.add_screenshot(prompt)
+
+ def shrink(self):
+ self.history.shrink()
+ self.obs.shrink()
+
+ def _parse_answer(self, text_answer):
+ try:
+ ans_dict = parse_html_tags_raise(
+ text_answer, keys=["think", "queries"], merge_multiple=True
+ )
+ except Exception as e:
+ t = text_answer.replace("\n", "\\n")
+ logger.warning(f"Failed to parse llm answer: {e}. RAW answer: '{t}'. Will retry")
+ raise e
+ raw_queries = ans_dict.get("queries", "[]")
+ try:
+ ans_dict["queries"] = json.loads(raw_queries)
+ except Exception as e:
+ t = text_answer.replace("\n", "\\n")
+ logger.warning(
+ f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry"
+ )
+ raise e
+ return ans_dict
diff --git a/src/agentlab/agents/hint_use_agent/tmlr_config.py b/src/agentlab/agents/hint_use_agent/tmlr_config.py
new file mode 100644
index 00000000..5a749721
--- /dev/null
+++ b/src/agentlab/agents/hint_use_agent/tmlr_config.py
@@ -0,0 +1,78 @@
+"""
+Specific configurations for our 2024 TMLR submission.
+"""
+
+from copy import deepcopy
+
+from agentlab.agents import dynamic_prompting as dp
+from agentlab.experiments import args
+from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
+
+from .generic_agent import GenericAgentArgs
+from .generic_agent_prompt import GenericPromptFlags
+
+BASE_FLAGS = GenericPromptFlags(
+ obs=dp.ObsFlags(
+ use_html=False,
+ use_ax_tree=True,
+ use_focused_element=True,
+ use_error_logs=True,
+ use_history=True,
+ use_past_error_logs=False,
+ use_action_history=True,
+ use_think_history=True, # gpt-4o config except for this line
+ use_diff=False,
+ html_type="pruned_html",
+ use_screenshot=False,
+ use_som=False,
+ extract_visible_tag=True,
+ extract_clickable_tag=True,
+ extract_coords="False",
+ filter_visible_elements_only=False,
+ ),
+ action=dp.ActionFlags(
+ multi_actions=False,
+ action_set="bid",
+ long_description=False,
+ individual_examples=False,
+ ),
+ use_plan=False,
+ use_criticise=False,
+ use_thinking=True,
+ use_memory=False,
+ use_concrete_example=True,
+ use_abstract_example=True,
+ use_hints=True,
+ enable_chat=False,
+ max_prompt_tokens=40_000,
+ be_cautious=True,
+ extra_instructions=None,
+)
+
+
+def get_base_agent(llm_config: str):
+ return GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config],
+ flags=BASE_FLAGS,
+ )
+
+
+def get_vision_agent(llm_config: str):
+ flags = deepcopy(BASE_FLAGS)
+ flags.obs.use_screenshot = True
+ agent_args = GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config],
+ flags=flags,
+ )
+ agent_args.agent_name = f"{agent_args.agent_name}_vision"
+ return agent_args
+
+
+def get_som_agent(llm_config: str):
+ flags = deepcopy(BASE_FLAGS)
+ flags.obs.use_screenshot = True
+ flags.obs.use_som = True
+ return GenericAgentArgs(
+ chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config],
+ flags=flags,
+ )
diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py
index 375c829e..894616a4 100644
--- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py
+++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py
@@ -28,6 +28,7 @@
from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark
from agentlab.benchmarks.osworld import OSWorldActionSet
from agentlab.llm.base_api import BaseModelArgs
+from agentlab.llm.chat_api import ChatModel
from agentlab.llm.llm_utils import image_to_png_base64_url
from agentlab.llm.response_api import (
APIPayload,
@@ -40,8 +41,10 @@
ToolCalls,
)
from agentlab.llm.tracking import cost_tracker_decorator
+from agentlab.utils.hinting import HintsSource
logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
@dataclass
@@ -316,39 +319,22 @@ class TaskHint(Block):
def _init(self):
"""Initialize the block."""
- if Path(self.hint_db_rel_path).is_absolute():
- hint_db_path = Path(self.hint_db_rel_path)
- else:
- hint_db_path = Path(__file__).parent / self.hint_db_rel_path
- self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
- if self.hint_retrieval_mode == "emb":
- self.encode_hints()
-
- def oai_embed(self, text: str):
- response = self._oai_emb.create(input=text, model="text-embedding-3-small")
- return response.data[0].embedding
-
- def encode_hints(self):
- self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
- logger.info(
- f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
- )
- hints = self.uniq_hints["hint"].tolist()
- semantic_keys = self.uniq_hints["semantic_keys"].tolist()
- lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
- emb_path = f"{self.hint_db_rel_path}.embs.npy"
- assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
- logger.info(f"Loading hint embeddings from: {emb_path}")
- emb_dict = np.load(emb_path, allow_pickle=True).item()
- self.hint_embeddings = np.array([emb_dict[k] for k in lines])
- logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
+ if self.use_task_hint:
+ self.hints_source = HintsSource(
+ hint_db_path=self.hint_db_rel_path,
+ hint_retrieval_mode=self.hint_retrieval_mode,
+ top_n=self.top_n,
+ embedder_model=self.embedder_model,
+ embedder_server=self.embedder_server,
+ llm_prompt=self.llm_prompt,
+ )
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
if not self.use_task_hint:
return {}
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
- task_hints = self.choose_hints(llm, task_name, goal)
+ task_hints = self.hints_source.choose_hints(llm, task_name, goal)
hints = []
for hint in task_hints:
@@ -358,101 +344,13 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
if len(hints) > 0:
hints_str = (
- "# Hints:\nHere are some hints for the task you are working on:\n"
+ "\n# Hints:\nHere are some hints for the task you are working on:\n"
+ "\n".join(hints)
)
msg = llm.msg.user().add_text(hints_str)
discussion.append(msg)
- def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
- """Choose hints based on the task name."""
- if self.hint_retrieval_mode == "llm":
- return self.choose_hints_llm(llm, goal)
- elif self.hint_retrieval_mode == "direct":
- return self.choose_hints_direct(task_name)
- elif self.hint_retrieval_mode == "emb":
- return self.choose_hints_emb(goal)
- else:
- raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
-
- def choose_hints_llm(self, llm, goal: str) -> list[str]:
- """Choose hints using LLM to filter the hints."""
- topic_to_hints = defaultdict(list)
- for i, row in self.hint_db.iterrows():
- topic_to_hints[row["semantic_keys"]].append(i)
- hint_topics = list(topic_to_hints.keys())
- topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
- prompt = self.llm_prompt.format(goal=goal, topics=topics)
- response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)]))
- try:
- hint_topic_idx = json.loads(response.think)
- if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
- logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
- return []
- hint_topic = hint_topics[hint_topic_idx]
- hint_indices = topic_to_hints[hint_topic]
- df = self.hint_db.iloc[hint_indices].copy()
- df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
- hints = df["hint"].tolist()
- logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
- except json.JSONDecodeError:
- logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints")
- hints = []
- return hints
-
- def choose_hints_emb(self, goal: str) -> list[str]:
- """Choose hints using embeddings to filter the hints."""
- goal_embeddings = self._encode([goal], prompt="task description")
- similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist())
- top_indices = similarities.argsort()[0][-self.top_n :].tolist()
- logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
- hints = self.uniq_hints.iloc[top_indices]
- logger.info(f"Embedding-based hints chosen: {hints}")
- return hints["hint"].tolist()
-
- def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
- """Call the encode API endpoint with timeout and retries"""
- for attempt in range(max_retries):
- try:
- response = requests.post(
- f"{self.embedder_server}/encode",
- json={"texts": texts, "prompt": prompt},
- timeout=timeout,
- )
- embs = response.json()["embeddings"]
- return np.asarray(embs)
- except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
- if attempt == max_retries - 1:
- raise e
- time.sleep(random.uniform(1, timeout))
- continue
-
- def _similarity(
- self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5
- ):
- """Call the similarity API endpoint with timeout and retries"""
- for attempt in range(max_retries):
- try:
- response = requests.post(
- f"{self.embedder_server}/similarity",
- json={"texts1": texts1, "texts2": texts2},
- timeout=timeout,
- )
- similarities = response.json()["similarities"]
- return np.asarray(similarities)
- except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
- if attempt == max_retries - 1:
- raise e
- time.sleep(random.uniform(1, timeout))
- continue
-
- def choose_hints_direct(self, task_name: str) -> list[str]:
- hints = self.hint_db[
- self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
- ]
- return hints["hint"].tolist()
-
@dataclass
class PromptConfig:
diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py
index 6dbec117..ce8654dc 100644
--- a/src/agentlab/analyze/agent_xray.py
+++ b/src/agentlab/analyze/agent_xray.py
@@ -537,9 +537,10 @@ def run_gradio(results_dir: Path):
do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", "false").lower() == "true"
port = os.getenv("AGENTXRAY_APP_PORT", None)
+ server_name = "0.0.0.0" if os.getenv("AGENTXRAY_PUBLIC", "false") == "true" else "127.0.0.1"
if isinstance(port, str):
port = int(port)
- demo.launch(server_port=port, share=do_share)
+ demo.launch(server_name=server_name, server_port=port, share=do_share)
def handle_key_event(key_event, step_id: StepId):
diff --git a/src/agentlab/experiments/graph_execution_ray.py b/src/agentlab/experiments/graph_execution_ray.py
index f047f866..f7aad780 100644
--- a/src/agentlab/experiments/graph_execution_ray.py
+++ b/src/agentlab/experiments/graph_execution_ray.py
@@ -3,9 +3,8 @@
import bgym
import ray
-from ray.util import state
-
from agentlab.experiments.exp_utils import _episode_timeout, run_exp
+from ray.util import state
logger = logging.getLogger(__name__)
@@ -79,6 +78,7 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter
try:
result = ray.get(task)
except Exception as e:
+ logger.exception(f"Task failed: {e}")
result = e
results.append(result)
diff --git a/src/agentlab/experiments/launch_exp.py b/src/agentlab/experiments/launch_exp.py
index 3bc6c54e..f14ae4c7 100644
--- a/src/agentlab/experiments/launch_exp.py
+++ b/src/agentlab/experiments/launch_exp.py
@@ -1,4 +1,5 @@
import logging
+import os
from importlib import import_module
from pathlib import Path
@@ -7,6 +8,8 @@
from agentlab.experiments.exp_utils import run_exp
from agentlab.experiments.loop import ExpArgs, yield_all_exp_results
+RAY_PUBLIC_DASHBOARD = os.environ.get("RAY_PUBLIC_DASHBOARD", "false") == "true"
+
def run_experiments(
n_jobs,
@@ -82,7 +85,9 @@ def run_experiments(
elif parallel_backend == "ray":
from agentlab.experiments.graph_execution_ray import execute_task_graph, ray
- ray.init(num_cpus=n_jobs)
+ ray.init(
+ num_cpus=n_jobs, dashboard_host="0.0.0.0" if RAY_PUBLIC_DASHBOARD else "127.0.0.1"
+ )
try:
execute_task_graph(exp_args_list, avg_step_timeout=avg_step_timeout)
finally:
diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py
index de4b976a..f69322a6 100644
--- a/src/agentlab/experiments/loop.py
+++ b/src/agentlab/experiments/loop.py
@@ -907,7 +907,6 @@ def _move_old_exp(exp_dir):
def _get_env_name(task_name: str):
"""Register tasks if needed (lazy import) and return environment name."""
-
# lazy import
if task_name.startswith("miniwob"):
import browsergym.miniwob
@@ -915,6 +914,7 @@ def _get_env_name(task_name: str):
import browsergym.workarena
elif task_name.startswith("webarena"):
import browsergym.webarena
+ import browsergym.webarenalite
elif task_name.startswith("visualwebarena"):
import browsergym.visualwebarena
elif task_name.startswith("assistantbench"):
diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py
index c1ee458f..2d0cb6ed 100644
--- a/src/agentlab/llm/llm_configs.py
+++ b/src/agentlab/llm/llm_configs.py
@@ -20,37 +20,6 @@
]
CHAT_MODEL_ARGS_DICT = {
- "openai/gpt-5-2025-08-07": OpenAIModelArgs(
- model_name="gpt-5-2025-08-07",
- max_total_tokens=400_000,
- max_input_tokens=256_000,
- max_new_tokens=128_000,
- temperature=1, # gpt-5 supports temperature of 1 only
- vision_support=True,
- ),
- "openai/gpt-5-nano-2025-08-07": OpenAIModelArgs(
- model_name="gpt-5-nano-2025-08-07",
- max_total_tokens=400_000,
- max_input_tokens=256_000,
- max_new_tokens=128_000,
- temperature=1, # gpt-5 supports temperature of 1 only
- vision_support=True,
- ),
- "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs(
- model_name="gpt-5-mini-2025-08-07",
- max_total_tokens=400_000,
- max_input_tokens=256_000,
- max_new_tokens=128_000,
- temperature=1, # gpt-5 supports temperature of 1 only
- vision_support=True,
- ),
- "openai/gpt-4.1-nano-2025-04-14": OpenAIModelArgs(
- model_name="gpt-4.1-nano-2025-04-14",
- max_total_tokens=128_000,
- max_input_tokens=128_000,
- max_new_tokens=16_384,
- vision_support=True,
- ),
"openai/gpt-4.1-mini-2025-04-14": OpenAIModelArgs(
model_name="gpt-4.1-mini-2025-04-14",
max_total_tokens=128_000,
@@ -126,6 +95,30 @@
max_new_tokens=64_000,
temperature=1e-1,
),
+ "openai/gpt-5-nano-2025-08-07": OpenAIModelArgs(
+ model_name="gpt-5-nano-2025-08-07",
+ max_total_tokens=400_000,
+ max_input_tokens=400_000 - 4_000,
+ max_new_tokens=4_000,
+ temperature=1, # temperature param not supported by gpt-5
+ vision_support=True,
+ ),
+ "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs(
+ model_name="gpt-5-mini-2025-08-07",
+ max_total_tokens=400_000,
+ max_input_tokens=400_000 - 4_000,
+ max_new_tokens=4_000,
+ temperature=1, # temperature param not supported by gpt-5
+ vision_support=True,
+ ),
+ "openai/gpt-5-2025-08-07": OpenAIModelArgs(
+ model_name="gpt-5-2025-08-07",
+ max_total_tokens=400_000,
+ max_input_tokens=400_000 - 4_000,
+ max_new_tokens=4_000,
+ temperature=1, # temperature param not supported by gpt-5
+ vision_support=True,
+ ),
"azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs(
model_name="gpt-35-turbo",
max_total_tokens=8_192,
diff --git a/src/agentlab/utils/__init__.py b/src/agentlab/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py
new file mode 100644
index 00000000..506513d5
--- /dev/null
+++ b/src/agentlab/utils/hinting.py
@@ -0,0 +1,221 @@
+import fnmatch
+import json
+import logging
+import os
+import random
+import time
+from collections import defaultdict
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+import pandas as pd
+import requests
+from agentlab.llm.chat_api import ChatModel
+import re
+import json
+from agentlab.llm.response_api import APIPayload
+
+logger = logging.getLogger(__name__)
+
+
+class HintsSource:
+
+ def __init__(
+ self,
+ hint_db_path: str,
+ hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
+ skip_hints_for_current_task: bool = False,
+ skip_hints_for_current_goal: bool = False,
+ top_n: int = 4,
+ embedder_model: str = "Qwen/Qwen3-Embedding-0.6B",
+ embedder_server: str = "http://localhost:5000/",
+ llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
+You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
+Choose hint topic for the task and return only its number. Use the following output format:
+index for e.g. 0 for the topic with index 0. If you don't know the answer, return -1""",
+ ) -> None:
+ self.hint_db_path = hint_db_path
+ self.hint_retrieval_mode = hint_retrieval_mode
+ self.skip_hints_for_current_task = skip_hints_for_current_task
+ self.skip_hints_for_current_goal = skip_hints_for_current_goal
+ self.top_n = top_n
+ self.embedder_model = embedder_model
+ self.embedder_server = embedder_server
+ self.llm_prompt = llm_prompt
+
+ if Path(hint_db_path).is_absolute():
+ self.hint_db_path = Path(hint_db_path).as_posix()
+ else:
+ self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix()
+ self.hint_db = pd.read_csv(
+ self.hint_db_path,
+ header=0,
+ index_col=None,
+ dtype=str,
+ converters={
+ "trace_paths_json": lambda x: json.loads(x) if pd.notna(x) else [],
+ "source_trace_goals": lambda x: json.loads(x) if pd.notna(x) else [],
+ },
+ )
+ logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}")
+ if self.hint_retrieval_mode == "emb":
+ self.load_hint_vectors()
+
+ def load_hint_vectors(self):
+ self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
+ logger.info(
+ f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
+ )
+ hints = self.uniq_hints["hint"].tolist()
+ semantic_keys = self.uniq_hints["semantic_keys"].tolist()
+ lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
+ emb_path = f"{self.hint_db_path}.embs.npy"
+ assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
+ logger.info(f"Loading hint embeddings from: {emb_path}")
+ emb_dict = np.load(emb_path, allow_pickle=True).item()
+ self.hint_embeddings = np.array([emb_dict[k] for k in lines])
+ logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
+
+ def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
+ """Choose hints based on the task name."""
+ logger.info(
+ f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}"
+ )
+ if self.hint_retrieval_mode == "llm":
+ return self.choose_hints_llm(llm, goal, task_name)
+ elif self.hint_retrieval_mode == "direct":
+ return self.choose_hints_direct(task_name)
+ elif self.hint_retrieval_mode == "emb":
+ return self.choose_hints_emb(goal, task_name)
+ else:
+ raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
+
+ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
+ """Choose hints using LLM to filter the hints."""
+ topic_to_hints = defaultdict(list)
+ skip_hints = []
+ if self.skip_hints_for_current_task:
+ skip_hints += self.get_current_task_hints(task_name)
+ if self.skip_hints_for_current_goal:
+ skip_hints += self.get_current_goal_hints(goal)
+ for _, row in self.hint_db.iterrows():
+ hint = row["hint"]
+ if hint in skip_hints:
+ continue
+ topic_to_hints[row["semantic_keys"]].append(hint)
+ logger.info(f"Collected {len(topic_to_hints)} hint topics")
+ hint_topics = list(topic_to_hints.keys())
+ topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
+ prompt = self.llm_prompt.format(goal=goal, topics=topics)
+
+ if isinstance(llm, ChatModel):
+ response: str = llm(messages=[dict(role="user", content=prompt)])["content"]
+ else:
+ response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think
+ try:
+ matches = re.findall(r"(-?\d+)", response)
+ if not matches:
+ logger.error(f"No choice tags found in LLM response: {response}")
+ return []
+ if len(matches) > 1:
+ logger.warning(
+ f"LLM selected multiple topics for retrieval using only the first one."
+ )
+ topic_number = int(matches[0])
+ if topic_number < 0 or topic_number >= len(hint_topics):
+ logger.error(f"Wrong LLM hint id response: {response}, no hints")
+ return []
+ hint_topic = hint_topics[topic_number]
+ hints = list(set(topic_to_hints[hint_topic]))
+ logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}")
+ except Exception as e:
+ logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}")
+ hints = []
+ return hints
+
+ def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
+ """Choose hints using embeddings to filter the hints."""
+ try:
+ goal_embeddings = self._encode([goal], prompt="task description")
+ hint_embeddings = self.hint_embeddings.copy()
+ all_hints = self.uniq_hints["hint"].tolist()
+ skip_hints = []
+ if self.skip_hints_for_current_task:
+ skip_hints += self.get_current_task_hints(task_name)
+ if self.skip_hints_for_current_goal:
+ skip_hints += self.get_current_goal_hints(goal)
+ hint_embeddings = []
+ id_to_hint = {}
+ for hint, emb in zip(all_hints, self.hint_embeddings):
+ if hint in skip_hints:
+ continue
+ hint_embeddings.append(emb.tolist())
+ id_to_hint[len(hint_embeddings) - 1] = hint
+ logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints")
+ similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings)
+ top_indices = similarities.argsort()[0][-self.top_n :].tolist()
+ logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
+ hints = [id_to_hint[idx] for idx in top_indices]
+ logger.info(f"Embedding-based hints chosen: {hints}")
+ except Exception as e:
+ logger.exception(f"Failed to choose hints using embeddings: {e}")
+ hints = []
+ return hints
+
+ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
+ """Call the encode API endpoint with timeout and retries"""
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ f"{self.embedder_server}/encode",
+ json={"texts": texts, "prompt": prompt},
+ timeout=timeout,
+ )
+ embs = response.json()["embeddings"]
+ return np.asarray(embs)
+ except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
+ if attempt == max_retries - 1:
+ raise e
+ time.sleep(random.uniform(1, timeout))
+ continue
+ raise ValueError("Failed to encode hints")
+
+ def _similarity(
+ self,
+ texts1: list,
+ texts2: list,
+ timeout: int = 2,
+ max_retries: int = 5,
+ ):
+ """Call the similarity API endpoint with timeout and retries"""
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ f"{self.embedder_server}/similarity",
+ json={"texts1": texts1, "texts2": texts2},
+ timeout=timeout,
+ )
+ similarities = response.json()["similarities"]
+ return np.asarray(similarities)
+ except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
+ if attempt == max_retries - 1:
+ raise e
+ time.sleep(random.uniform(1, timeout))
+ continue
+ raise ValueError("Failed to compute similarity")
+
+ def choose_hints_direct(self, task_name: str) -> list[str]:
+ hints = self.get_current_task_hints(task_name)
+ logger.info(f"Direct hints chosen: {hints}")
+ return hints
+
+ def get_current_task_hints(self, task_name):
+ hints_df = self.hint_db[
+ self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
+ ]
+ return hints_df["hint"].tolist()
+
+ def get_current_goal_hints(self, goal_str: str):
+ mask = self.hint_db["source_trace_goals"].apply(lambda goals: goal_str in goals)
+ return self.hint_db.loc[mask, "hint"].tolist()