From 7af2d152adeed3d48c7a074e175125823e76f203 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 13 Aug 2025 12:16:50 +0200 Subject: [PATCH 01/53] fixes --- .../agents/tool_use_agent/tool_use_agent.py | 139 ++++++++++++++++-- src/agentlab/analyze/agent_xray.py | 2 +- src/agentlab/llm/tracking.py | 12 +- 3 files changed, 137 insertions(+), 16 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 6ac61180..b1407a87 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -1,10 +1,12 @@ import fnmatch import json +import logging from abc import ABC, abstractmethod +from collections import defaultdict from copy import copy from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Literal import bgym import pandas as pd @@ -16,6 +18,7 @@ overlay_som, prune_html, ) +from sentence_transformers import SentenceTransformer from agentlab.agents.agent_args import AgentArgs from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark @@ -34,6 +37,8 @@ ) from agentlab.llm.tracking import cost_tracker_decorator +logger = logging.getLogger(__name__) + @dataclass class Block(ABC): @@ -298,22 +303,45 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict: class TaskHint(Block): use_task_hint: bool = True hint_db_rel_path: str = "hint_db.csv" + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" + top_n: int = 4 # Number of top hints to return when using embedding retrieval + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""" def _init(self): """Initialize the block.""" - hint_db_path = Path(__file__).parent / self.hint_db_rel_path + if Path(self.hint_db_rel_path).is_absolute(): + hint_db_path = Path(self.hint_db_rel_path) + else: + hint_db_path = Path(__file__).parent / self.hint_db_rel_path self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + if self.hint_retrieval_mode == "emb": + logger.info("Load sentence transformer model for hint embeddings.") + self.emb_model = SentenceTransformer( + "Qwen/Qwen3-Embedding-0.6B", model_kwargs={"torch_dtype": "bfloat16"} + ) + self.encode_hints() + + def encode_hints(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints using {self.embedder_model} model." + ) + self.hint_embeddings = self.emb_model.encode( + self.uniq_hints["hint"].tolist(), prompt="task hint" + ) def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: - return + return {} - task_hints = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] + goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) + task_hints = self.choose_hints(llm, task_name, goal) hints = [] - for hint in task_hints["hint"]: + for hint in task_hints: hint = hint.strip() if hint: hints.append(f"- {hint}") @@ -327,6 +355,58 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: + """Choose hints based on the task name.""" + if self.hint_retrieval_mode == "llm": + return self.choose_hints_llm(llm, goal) + elif self.hint_retrieval_mode == "direct": + return self.choose_hints_direct(task_name) + elif self.hint_retrieval_mode == "emb": + return self.choose_hints_emb(goal) + else: + raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") + + def choose_hints_llm(self, llm, goal: str) -> list[str]: + """Choose hints using LLM to filter the hints.""" + topic_to_hints = defaultdict(list) + for i, row in self.hint_db.iterrows(): + topic_to_hints[row["semantic_keys"]].append(i) + hint_topics = list(topic_to_hints.keys()) + topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) + prompt = self.llm_prompt.format(goal=goal, topics=topics) + response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])) + try: + hint_topic_idx = json.loads(response.think) + if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response.think}, no hints") + return [] + hint_topic = hint_topics[hint_topic_idx] + hint_indices = topic_to_hints[hint_topic] + df = self.hint_db.iloc[hint_indices].copy() + df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints + hints = df["hint"].tolist() + logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") + except json.JSONDecodeError: + logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints") + hints = [] + return hints + + def choose_hints_emb(self, goal: str) -> list[str]: + """Choose hints using embeddings to filter the hints.""" + goal_embeddings = self.emb_model.encode([goal], prompt="task description") + similarities = self.emb_model.similarity(goal_embeddings, self.hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = self.uniq_hints.iloc[top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + return hints["hint"].tolist() + + def choose_hints_direct(self, task_name: str) -> list[str]: + hints = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + return hints["hint"].tolist() + @dataclass class PromptConfig: @@ -510,6 +590,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +GPT_4_1_CC_API = OpenAIChatModelArgs( + model_name="gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + GPT_4_1_MINI = OpenAIResponseModelArgs( model_name="gpt-4.1-mini", max_total_tokens=200_000, @@ -528,7 +617,7 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) -CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( +CLAUDE_SONNET_37 = ClaudeResponseModelArgs( model_name="claude-3-7-sonnet-20250219", max_total_tokens=200_000, max_input_tokens=200_000, @@ -537,6 +626,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +CLAUDE_SONNET_4 = ClaudeResponseModelArgs( + model_name="claude-sonnet-4-20250514", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + O3_RESPONSE_MODEL = OpenAIResponseModelArgs( model_name="o3-2025-04-16", max_total_tokens=200_000, @@ -554,6 +652,25 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +GPT_5 = OpenAIChatModelArgs( + model_name="gpt-5", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=None, + vision_support=True, +) + + +GPT_5_MINI = OpenAIChatModelArgs( + model_name="gpt-5-mini-2025-08-07", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=1.0, + vision_support=True, +) + GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs( model_name="openai/gpt-4.1", max_total_tokens=200_000, @@ -580,12 +697,12 @@ def get_action(self, obs: Any) -> float: keep_last_n_obs=None, multiaction=True, # whether to use multi-action or not # action_subsets=("bid",), - action_subsets=("coord"), + action_subsets=("coord",), # action_subsets=("coord", "bid"), ) AGENT_CONFIG = ToolUseAgentArgs( - model_args=CLAUDE_MODEL_CONFIG, + model_args=CLAUDE_SONNET_37, config=DEFAULT_PROMPT_CONFIG, ) @@ -605,7 +722,7 @@ def get_action(self, obs: Any) -> float: ) OSWORLD_CLAUDE = ToolUseAgentArgs( - model_args=CLAUDE_MODEL_CONFIG, + model_args=CLAUDE_SONNET_37, config=PromptConfig( tag_screenshot=True, goal=Goal(goal_as_system_msg=True), diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 84dc423d..37ead1c3 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -735,7 +735,7 @@ def dict_msg_to_markdown(d: dict): case _: parts.append(f"\n```\n{str(item)}\n```\n") - markdown = f"### {d["role"].capitalize()}\n" + markdown = f"### {d['role'].capitalize()}\n" markdown += "\n".join(parts) return markdown diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index e761a7f6..afcf5e07 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -178,9 +178,9 @@ def __call__(self, *args, **kwargs): # 'self' here calls ._call_api() method of the subclass response = self._call_api(*args, **kwargs) usage = dict(getattr(response, "usage", {})) - if "prompt_tokens_details" in usage: + if "prompt_tokens_details" in usage and usage["prompt_tokens_details"]: usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens - if "input_tokens_details" in usage: + if "input_tokens_details" in usage and usage["input_tokens_details"]: usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))} usage |= {"n_api_calls": 1} @@ -332,12 +332,16 @@ def get_effective_cost_from_openai_api(self, response) -> float: if api_type == "chatcompletion": total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens) output_tokens = usage.completion_tokens - cached_input_tokens = usage.prompt_tokens_details.cached_tokens + cached_input_tokens = ( + usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0 + ) new_input_tokens = total_input_tokens - cached_input_tokens elif api_type == "response": total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens) output_tokens = usage.output_tokens - cached_input_tokens = usage.input_tokens_details.cached_tokens + cached_input_tokens = ( + usage.input_tokens_details.cached_tokens if usage.input_tokens_details else 0 + ) new_input_tokens = total_input_tokens - cached_input_tokens else: logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.") From 3f9e4a2191f81d1a177e9be3d6eea734924754cd Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 13 Aug 2025 12:16:57 +0200 Subject: [PATCH 02/53] add new deps --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 6322ffd3..a2798f2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,5 @@ ray[default] python-slugify pillow gymnasium>=0.27 +sentence-transformers>=5.0.0 +python-dotenv>=1.1.1 \ No newline at end of file From c88d7f3fd0f7942e700d5d79ee79555f45cf3f6b Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 19 Aug 2025 14:14:38 +0200 Subject: [PATCH 03/53] use external embedding service in task hints retrieval --- .../agents/tool_use_agent/tool_use_agent.py | 75 +++++++++++++++---- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index b1407a87..f6ace3a8 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -1,6 +1,9 @@ import fnmatch import json import logging +import os +import random +import time from abc import ABC, abstractmethod from collections import defaultdict from copy import copy @@ -9,7 +12,9 @@ from typing import Any, Literal import bgym +import numpy as np import pandas as pd +import requests from bgym import Benchmark as BgymBenchmark from browsergym.core.observation import extract_screenshot from browsergym.utils.obs import ( @@ -18,7 +23,6 @@ overlay_som, prune_html, ) -from sentence_transformers import SentenceTransformer from agentlab.agents.agent_args import AgentArgs from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark @@ -181,7 +185,6 @@ class Obs(Block): def apply( self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput ) -> dict: - obs_msg = llm.msg.user() tool_calls = last_llm_output.tool_calls if self.use_last_error: @@ -306,6 +309,7 @@ class TaskHint(Block): hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" top_n: int = 4 # Number of top hints to return when using embedding retrieval embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints + embedder_server: str = "http://localhost:5000/" llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""" @@ -318,20 +322,26 @@ def _init(self): hint_db_path = Path(__file__).parent / self.hint_db_rel_path self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) if self.hint_retrieval_mode == "emb": - logger.info("Load sentence transformer model for hint embeddings.") - self.emb_model = SentenceTransformer( - "Qwen/Qwen3-Embedding-0.6B", model_kwargs={"torch_dtype": "bfloat16"} - ) self.encode_hints() + def oai_embed(self, text: str): + response = self._oai_emb.create(input=text, model="text-embedding-3-small") + return response.data[0].embedding + def encode_hints(self): self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") logger.info( - f"Encoding {len(self.uniq_hints)} unique hints using {self.embedder_model} model." - ) - self.hint_embeddings = self.emb_model.encode( - self.uniq_hints["hint"].tolist(), prompt="task hint" + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_rel_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: @@ -393,14 +403,50 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: def choose_hints_emb(self, goal: str) -> list[str]: """Choose hints using embeddings to filter the hints.""" - goal_embeddings = self.emb_model.encode([goal], prompt="task description") - similarities = self.emb_model.similarity(goal_embeddings, self.hint_embeddings) + goal_embeddings = self._encode([goal], prompt="task description") + similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist()) top_indices = similarities.argsort()[0][-self.top_n :].tolist() logger.info(f"Top hint indices based on embedding similarity: {top_indices}") hints = self.uniq_hints.iloc[top_indices] logger.info(f"Embedding-based hints chosen: {hints}") return hints["hint"].tolist() + def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): + """Call the encode API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/encode", + json={"texts": texts, "prompt": prompt}, + timeout=timeout, + ) + embs = response.json()["embeddings"] + return np.asarray(embs) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + + def _similarity( + self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 + ): + """Call the similarity API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/similarity", + json={"texts1": texts1, "texts2": texts2}, + timeout=timeout, + ) + similarities = response.json()["similarities"] + return np.asarray(similarities) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + def choose_hints_direct(self, task_name: str) -> list[str]: hints = self.hint_db[ self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) @@ -466,7 +512,8 @@ def __init__( self.model_args = model_args self.config = config self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet( - self.config.action_subsets, multiaction=self.config.multiaction # type: ignore + self.config.action_subsets, + multiaction=self.config.multiaction, # type: ignore ) self.tools = self.action_set.to_tool_description(api=model_args.api) @@ -656,7 +703,7 @@ def get_action(self, obs: Any) -> float: model_name="gpt-5", max_total_tokens=200_000, max_input_tokens=200_000, - max_new_tokens=2_000, + max_new_tokens=8_000, temperature=None, vision_support=True, ) From 74fc47f2820ec6dde79035a4d3bb5e5949d2c2bf Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 19 Aug 2025 14:14:49 +0200 Subject: [PATCH 04/53] gpt5 fixes --- src/agentlab/llm/chat_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index ff341356..dc9667b5 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -292,7 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float messages=messages, n=n_samples, temperature=temperature, - max_tokens=self.max_tokens, + max_completion_tokens=self.max_tokens, logprobs=self.log_probs, ) @@ -359,7 +359,7 @@ def __init__( min_retry_wait_time=min_retry_wait_time, api_key_env_var="OPENAI_API_KEY", client_class=OpenAI, - pricing_func=tracking.get_pricing_openai, + pricing_func=partial(tracking.get_pricing_litellm, model_name=model_name), log_probs=log_probs, ) From 1de1e519f2adb307d5affb4f51e000db0cc72914 Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Tue, 19 Aug 2025 20:27:52 -0400 Subject: [PATCH 05/53] first cut --- .../agents/human_trace_recorder/agent.py | 215 ++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 src/agentlab/agents/human_trace_recorder/agent.py diff --git a/src/agentlab/agents/human_trace_recorder/agent.py b/src/agentlab/agents/human_trace_recorder/agent.py new file mode 100644 index 00000000..52496b7e --- /dev/null +++ b/src/agentlab/agents/human_trace_recorder/agent.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import logging +import textwrap +from dataclasses import dataclass + +import bgym +from playwright.sync_api import Page + +from agentlab.agents.agent_args import AgentArgs + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Simplified variant: capture human step (trace + screenshot + html) only +# --------------------------------------------------------------------------- + + +@dataclass +class SimpleHumanTraceCaptureAgentArgs(AgentArgs): + """Args for SimpleHumanTraceCaptureAgent. + + This version ONLY captures what the human does in the paused browser per step. + It does NOT attempt to map or translate actions. Always returns noop(). + Set use_raw_page_output=True in loop/env so that obs contains a Playwright Page. + """ + + agent_name: str = "SimpleHumanTraceCapture" + trace_dir: str = "human_traces" + screenshots: bool = True + snapshots: bool = True # playwright tracing snapshots (DOM/Sources) + sources: bool = False # include source files (bigger trace) + # Ensure the raw Playwright Page object is present in observations so we can pause. + use_raw_page_output: bool = True + + def make_agent(self) -> bgym.Agent: + return SimpleHumanTraceCaptureAgent( + trace_dir=self.trace_dir, + screenshots=self.screenshots, + snapshots=self.snapshots, + sources=self.sources, + ) + + def set_reproducibility_mode(self): + pass + + +class SimpleHumanTraceCaptureAgent(bgym.Agent): + """Minimal human-in-the-loop recorder. + + On each get_action: + 1. Start a Playwright tracing capture (if not already running for this step). + 2. Call page.pause() to open Inspector; user performs EXACTLY one logical action. + 3. Stop tracing, save trace zip, screenshot (after action), and HTML snapshot. + 4. Return noop() so the environment advances. + + Artifacts are stored under trace_dir/step_/ + """ + + def __init__(self, trace_dir: str, screenshots: bool, snapshots: bool, sources: bool): + self.action_set = bgym.HighLevelActionSet(["bid"], multiaction=False) + self._step_idx = 0 + from pathlib import Path + + self._root = Path(trace_dir) + self._root.mkdir(parents=True, exist_ok=True) + # Store trace config booleans; Playwright tracing.start expects them as named params. + self._trace_conf = dict(screenshots=screenshots, snapshots=snapshots, sources=sources) + self._tracing_started = False # track if global tracing has been started + self._page: Page | None = None # optional persistent page ref (when not in obs) + + def set_page(self, page: Page): + """Manually inject a Playwright Page so the agent can function without it in obs. + + Call this once after you create / reset the environment if you prefer not to + expose the page through observations (e.g., for safety or serialization reasons). + """ + self._page = page + + def obs_preprocessor(self, obs): # keep original obs so page is available + return obs + + def get_action(self, obs: dict): # type: ignore[override] + import json + import time + + # Resolve page priority: observation > stored page + page: Page | None = obs.get("page") or self._page + if page is None: + raise RuntimeError( + "No Playwright Page available. Provide use_raw_page_output=True OR call set_page(page)." + ) + # Cache page if first time we see it via obs so later steps can omit it + if self._page is None: + self._page = page + + step_dir = self._root / f"step_{self._step_idx:04d}" + step_dir.mkdir(parents=True, exist_ok=True) + trace_path = step_dir / "trace.zip" + screenshot_path = step_dir / "after.png" + html_path = step_dir / "after.html" + + # Lazy start of tracing (once per context) then per-step chunk + if not self._tracing_started: + try: + page.context.tracing.start( + screenshots=self._trace_conf["screenshots"], + snapshots=self._trace_conf["snapshots"], + sources=self._trace_conf["sources"], + ) + self._tracing_started = True + except Exception as e: # pragma: no cover + print(f"[SimpleHumanTraceCapture][WARN] initial tracing.start failed: {e}") + + try: + page.context.tracing.start_chunk() + except Exception as e: # pragma: no cover + print(f"[SimpleHumanTraceCapture][WARN] tracing.start_chunk failed: {e}") + + print("\n[SimpleHumanTraceCapture] Perform ONE action then resume Inspector.") + print("[SimpleHumanTraceCapture] A trace will be saved to:", trace_path) + try: + page.pause() + except Exception as e: # pragma: no cover + print(f"[SimpleHumanTraceCapture][WARN] page.pause failed: {e}") + + # Stop current chunk & save + try: + page.context.tracing.stop_chunk(path=str(trace_path)) + except Exception as e: # pragma: no cover + print(f"[SimpleHumanTraceCapture][WARN] tracing.stop_chunk failed: {e}") + + # Post-action artifacts + try: + page.screenshot(path=str(screenshot_path)) + except Exception as e: # pragma: no cover + print(f"[SimpleHumanTraceCapture][WARN] screenshot failed: {e}") + try: + html = page.content() + html_path.write_text(html) + except Exception as e: # pragma: no cover + print(f"[SimpleHumanTraceCapture][WARN] html capture failed: {e}") + + meta = { + "url": page.url, + "timestamp": time.time(), + "step": self._step_idx, + "trace_path": str(trace_path), + "screenshot_path": str(screenshot_path), + "html_path": str(html_path), + } + (step_dir / "meta.json").write_text(json.dumps(meta, indent=2)) + + # --- Derive a lightweight human-readable script summary from the trace --- + script_summary_lines: list[str] = [] + try: + import json as _json + import zipfile + + with zipfile.ZipFile(trace_path, "r") as zf: + # Playwright trace usually contains one or more *.trace files (jsonl) + trace_files = [n for n in zf.namelist() if n.endswith(".trace")] + for tf in trace_files: + with zf.open(tf, "r") as fh: + for raw_line in fh: + try: + evt = _json.loads(raw_line.decode("utf-8")) + except Exception: + continue + if evt.get("type") != "action": + continue + a = evt.get("action", {}) + api_name = a.get("apiName") or a.get("name") or "action" + selector = a.get("selector") or a.get("locator") or "" + value = a.get("value") or a.get("text") or "" + line = f"{api_name}" + if selector: + line += f" selector={selector!r}" + if value and isinstance(value, str) and len(value) < 200: + line += f" value={value!r}" + script_summary_lines.append(line) + if not script_summary_lines: + script_summary_lines.append("(no action events parsed from trace chunk)") + except Exception as e: # pragma: no cover + script_summary_lines.append(f"(failed to parse trace for script summary: {e})") + + # Prepare chat messages (simple list of strings for easy viewing) + chat_messages = [ + "PLAYWRIGHT TRACE STEP SUMMARY:", + f"Step {self._step_idx} URL: {page.url}", + "Actions:", + *script_summary_lines, + f"Trace file: {trace_path}", + "Open with: npx playwright show-trace " + str(trace_path), + ] + + self._step_idx += 1 + + agent_info = bgym.AgentInfo( + think="human-recorded", + chat_messages=chat_messages, + stats={"step": self._step_idx}, + markdown_page=textwrap.dedent( + f"""### Simple Human Trace Capture\nSaved artifacts for step {meta['step']}:\n- URL: {meta['url']}\n- Trace: {meta['trace_path']}\n- Screenshot: {meta['screenshot_path']}\n- HTML: {meta['html_path']}\n""" + ), + extra_info=meta, + ) + return "noop()", agent_info + + +SIMPLE_TRACE_CAPTURE_AGENT = SimpleHumanTraceCaptureAgentArgs() + +##1. Simple debug agent +# 2. Instead of using the page object Launch codegen directly in a subprocess using the playwright codegen --url or somethiing From 2b4633a95c0e18724565d2a5ffa489f4c7ad220c Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Tue, 19 Aug 2025 21:30:00 -0400 Subject: [PATCH 06/53] update --- .../agents/human_trace_recorder/agent.py | 321 ++++++++---------- 1 file changed, 149 insertions(+), 172 deletions(-) diff --git a/src/agentlab/agents/human_trace_recorder/agent.py b/src/agentlab/agents/human_trace_recorder/agent.py index 52496b7e..fd5aa554 100644 --- a/src/agentlab/agents/human_trace_recorder/agent.py +++ b/src/agentlab/agents/human_trace_recorder/agent.py @@ -1,215 +1,192 @@ +"""Minimal Human Trace Agent (<200 lines) + +Per step we capture ONLY: + - axtree_txt, pruned_html, actions.json, after.html + - Auto-resume after detecting user action + - Visible recording indicator +""" + from __future__ import annotations -import logging -import textwrap +import json +import time +import zipfile from dataclasses import dataclass +from pathlib import Path import bgym from playwright.sync_api import Page from agentlab.agents.agent_args import AgentArgs - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Simplified variant: capture human step (trace + screenshot + html) only -# --------------------------------------------------------------------------- +from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html @dataclass -class SimpleHumanTraceCaptureAgentArgs(AgentArgs): - """Args for SimpleHumanTraceCaptureAgent. - - This version ONLY captures what the human does in the paused browser per step. - It does NOT attempt to map or translate actions. Always returns noop(). - Set use_raw_page_output=True in loop/env so that obs contains a Playwright Page. - """ - - agent_name: str = "SimpleHumanTraceCapture" +class HumanTraceAgentArgs(AgentArgs): + agent_name: str = "HumanTraceAgent" trace_dir: str = "human_traces" - screenshots: bool = True - snapshots: bool = True # playwright tracing snapshots (DOM/Sources) - sources: bool = False # include source files (bigger trace) - # Ensure the raw Playwright Page object is present in observations so we can pause. use_raw_page_output: bool = True - def make_agent(self) -> bgym.Agent: - return SimpleHumanTraceCaptureAgent( - trace_dir=self.trace_dir, - screenshots=self.screenshots, - snapshots=self.snapshots, - sources=self.sources, - ) + def make_agent(self) -> bgym.Agent: # type: ignore[override] + return HumanTraceAgent(self.trace_dir) def set_reproducibility_mode(self): pass -class SimpleHumanTraceCaptureAgent(bgym.Agent): - """Minimal human-in-the-loop recorder. - - On each get_action: - 1. Start a Playwright tracing capture (if not already running for this step). - 2. Call page.pause() to open Inspector; user performs EXACTLY one logical action. - 3. Stop tracing, save trace zip, screenshot (after action), and HTML snapshot. - 4. Return noop() so the environment advances. - - Artifacts are stored under trace_dir/step_/ - """ - - def __init__(self, trace_dir: str, screenshots: bool, snapshots: bool, sources: bool): +class HumanTraceAgent(bgym.Agent): + def __init__(self, trace_dir: str): self.action_set = bgym.HighLevelActionSet(["bid"], multiaction=False) - self._step_idx = 0 - from pathlib import Path - self._root = Path(trace_dir) self._root.mkdir(parents=True, exist_ok=True) - # Store trace config booleans; Playwright tracing.start expects them as named params. - self._trace_conf = dict(screenshots=screenshots, snapshots=snapshots, sources=sources) - self._tracing_started = False # track if global tracing has been started - self._page: Page | None = None # optional persistent page ref (when not in obs) - - def set_page(self, page: Page): - """Manually inject a Playwright Page so the agent can function without it in obs. - - Call this once after you create / reset the environment if you prefer not to - expose the page through observations (e.g., for safety or serialization reasons). - """ - self._page = page - - def obs_preprocessor(self, obs): # keep original obs so page is available + self._page: Page | None = None + self._step = 0 + self._task_name = None + self._seed = None + + def obs_preprocessor(self, obs: dict): # type: ignore[override] + if isinstance(obs, dict): + if self._page is None and "page" in obs: + self._page = obs["page"] + + # Extract task name and seed from obs if available + if self._task_name is None: + self._task_name = obs.get("task_name") or obs.get("task", {}).get( + "task_name", "unknown_task" + ) + if self._seed is None: + self._seed = obs.get("seed") or obs.get("task", {}).get("seed", "unknown_seed") + + dom = obs.get("dom_object") + axt = obs.get("axtree_object") + if axt is not None: + try: + obs["axtree_txt"] = flatten_axtree_to_str(axt) + except Exception: + pass + if dom is not None: + try: + obs["pruned_html"] = prune_html(flatten_dom_to_str(dom)) + except Exception: + pass + for k in ("dom_object", "axtree_object", "page"): + obs.pop(k, None) return obs def get_action(self, obs: dict): # type: ignore[override] - import json - import time - - # Resolve page priority: observation > stored page - page: Page | None = obs.get("page") or self._page - if page is None: - raise RuntimeError( - "No Playwright Page available. Provide use_raw_page_output=True OR call set_page(page)." - ) - # Cache page if first time we see it via obs so later steps can omit it if self._page is None: - self._page = page + raise RuntimeError("Playwright Page missing; ensure use_raw_page_output=True") - step_dir = self._root / f"step_{self._step_idx:04d}" + page = self._page + + # Create directory structure: trace_dir/task_name/seed/step_XXXX + task_dir = self._root / str(self._task_name or "unknown_task") + seed_dir = task_dir / str(self._seed or "unknown_seed") + step_dir = seed_dir / f"step_{self._step:04d}" step_dir.mkdir(parents=True, exist_ok=True) - trace_path = step_dir / "trace.zip" - screenshot_path = step_dir / "after.png" - html_path = step_dir / "after.html" - # Lazy start of tracing (once per context) then per-step chunk - if not self._tracing_started: - try: - page.context.tracing.start( - screenshots=self._trace_conf["screenshots"], - snapshots=self._trace_conf["snapshots"], - sources=self._trace_conf["sources"], - ) - self._tracing_started = True - except Exception as e: # pragma: no cover - print(f"[SimpleHumanTraceCapture][WARN] initial tracing.start failed: {e}") + trace_path = step_dir / "temp_trace.zip" + actions_path = step_dir / "actions.json" + + print( + f"[HumanTrace] Task: {self._task_name}, Seed: {self._seed}, Step {self._step}: Perform ONE action" + ) + # Small recording indicator + page.evaluate( + """ + const div = document.createElement('div'); + div.id = '__rec'; + div.innerHTML = '🔴 REC'; + div.style.cssText = 'position:fixed;top:5px;right:5px;background:#f44;color:#fff;padding:5px 8px;border-radius:4px;font:bold 12px monospace;z-index:99999'; + document.body.appendChild(div); + """ + ) + + # Start tracing try: + page.context.tracing.start(screenshots=True, snapshots=True) page.context.tracing.start_chunk() - except Exception as e: # pragma: no cover - print(f"[SimpleHumanTraceCapture][WARN] tracing.start_chunk failed: {e}") + except Exception: + pass - print("\n[SimpleHumanTraceCapture] Perform ONE action then resume Inspector.") - print("[SimpleHumanTraceCapture] A trace will be saved to:", trace_path) - try: - page.pause() - except Exception as e: # pragma: no cover - print(f"[SimpleHumanTraceCapture][WARN] page.pause failed: {e}") + # Wait for action + self._wait_for_action(page) - # Stop current chunk & save + # Stop tracing and save try: page.context.tracing.stop_chunk(path=str(trace_path)) - except Exception as e: # pragma: no cover - print(f"[SimpleHumanTraceCapture][WARN] tracing.stop_chunk failed: {e}") + actions = self._extract_trace(str(trace_path)) + actions_path.write_text(json.dumps(actions, indent=2)) + trace_path.unlink(missing_ok=True) + except Exception: + pass - # Post-action artifacts + # Remove indicator + page.evaluate("document.getElementById('__rec')?.remove()") + + # Save screenshot try: - page.screenshot(path=str(screenshot_path)) - except Exception as e: # pragma: no cover - print(f"[SimpleHumanTraceCapture][WARN] screenshot failed: {e}") + page.screenshot(path=str(step_dir / "screenshot.png")) + except Exception: + pass + + # Save HTML try: - html = page.content() - html_path.write_text(html) - except Exception as e: # pragma: no cover - print(f"[SimpleHumanTraceCapture][WARN] html capture failed: {e}") - - meta = { - "url": page.url, - "timestamp": time.time(), - "step": self._step_idx, - "trace_path": str(trace_path), - "screenshot_path": str(screenshot_path), - "html_path": str(html_path), + (step_dir / "after.html").write_text(page.content()) + except Exception: + pass + + self._step += 1 + return "noop()", { + "extra_info": { + "step": self._step - 1, + "task_name": self._task_name, + "seed": self._seed, + "trace_dir": str(step_dir), + } } - (step_dir / "meta.json").write_text(json.dumps(meta, indent=2)) - # --- Derive a lightweight human-readable script summary from the trace --- - script_summary_lines: list[str] = [] - try: - import json as _json - import zipfile - - with zipfile.ZipFile(trace_path, "r") as zf: - # Playwright trace usually contains one or more *.trace files (jsonl) - trace_files = [n for n in zf.namelist() if n.endswith(".trace")] - for tf in trace_files: - with zf.open(tf, "r") as fh: - for raw_line in fh: - try: - evt = _json.loads(raw_line.decode("utf-8")) - except Exception: - continue - if evt.get("type") != "action": - continue - a = evt.get("action", {}) - api_name = a.get("apiName") or a.get("name") or "action" - selector = a.get("selector") or a.get("locator") or "" - value = a.get("value") or a.get("text") or "" - line = f"{api_name}" - if selector: - line += f" selector={selector!r}" - if value and isinstance(value, str) and len(value) < 200: - line += f" value={value!r}" - script_summary_lines.append(line) - if not script_summary_lines: - script_summary_lines.append("(no action events parsed from trace chunk)") - except Exception as e: # pragma: no cover - script_summary_lines.append(f"(failed to parse trace for script summary: {e})") - - # Prepare chat messages (simple list of strings for easy viewing) - chat_messages = [ - "PLAYWRIGHT TRACE STEP SUMMARY:", - f"Step {self._step_idx} URL: {page.url}", - "Actions:", - *script_summary_lines, - f"Trace file: {trace_path}", - "Open with: npx playwright show-trace " + str(trace_path), - ] - - self._step_idx += 1 - - agent_info = bgym.AgentInfo( - think="human-recorded", - chat_messages=chat_messages, - stats={"step": self._step_idx}, - markdown_page=textwrap.dedent( - f"""### Simple Human Trace Capture\nSaved artifacts for step {meta['step']}:\n- URL: {meta['url']}\n- Trace: {meta['trace_path']}\n- Screenshot: {meta['screenshot_path']}\n- HTML: {meta['html_path']}\n""" - ), - extra_info=meta, + def _wait_for_action(self, page): + """Wait for user action with auto-resume.""" + page.evaluate( + """ + window.__acted = false; + ['click','keydown','input','change'].forEach(e => + document.addEventListener(e, () => window.__acted = true, true) + ); + """ ) - return "noop()", agent_info - -SIMPLE_TRACE_CAPTURE_AGENT = SimpleHumanTraceCaptureAgentArgs() - -##1. Simple debug agent -# 2. Instead of using the page object Launch codegen directly in a subprocess using the playwright codegen --url or somethiing + start = time.time() + while time.time() - start < 300: # 5 min max + try: + if page.evaluate("window.__acted"): + page.evaluate("document.getElementById('__rec').innerHTML = '💾 SAVING'") + time.sleep(0.3) + return + except Exception: + pass + time.sleep(0.1) + + def _extract_trace(self, trace_file: str): + """Extract ALL events from trace zip.""" + all_events = [] + try: + with zipfile.ZipFile(trace_file, "r") as zf: + for name in zf.namelist(): + if name.endswith(".trace"): + with zf.open(name) as f: + for line in f: + try: + event = json.loads(line.decode()) + # Save everything - don't filter + all_events.append(event) + except Exception: + continue + except Exception: + pass + return all_events + + +HUMAN_TRACE_AGENT = HumanTraceAgentArgs() From 380c69f4708f6c172b9408bd1b55cbaa0edf5556 Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:27:39 -0400 Subject: [PATCH 07/53] add event listeners and launcher --- .../agents/human_trace_recorder/agent.py | 368 ++++++++---- .../human_trace_recorder/event_listeners.py | 563 ++++++++++++++++++ 2 files changed, 802 insertions(+), 129 deletions(-) create mode 100644 src/agentlab/agents/human_trace_recorder/event_listeners.py diff --git a/src/agentlab/agents/human_trace_recorder/agent.py b/src/agentlab/agents/human_trace_recorder/agent.py index fd5aa554..556922af 100644 --- a/src/agentlab/agents/human_trace_recorder/agent.py +++ b/src/agentlab/agents/human_trace_recorder/agent.py @@ -1,16 +1,14 @@ -"""Minimal Human Trace Agent (<200 lines) +"""Human Trace Agent for Browser Automation Training Data -Per step we capture ONLY: - - axtree_txt, pruned_html, actions.json, after.html - - Auto-resume after detecting user action - - Visible recording indicator +Captures human interactions at each step including: + - Comprehensive action tracking (clicks, input, navigation, etc.) + - Saves only human_action.json files in simple numbered folders """ from __future__ import annotations import json import time -import zipfile from dataclasses import dataclass from pathlib import Path @@ -18,6 +16,17 @@ from playwright.sync_api import Page from agentlab.agents.agent_args import AgentArgs +from agentlab.agents.human_trace_recorder.event_listeners import ( + get_interaction_tracking_script, + get_recording_indicators_script, +) +from browsergym.core.observation import ( + extract_dom_extra_properties, + extract_dom_snapshot, + extract_focused_element_bid, + extract_merged_axtree, + extract_screenshot, +) from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html @@ -41,36 +50,33 @@ def __init__(self, trace_dir: str): self._root.mkdir(parents=True, exist_ok=True) self._page: Page | None = None self._step = 0 - self._task_name = None - self._seed = None def obs_preprocessor(self, obs: dict): # type: ignore[override] if isinstance(obs, dict): - if self._page is None and "page" in obs: - self._page = obs["page"] - - # Extract task name and seed from obs if available - if self._task_name is None: - self._task_name = obs.get("task_name") or obs.get("task", {}).get( - "task_name", "unknown_task" - ) - if self._seed is None: - self._seed = obs.get("seed") or obs.get("task", {}).get("seed", "unknown_seed") - - dom = obs.get("dom_object") - axt = obs.get("axtree_object") - if axt is not None: - try: + self._page = obs.get("page") + # Remove the page object from obs to avoid pickle issues + if "page" in obs: + del obs["page"] + + obs["screenshot"] = extract_screenshot(self._page) + obs["dom_object"] = extract_dom_snapshot(self._page) + obs["axtree_object"] = extract_merged_axtree(self._page) + scale_factor = getattr(self._page, "_bgym_scale_factor", 1.0) + extra_properties = extract_dom_extra_properties( + obs["dom_object"], scale_factor=scale_factor + ) + obs["extra_element_properties"] = extra_properties + obs["focused_element_bid"] = extract_focused_element_bid(self._page) + + # Add text representations for easier analysis + if obs["axtree_object"]: + axt = obs["axtree_object"] + if extra_properties: obs["axtree_txt"] = flatten_axtree_to_str(axt) - except Exception: - pass - if dom is not None: - try: - obs["pruned_html"] = prune_html(flatten_dom_to_str(dom)) - except Exception: - pass - for k in ("dom_object", "axtree_object", "page"): - obs.pop(k, None) + + if obs["dom_object"]: + obs["dom_txt"] = flatten_dom_to_str(obs["dom_object"]) + obs["pruned_html"] = prune_html(obs["dom_txt"]) return obs def get_action(self, obs: dict): # type: ignore[override] @@ -78,115 +84,219 @@ def get_action(self, obs: dict): # type: ignore[override] raise RuntimeError("Playwright Page missing; ensure use_raw_page_output=True") page = self._page - - # Create directory structure: trace_dir/task_name/seed/step_XXXX - task_dir = self._root / str(self._task_name or "unknown_task") - seed_dir = task_dir / str(self._seed or "unknown_seed") - step_dir = seed_dir / f"step_{self._step:04d}" - step_dir.mkdir(parents=True, exist_ok=True) - - trace_path = step_dir / "temp_trace.zip" - actions_path = step_dir / "actions.json" - - print( - f"[HumanTrace] Task: {self._task_name}, Seed: {self._seed}, Step {self._step}: Perform ONE action" - ) - - # Small recording indicator - page.evaluate( - """ - const div = document.createElement('div'); - div.id = '__rec'; - div.innerHTML = '🔴 REC'; - div.style.cssText = 'position:fixed;top:5px;right:5px;background:#f44;color:#fff;padding:5px 8px;border-radius:4px;font:bold 12px monospace;z-index:99999'; - document.body.appendChild(div); - """ - ) - - # Start tracing - try: - page.context.tracing.start(screenshots=True, snapshots=True) - page.context.tracing.start_chunk() - except Exception: - pass - - # Wait for action - self._wait_for_action(page) - - # Stop tracing and save - try: - page.context.tracing.stop_chunk(path=str(trace_path)) - actions = self._extract_trace(str(trace_path)) - actions_path.write_text(json.dumps(actions, indent=2)) - trace_path.unlink(missing_ok=True) - except Exception: - pass - - # Remove indicator - page.evaluate("document.getElementById('__rec')?.remove()") - - # Save screenshot - try: - page.screenshot(path=str(step_dir / "screenshot.png")) - except Exception: - pass - - # Save HTML - try: - (step_dir / "after.html").write_text(page.content()) - except Exception: - pass - + step_dir = self._create_step_directory() + + self._display_recording_prompt() + self._show_recording_indicators(page) + + # Capture human interactions + captured_action, human_interactions = self._capture_interactions_with_js(page, step_dir) + + # Save and cleanup + self._save_human_action(captured_action, step_dir) + self._cleanup_indicators(page) + self._step += 1 return "noop()", { "extra_info": { "step": self._step - 1, - "task_name": self._task_name, - "seed": self._seed, - "trace_dir": str(step_dir), + "human_interactions": human_interactions, } } - def _wait_for_action(self, page): - """Wait for user action with auto-resume.""" - page.evaluate( - """ - window.__acted = false; - ['click','keydown','input','change'].forEach(e => - document.addEventListener(e, () => window.__acted = true, true) - ); - """ - ) - - start = time.time() - while time.time() - start < 300: # 5 min max + def _create_step_directory(self) -> Path: + """Create directory for current step.""" + step_dir = self._root / str(self._step) + step_dir.mkdir(parents=True, exist_ok=True) + return step_dir + + def _display_recording_prompt(self): + """Display prompt messages to user.""" + print(f"[HumanTrace] Step {self._step}: Perform ONE action") + print("[HumanTrace] ⚠️ WAIT FOR THE RED BORDER TO APPEAR BEFORE PERFORMING ANY ACTION ⚠️") + print("[HumanTrace] The system will automatically save after detecting your action") + + def _show_recording_indicators(self, page: Page): + """Show visual recording indicators on the page.""" + page.evaluate(get_recording_indicators_script()) + + def _save_human_action(self, captured_action: dict, step_dir: Path): + """Save the captured human action to JSON file.""" + try: + human_action_path = step_dir / "human_action.json" + if captured_action and isinstance(captured_action, dict): + human_action_path.write_text(json.dumps(captured_action, indent=2)) + action_type = captured_action.get("type", "unknown") + else: + # Create empty action record for consistency + empty_action = { + "type": "no_action", + "timestamp": time.time() * 1000, + "reason": "No meaningful human action captured in this step", + } + human_action_path.write_text(json.dumps(empty_action, indent=2)) + action_type = "no_action" + + print(f"[HumanTrace] Step {self._step} complete - Action: {action_type}") + + except Exception as e: + print(f"[HumanTrace] Warning: Failed to save human action: {e}") + + def _cleanup_indicators(self, page: Page): + """Remove recording indicators from the page.""" + page.evaluate("document.getElementById('__rec')?.remove(); document.getElementById('__rec_border')?.remove()") + + def _capture_interactions_with_js(self, page: Page, step_dir: Path) -> tuple[dict, str]: + """Capture human interactions using JavaScript injection.""" + try: + print("[HumanTrace] JavaScript interaction tracking enabled") + initial_url, initial_title = page.url, page.title() + + # Inject interaction tracking + self._inject_interaction_tracking(page) + + # Wait for user action + self._wait_for_user_action(page) + + # Collect and process interaction data + return self._collect_interaction_data(page, initial_url, initial_title) + + except Exception as e: + print(f"[HumanTrace] Error: {e}") + return { + "type": "error", + "timestamp": time.time() * 1000, + "error": str(e), + }, f"Error: {e}" + + def _inject_interaction_tracking(self, page: Page): + """Inject JavaScript code for comprehensive interaction tracking.""" + tracking_script = get_interaction_tracking_script() + page.evaluate(tracking_script) + + def _wait_for_user_action(self, page: Page): + """Wait for user to perform an action.""" + start_time = time.time() + while time.time() - start_time < 300: try: - if page.evaluate("window.__acted"): - page.evaluate("document.getElementById('__rec').innerHTML = '💾 SAVING'") - time.sleep(0.3) - return - except Exception: + action_detected = page.evaluate("window.__acted || false") + if action_detected: + print(f"[HumanTrace] Action detected! Exiting immediately...") + break + except Exception as e: + print(f"[HumanTrace] Debug: Error checking actions: {e}") pass time.sleep(0.1) - def _extract_trace(self, trace_file: str): - """Extract ALL events from trace zip.""" - all_events = [] + def _collect_interaction_data(self, page: Page, initial_url: str, initial_title: str) -> tuple[dict, str]: + """Collect and format interaction data.""" try: - with zipfile.ZipFile(trace_file, "r") as zf: - for name in zf.namelist(): - if name.endswith(".trace"): - with zf.open(name) as f: - for line in f: - try: - event = json.loads(line.decode()) - # Save everything - don't filter - all_events.append(event) - except Exception: - continue - except Exception: - pass - return all_events + action_detected = page.evaluate("window.__acted || false") + interactions = page.evaluate("window.__interactions || []") + + action_data = { + "type": "human_interactions" if action_detected else "no_action", + "timestamp": time.time() * 1000, + "detected": action_detected, + "interactions": interactions, + "interaction_count": len(interactions) + } + + summary = self._create_interaction_summary(interactions) + self._add_page_change_info(action_data, initial_url, initial_title, page) + + print(f"[HumanTrace] {summary}") + return action_data, summary + + except Exception as e: + return { + "type": "error", + "timestamp": time.time() * 1000, + "detected": False, + "error": str(e), + "interactions": [], + "interaction_count": 0 + }, f"Error collecting interactions: {e}" + + def _create_interaction_summary(self, interactions: list) -> str: + """Create a summary string of captured interactions.""" + if interactions: + interaction_types = {} + for interaction in interactions: + itype = interaction.get('type', 'unknown') + interaction_types[itype] = interaction_types.get(itype, 0) + 1 + + summary_parts = [] + for itype, count in interaction_types.items(): + summary_parts.append(f"{itype}:{count}") + return f"Captured {len(interactions)} interactions: {', '.join(summary_parts)}" + else: + return "No interactions detected" + + def _add_page_change_info(self, action_data: dict, initial_url: str, initial_title: str, page: Page): + """Add page change information to action data.""" + final_url, final_title = page.url, page.title() + if initial_url != final_url or initial_title != final_title: + action_data["page_changed"] = True + action_data["url_change"] = {"from": initial_url, "to": final_url} + action_data["title_change"] = {"from": initial_title, "to": final_title} + + def _format_js_interaction_summary(self, action_data, interaction_log): + """Format JavaScript-captured interactions into readable summary.""" + lines = ["Human Interactions (JavaScript Tracking):"] + + if action_data["interactions"]: + lines.append(f"Total Actions: {len(action_data['interactions'])}") + lines.append("") + + # Group interactions by type + by_type = {} + for interaction in action_data["interactions"]: + interaction_type = interaction["type"] + if interaction_type not in by_type: + by_type[interaction_type] = [] + by_type[interaction_type].append(interaction) + + # Show summary by type + for interaction_type, interactions in by_type.items(): + lines.append(f"{interaction_type.title()}: {len(interactions)} actions") + + lines.append("") + lines.append("Detailed Actions:") + + # Add each interaction from the log + for log_entry in interaction_log: + lines.append(f" {log_entry}") + else: + lines.append("No interactions detected - user may have just observed the page") + + # Add page state changes if URL changed + if action_data.get("page_changed"): + url_info = action_data.get("url") + if url_info: + lines.append("") + lines.append("� Page Navigation:") + lines.append(f" From: {url_info['from']}") + lines.append(f" To: {url_info['to']}") + + return "\n".join(lines) HUMAN_TRACE_AGENT = HumanTraceAgentArgs() + + +if __name__ == "__main__": + from agentlab.agents.human_trace_recorder.agent import HUMAN_TRACE_AGENT + from agentlab.experiments.study import Study + + agent_configs = [HUMAN_TRACE_AGENT] + benchmark = bgym.DEFAULT_BENCHMARKS["workarena_l1"](n_repeats=1) # type: bgym.Benchmark + benchmark = benchmark.subset_from_glob("task_name", "*filter*") + benchmark.env_args_list = benchmark.env_args_list[:1] + for env_args in benchmark.env_args_list: + print(env_args.task_name) + env_args.max_steps = 15 + env_args.headless = False + + study = Study(agent_configs, benchmark) + study.run(n_jobs=1, parallel_backend="sequential") diff --git a/src/agentlab/agents/human_trace_recorder/event_listeners.py b/src/agentlab/agents/human_trace_recorder/event_listeners.py new file mode 100644 index 00000000..2fd8453c --- /dev/null +++ b/src/agentlab/agents/human_trace_recorder/event_listeners.py @@ -0,0 +1,563 @@ +"""JavaScript Event Listeners for Human Trace Capture + +This module contains all the JavaScript code for capturing comprehensive +browser interactions including mouse, keyboard, form, scroll, and focus events. +""" + + +def get_interaction_tracking_script() -> str: + """Get the complete JavaScript code for interaction tracking.""" + return ( + """ + window.__acted = false; + window.__interactions = []; + + // Debug mode - set to true to see all events in console + window.__debug_events = false; + + function captureInteraction(type, event, extra = {}) { + // Skip our own recording indicators + if (event.target.id === '__rec' || event.target.id === '__rec_border' || + event.target.closest('#__rec') || event.target.closest('#__rec_border')) { + return; + } + + const interaction = { + type: type, + timestamp: Date.now(), + coords: { + x: event.clientX || 0, + y: event.clientY || 0 + }, + target: { + tagName: event.target.tagName, + id: event.target.id || null, + className: event.target.className || null, + text: event.target.textContent?.slice(0, 50) || null, + bid: event.target.getAttribute('bid') || null + }, + ...extra + }; + + window.__interactions.push(interaction); + window.__acted = true; + + // Debug logging + if (window.__debug_events) { + console.log(`🎯 Captured: ${type}`, interaction); + } + + // Update indicators immediately + const indicator = document.getElementById('__rec'); + const border = document.getElementById('__rec_border'); + if (indicator) { + indicator.innerHTML = '✅ ACTION DETECTED - SAVING...'; + indicator.style.background = '#28a745'; + indicator.style.animation = 'none'; + } + if (border) { + border.style.border = '8px solid #28a745'; + border.style.animation = 'none'; + } + } + + // Debug function - add this temporarily to see what events fire + if (window.__debug_events) { + ['input', 'change', 'select', 'focus', 'click', 'keydown', 'paste', 'cut', 'copy'].forEach(eventType => { + document.addEventListener(eventType, (e) => { + console.log(`🔍 DEBUG: ${eventType} on`, e.target.tagName, e.target.type, e.target); + }, true); + }); + } + + """ + + get_mouse_event_listeners() + + """ + """ + + get_keyboard_event_listeners() + + """ + """ + + get_form_event_listeners() + + """ + """ + + get_scroll_event_listeners() + + """ + """ + + get_focus_event_listeners() + + """ + + console.log('Comprehensive interaction tracking initialized'); + """ + ) + + +def get_mouse_event_listeners() -> str: + """Get JavaScript code for mouse event listeners.""" + return """ + // Mouse events with comprehensive button tracking and performance optimizations + let lastClickTime = 0; + + document.addEventListener('click', (e) => { + const now = Date.now(); + // Prevent spam clicking from creating too many events (minimum 50ms between clicks) + if (now - lastClickTime < 50) return; + lastClickTime = now; + + captureInteraction('click', e, { + button: e.button, // 0=left, 1=middle, 2=right + buttons: e.buttons, // bitmask of pressed buttons + buttonName: ['left', 'middle', 'right'][e.button] || 'unknown', + detail: e.detail, // click count (single, double, etc.) + clickType: e.detail === 1 ? 'single' : e.detail === 2 ? 'double' : `${e.detail}x` + }); + }, true); + + document.addEventListener('dblclick', (e) => { + captureInteraction('dblclick', e, { + button: e.button, + buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' + }); + }, true); + + document.addEventListener('mousedown', (e) => { + captureInteraction('mousedown', e, { + button: e.button, + buttons: e.buttons, + buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' + }); + }, true); + + document.addEventListener('mouseup', (e) => { + captureInteraction('mouseup', e, { + button: e.button, + buttons: e.buttons, + buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' + }); + }, true); + + // Context menu (right-click menu) + document.addEventListener('contextmenu', (e) => { + captureInteraction('contextmenu', e, { + button: e.button, + buttonName: 'right' + }); + }, true); + + // Middle mouse button events (often used for scrolling/opening in new tab) + document.addEventListener('auxclick', (e) => { + captureInteraction('auxclick', e, { + button: e.button, + buttonName: e.button === 1 ? 'middle' : (e.button === 2 ? 'right' : 'other'), + detail: e.detail + }); + }, true); + + // Enhanced drag tracking (without redundant mousedown) + let isDragging = false; + let dragStart = null; + let dragButton = null; + let hasDraggedSignificantly = false; + + document.addEventListener('mousedown', (e) => { + isDragging = true; + dragButton = e.button; + hasDraggedSignificantly = false; + dragStart = { + x: e.clientX, + y: e.clientY, + time: Date.now(), + button: e.button, + buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' + }; + }, true); + + document.addEventListener('mousemove', (e) => { + if (isDragging && dragStart) { + const distance = Math.sqrt( + Math.pow(e.clientX - dragStart.x, 2) + + Math.pow(e.clientY - dragStart.y, 2) + ); + if (distance > 5 && !hasDraggedSignificantly) { + // Only capture the start of a significant drag, not every movement + hasDraggedSignificantly = true; + captureInteraction('drag_start', e, { + startX: dragStart.x, + startY: dragStart.y, + endX: e.clientX, + endY: e.clientY, + distance: distance, + button: dragButton, + buttonName: dragStart.buttonName, + duration: Date.now() - dragStart.time + }); + } + } + // Note: Removed general mousemove tracking to reduce noise + }, true); + + document.addEventListener('mouseup', (e) => { + if (isDragging && dragStart && hasDraggedSignificantly) { + const distance = Math.sqrt( + Math.pow(e.clientX - dragStart.x, 2) + + Math.pow(e.clientY - dragStart.y, 2) + ); + captureInteraction('drag_end', e, { + startX: dragStart.x, + startY: dragStart.y, + endX: e.clientX, + endY: e.clientY, + distance: distance, + duration: Date.now() - dragStart.time, + button: dragButton, + buttonName: dragStart.buttonName, + totalDistance: distance + }); + } + isDragging = false; + dragStart = null; + dragButton = null; + hasDraggedSignificantly = false; + }, true); + + // Drag and drop events + document.addEventListener('dragstart', (e) => { + captureInteraction('dragstart', e, { + dataTransfer: { + effectAllowed: e.dataTransfer.effectAllowed, + types: Array.from(e.dataTransfer.types) + } + }); + }, true); + + document.addEventListener('dragend', (e) => { + captureInteraction('dragend', e, { + dataTransfer: { + dropEffect: e.dataTransfer.dropEffect + } + }); + }, true); + + document.addEventListener('drop', (e) => { + captureInteraction('drop', e, { + dataTransfer: { + dropEffect: e.dataTransfer.dropEffect, + types: Array.from(e.dataTransfer.types) + }, + files: e.dataTransfer.files.length > 0 ? Array.from(e.dataTransfer.files).map(f => ({ + name: f.name, + type: f.type, + size: f.size + })) : null + }); + }, true); + """ + + +def get_keyboard_event_listeners() -> str: + """Get JavaScript code for keyboard event listeners.""" + return """ + // Keyboard events with shortcut detection + document.addEventListener('keydown', (e) => { + let shortcut = null; + if (e.ctrlKey || e.metaKey) { + const modifier = e.ctrlKey ? 'Ctrl' : 'Cmd'; + const key = e.key.length === 1 ? e.key.toUpperCase() : e.key; + shortcut = `${modifier}+${key}`; + } else if (e.altKey && e.key.length === 1) { + shortcut = `Alt+${e.key.toUpperCase()}`; + } else if (e.shiftKey && e.key.length === 1) { + shortcut = `Shift+${e.key.toUpperCase()}`; + } + + captureInteraction('keydown', e, { + key: e.key, + code: e.code, + ctrlKey: e.ctrlKey, + shiftKey: e.shiftKey, + altKey: e.altKey, + metaKey: e.metaKey, + shortcut: shortcut + }); + }, true); + + document.addEventListener('keyup', (e) => { + captureInteraction('keyup', e, { + key: e.key, + code: e.code + }); + }, true); + """ + + +def get_form_event_listeners() -> str: + """Get JavaScript code for form event listeners.""" + return """ + // Input events with throttling to prevent spam during fast typing + let inputTimeout; + let lastInputValue = ''; + + document.addEventListener('input', (e) => { + if (['INPUT', 'TEXTAREA'].includes(e.target.tagName) || e.target.contentEditable === 'true') { + clearTimeout(inputTimeout); + inputTimeout = setTimeout(() => { + const currentValue = e.target.value || e.target.textContent; + // Only capture if value actually changed significantly + if (currentValue !== lastInputValue) { + lastInputValue = currentValue; + captureInteraction('input', e, { + value: currentValue, + inputType: e.inputType || null, + valueLength: currentValue.length + }); + } + }, 50); // Reduced from 300ms to 50ms for better responsiveness + } + }, true); + + // Immediate input capture (without throttling) for certain cases + document.addEventListener('input', (e) => { + // Immediate capture for dropdown/select-like inputs or when selection changes + if (e.target.tagName === 'SELECT' || + e.inputType === 'deleteContentBackward' || + e.inputType === 'insertFromPaste' || + e.inputType === 'insertFromDrop') { + captureInteraction('input_immediate', e, { + value: e.target.value || e.target.textContent, + inputType: e.inputType || null, + immediate: true + }); + } + }, true); + + // Text selection events + document.addEventListener('select', (e) => { + if (['INPUT', 'TEXTAREA'].includes(e.target.tagName)) { + const selectedText = e.target.value.substring(e.target.selectionStart, e.target.selectionEnd); + captureInteraction('select', e, { + selectedText: selectedText, + selectionStart: e.target.selectionStart, + selectionEnd: e.target.selectionEnd, + value: e.target.value, + selectionLength: selectedText.length + }); + } + }, true); + + // Clipboard events + document.addEventListener('cut', (e) => { + captureInteraction('cut', e, { + clipboardData: e.clipboardData ? Array.from(e.clipboardData.types) : null, + targetValue: e.target.value || e.target.textContent + }); + }, true); + + document.addEventListener('copy', (e) => { + captureInteraction('copy', e, { + clipboardData: e.clipboardData ? Array.from(e.clipboardData.types) : null, + targetValue: e.target.value || e.target.textContent + }); + }, true); + + document.addEventListener('paste', (e) => { + captureInteraction('paste', e, { + clipboardData: e.clipboardData ? Array.from(e.clipboardData.types) : null, + targetValue: e.target.value || e.target.textContent + }); + }, true); + + // Enhanced form change events with better dropdown handling + document.addEventListener('change', (e) => { + let extra = {}; + if (e.target.tagName === 'SELECT') { + const option = e.target.options[e.target.selectedIndex]; + extra = { + selectedValue: e.target.value, + selectedText: option?.text || '', + selectedIndex: e.target.selectedIndex, + allOptions: Array.from(e.target.options).map(opt => ({ + value: opt.value, + text: opt.text, + selected: opt.selected + })), + optionsCount: e.target.options.length + }; + } else if (['checkbox', 'radio'].includes(e.target.type)) { + extra = { + checked: e.target.checked, + value: e.target.value, + name: e.target.name + }; + } else { + extra = { + value: e.target.value, + previousValue: e.target.defaultValue, // Capture what it was before + inputType: e.target.type + }; + } + captureInteraction('change', e, extra); + }, true); + + document.addEventListener('submit', (e) => { + captureInteraction('submit', e, { + formAction: e.target.action || null, + formMethod: e.target.method || 'GET', + formElements: Array.from(e.target.elements).length + }); + }, true); + + // Additional events for better field interaction capture + + // Option selection in datalists + document.addEventListener('input', (e) => { + if (e.target.list) { // Has datalist + captureInteraction('datalist_input', e, { + value: e.target.value, + listId: e.target.list.id, + optionsCount: e.target.list.options.length + }); + } + }, true); + + // File input changes + document.addEventListener('change', (e) => { + if (e.target.type === 'file') { + captureInteraction('file_select', e, { + filesCount: e.target.files.length, + files: Array.from(e.target.files).map(file => ({ + name: file.name, + type: file.type, + size: file.size, + lastModified: file.lastModified + })) + }); + } + }, true); + """ + + +def get_scroll_event_listeners() -> str: + """Get JavaScript code for scroll event listeners.""" + return """ + // Scroll events with debouncing to reduce noise + let scrollTimeout; + let lastScrollTime = 0; + + document.addEventListener('scroll', (e) => { + clearTimeout(scrollTimeout); + scrollTimeout = setTimeout(() => { + const now = Date.now(); + // Only capture scroll if it's been at least 200ms since last scroll capture + if (now - lastScrollTime > 200) { + lastScrollTime = now; + captureInteraction('scroll', e, { + scrollX: window.scrollX, + scrollY: window.scrollY, + scrollLeft: e.target.scrollLeft || 0, + scrollTop: e.target.scrollTop || 0 + }); + } + }, 150); // Increased debounce time + }, true); + + // Wheel events (for detailed scroll tracking) with throttling + let lastWheelTime = 0; + document.addEventListener('wheel', (e) => { + const now = Date.now(); + // Only capture wheel events every 100ms to reduce noise + if (now - lastWheelTime > 100) { + lastWheelTime = now; + captureInteraction('wheel', e, { + deltaX: e.deltaX, + deltaY: e.deltaY, + deltaZ: e.deltaZ, + deltaMode: e.deltaMode + }); + } + }, true); + """ + + +def get_focus_event_listeners() -> str: + """Get JavaScript code for focus event listeners.""" + return """ + // Focus events - only for interactive elements to reduce noise + document.addEventListener('focus', (e) => { + // Only capture focus on interactive elements + const interactiveElements = ['INPUT', 'TEXTAREA', 'SELECT', 'BUTTON', 'A']; + if (interactiveElements.includes(e.target.tagName) || + e.target.contentEditable === 'true' || + e.target.tabIndex >= 0) { + captureInteraction('focus', e); + } + }, true); + + document.addEventListener('blur', (e) => { + // Only capture blur on interactive elements + const interactiveElements = ['INPUT', 'TEXTAREA', 'SELECT', 'BUTTON', 'A']; + if (interactiveElements.includes(e.target.tagName) || + e.target.contentEditable === 'true' || + e.target.tabIndex >= 0) { + captureInteraction('blur', e); + } + }, true); + """ + + +def get_recording_indicators_script() -> str: + """Get JavaScript code for recording indicators.""" + return """ + // Remove any existing indicators + const existingBorder = document.getElementById('__rec_border'); + if (existingBorder) existingBorder.remove(); + const existingIndicator = document.getElementById('__rec'); + if (existingIndicator) existingIndicator.remove(); + + // Create border overlay + const border = document.createElement('div'); + border.id = '__rec_border'; + border.style.cssText = ` + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + border: 8px solid #ff0000; + box-sizing: border-box; + pointer-events: none; + z-index: 999999; + animation: pulse 1.5s infinite; + `; + + // Create status indicator + const indicator = document.createElement('div'); + indicator.id = '__rec'; + indicator.innerHTML = '🔴 RECORDING - Perform your action now'; + indicator.style.cssText = ` + position: fixed; + top: 10px; + left: 50%; + transform: translateX(-50%); + background: #ff0000; + color: #fff; + padding: 12px 20px; + border-radius: 8px; + font: bold 10px -apple-system, BlinkMacSystemFont, sans-serif; + z-index: 9999999; + box-shadow: 0 4px 12px rgba(255,0,0,0.4); + animation: pulse 1.5s infinite; + `; + + // Add pulsing animation + const style = document.createElement('style'); + style.textContent = ` + @keyframes pulse { + 0% { opacity: 1; } + 50% { opacity: 0.4; } + 100% { opacity: 0.8; } + } + `; + document.head.appendChild(style); + + document.body.appendChild(border); + document.body.appendChild(indicator); + """ From d3054cd15d2f6eb492c29531d0479b4ae61377b5 Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Thu, 21 Aug 2025 18:54:55 -0400 Subject: [PATCH 08/53] Add codegen step-wise recoder agent --- .../human_trace_recorder/codegen_agent.py | 192 ++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 src/agentlab/agents/human_trace_recorder/codegen_agent.py diff --git a/src/agentlab/agents/human_trace_recorder/codegen_agent.py b/src/agentlab/agents/human_trace_recorder/codegen_agent.py new file mode 100644 index 00000000..16d0222c --- /dev/null +++ b/src/agentlab/agents/human_trace_recorder/codegen_agent.py @@ -0,0 +1,192 @@ +"""Simple Codegen Agent + +Captures human interactions using playwright inspector. +Playwright trace logs are stored in "think" messages and can be viewed in Agentlab Xray. +""" + +from __future__ import annotations + +import json +import logging +import tempfile +import zipfile +from dataclasses import dataclass +from pathlib import Path + +import bgym +from playwright.sync_api import Page + +from agentlab.agents.agent_args import AgentArgs +from browsergym.core.observation import ( + extract_dom_extra_properties, + extract_dom_snapshot, + extract_focused_element_bid, + extract_merged_axtree, + extract_screenshot, +) +from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html + + +def extract_log_message_from_pw_trace(pw_trace_file_path): + zip_file = zipfile.ZipFile(pw_trace_file_path, "r") + trace_lines = zip_file.read("trace.trace").decode("utf-8").splitlines() + + actions = [] + for line in trace_lines: + if line.strip(): + event = json.loads(line) + if event.get("type") == "log": + actions.append(event) + # Extract log messages from the trace + return [log["message"].strip() for log in sorted(actions, key=lambda x: x.get("time", 0))] + + +def clean_pw_logs(logs, exclude_blacklist=True, use_substitutions=True): + clean_logs = list(logs) + blacklist = { + "attempting click action", + "waiting for element to be visible, enabled and stable", + "element is visible, enabled and stable", + "scrolling into view if needed", + "done scrolling", + "performing click action", + "click action done", + "waiting for scheduled navigations to finish", + "navigations have finished", + } + + substitutions = [("waiting for ", "")] + + def apply_substitutions(log): + for old, new in substitutions: + log = log.replace(old, new) + return log + + if exclude_blacklist: + clean_logs = [log for log in clean_logs if log not in blacklist] + if use_substitutions: + clean_logs = [apply_substitutions(log) for log in clean_logs] + + return clean_logs + + +@dataclass +class PlayWrightCodeGenAgentArgs(AgentArgs): + agent_name: str = "PlayWrightCodeGenAgent" + trace_dir: str = "playwright_codegen_traces" + use_raw_page_output: bool = True + store_raw_trace: bool = False + + def make_agent(self) -> bgym.Agent: # type: ignore[override] + return PlayWrightCodeGenAgent(self.trace_dir, self.store_raw_trace) + + def set_reproducibility_mode(self): + pass + + +class PlayWrightCodeGenAgent(bgym.Agent): + def __init__(self, trace_dir: str, store_raw_trace: bool): + self.action_set = bgym.HighLevelActionSet(["bid"], multiaction=False) + self._root = Path(trace_dir) + self._page: Page | None = None + self._step = 0 + self.store_raw_trace = store_raw_trace + self._episode_trace_dir = None # Cache for single episode + + def _get_trace_dir(self): + """Return the trace directory based on store_raw_trace setting.""" + if self._episode_trace_dir is None: + if self.store_raw_trace: + import datetime + + dt_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + self._episode_trace_dir = self._root / f"codegen_traces_{dt_str}" + self._episode_trace_dir.mkdir(parents=True, exist_ok=True) + else: + self._episode_trace_dir = Path(tempfile.mkdtemp()) + return self._episode_trace_dir + + def obs_preprocessor(self, obs: dict): # type: ignore[override] + if isinstance(obs, dict): + self._page = obs.get("page") + obs["screenshot"] = extract_screenshot(self._page) + obs["dom_object"] = extract_dom_snapshot(self._page) + obs["axtree_object"] = extract_merged_axtree(self._page) + scale_factor = getattr(self._page, "_bgym_scale_factor", 1.0) + extra_properties = extract_dom_extra_properties( + obs["dom_object"], scale_factor=scale_factor + ) + obs["extra_element_properties"] = extra_properties + obs["focused_element_bid"] = extract_focused_element_bid(self._page) + + if obs["axtree_object"]: + obs["axtree_txt"] = flatten_axtree_to_str(obs["axtree_object"]) + + if obs["dom_object"]: + obs["dom_txt"] = flatten_dom_to_str(obs["dom_object"]) + obs["pruned_html"] = prune_html(obs["dom_txt"]) + + if "page" in obs: # unpickable + del obs["page"] + + return obs + + def get_action(self, obs: dict): # type: ignore[override] + + if self._page is None: + raise RuntimeError("Playwright Page missing; ensure use_raw_page_output=True") + + page = self._page + trace_dir = self._get_trace_dir() + trace_path = trace_dir / f"step_{self._step}.zip" + page.context.tracing.start(screenshots=True, snapshots=True, sources=True) + page.context.tracing.start_chunk(name=f"step_{self._step}") + + print( + f"{'─'*60}\n" f"Step {self._step}\n", + f"{'─'*60}\n", + "1. 🔴 Start Recording (Press 'Record' in the Playwright Inspector.)\n", + "2. ✨ Perform actions for a single step.\n", + "3. ⚫ Stop Recording (Press 'Record' again to stop recording.)\n", + "4. ▶️ Press 'Resume' in the Playwright Inspector.", + ) + + page.pause() # Launch Inspector and record actions + page.context.tracing.stop_chunk(path=trace_path) + page.context.tracing.stop() + + pw_logs = extract_log_message_from_pw_trace(trace_path) + pw_logs = clean_pw_logs(pw_logs, exclude_blacklist=True) + pw_logs_str = "\n".join([f"{i}. {log}" for i, log in enumerate(pw_logs, 1)]) + + print(f"\n Playwright logs for step {self._step}:\n{pw_logs_str}") + + self._step += 1 + + agent_info = bgym.AgentInfo( + think=pw_logs_str, + chat_messages=[], + stats={}, + ) + + return "noop()", agent_info + + +PW_CODEGEN_AGENT = PlayWrightCodeGenAgentArgs(store_raw_trace=True) + + +if __name__ == "__main__": + from agentlab.agents.human_trace_recorder.codegen_agent import PW_CODEGEN_AGENT + from agentlab.experiments.study import Study + + agent_configs = [PW_CODEGEN_AGENT] + benchmark = bgym.DEFAULT_BENCHMARKS["workarena_l1"]() # type: bgym.Benchmark + benchmark = benchmark.subset_from_glob("task_name", "*create*") + benchmark.env_args_list = benchmark.env_args_list[:1] + for env_args in benchmark.env_args_list: + print(env_args.task_name) + env_args.max_steps = 15 + env_args.headless = False + + study = Study(agent_configs, benchmark, logging_level_stdout=logging.INFO) + study.run(n_jobs=1, parallel_backend="sequential", n_relaunch=1) From 825effbbbbf718842a08cd014700f43d7cf1d87c Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Thu, 28 Aug 2025 12:23:52 +0000 Subject: [PATCH 09/53] adding task hints to generic agent --- .../agents/generic_agent_hinter/__init__.py | 53 +++ .../generic_agent_hinter/agent_configs.py | 420 ++++++++++++++++++ .../generic_agent_hinter/generic_agent.py | 211 +++++++++ .../generic_agent_prompt.py | 368 +++++++++++++++ .../generic_agent_hinter/tmlr_config.py | 78 ++++ 5 files changed, 1130 insertions(+) create mode 100644 src/agentlab/agents/generic_agent_hinter/__init__.py create mode 100644 src/agentlab/agents/generic_agent_hinter/agent_configs.py create mode 100644 src/agentlab/agents/generic_agent_hinter/generic_agent.py create mode 100644 src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py create mode 100644 src/agentlab/agents/generic_agent_hinter/tmlr_config.py diff --git a/src/agentlab/agents/generic_agent_hinter/__init__.py b/src/agentlab/agents/generic_agent_hinter/__init__.py new file mode 100644 index 00000000..659aa35a --- /dev/null +++ b/src/agentlab/agents/generic_agent_hinter/__init__.py @@ -0,0 +1,53 @@ +""" +Baseline agent for all ServiceNow papers + +This module contains the GenericAgent class, which is the baseline agent for all ServiceNow papers. \ +It is a simple agent that can be ran OOB on all BrowserGym environments. It is also shipped with \ +a few configurations that can be used to run it on different environments. +""" + +from .agent_configs import ( + AGENT_3_5, + AGENT_8B, + AGENT_37_SONNET, + AGENT_CLAUDE_SONNET_35, + AGENT_CLAUDE_SONNET_35_VISION, + AGENT_CUSTOM, + AGENT_LLAMA3_70B, + AGENT_LLAMA4_17B_INSTRUCT, + AGENT_LLAMA31_70B, + CHAT_MODEL_ARGS_DICT, + RANDOM_SEARCH_AGENT, + AGENT_4o, + AGENT_4o_MINI, + AGENT_4o_MINI_VISION, + AGENT_4o_VISION, + AGENT_o1_MINI, + AGENT_o3_MINI, + FLAGS_GPT_4o, + GenericAgentArgs, + AGENT_GPT5_MINI, +) + +from .generic_agent import GenericAgent, GenericAgentArgs + +__all__ = [ + "AGENT_3_5", + "AGENT_4o", + "AGENT_4o_MINI", + "AGENT_4o_VISION", + "AGENT_o3_MINI", + "AGENT_o1_MINI", + "AGENT_LLAMA4_17B_INSTRUCT", + "AGENT_LLAMA3_70B", + "AGENT_LLAMA31_70B", + "AGENT_8B", + "RANDOM_SEARCH_AGENT", + "AGENT_CUSTOM", + "AGENT_CLAUDE_SONNET_35", + "AGENT_37_SONNET", + "AGENT_4o_VISION", + "AGENT_4o_MINI_VISION", + "AGENT_CLAUDE_SONNET_35_VISION", + "AGENT_GPT5_MINI", +] diff --git a/src/agentlab/agents/generic_agent_hinter/agent_configs.py b/src/agentlab/agents/generic_agent_hinter/agent_configs.py new file mode 100644 index 00000000..798445db --- /dev/null +++ b/src/agentlab/agents/generic_agent_hinter/agent_configs.py @@ -0,0 +1,420 @@ +""" +Basic flags and agent configurations for generic agents. +""" + +import bgym +from bgym import HighLevelActionSetArgs + +from agentlab.agents import dynamic_prompting as dp +from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags +from .tmlr_config import BASE_FLAGS + +FLAGS_CUSTOM = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=False, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +AGENT_CUSTOM = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"], + flags=FLAGS_CUSTOM, +) + + +# GPT-3.5 default config +FLAGS_GPT_3_5 = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, # too big for most benchmark except miniwob + use_ax_tree=True, # very useful + use_focused_element=True, # detrimental on minowob according to ablation study + use_error_logs=True, + use_history=True, + use_past_error_logs=False, # very detrimental on L1 and miniwob + use_action_history=True, # helpful on miniwob + use_think_history=False, # detrimental on L1 and miniwob + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, # doesn't change much + extract_clickable_tag=False, # doesn't change much + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, # usually detrimental + use_criticise=False, # usually detrimental + use_thinking=True, # very useful + use_memory=False, + use_concrete_example=True, # useful + use_abstract_example=True, # useful + use_hints=True, # useful + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +AGENT_3_5 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-3.5-turbo-1106"], + flags=FLAGS_GPT_3_5, +) + +# llama3-70b default config +FLAGS_LLAMA3_70B = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=False, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=True, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=False, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, + add_missparsed_messages=True, +) + +AGENT_LLAMA3_70B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3-70b-instruct"], + flags=FLAGS_LLAMA3_70B, +) +AGENT_LLAMA31_70B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-70b-instruct"], + flags=FLAGS_LLAMA3_70B, +) + +FLAGS_8B = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=False, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=False, + extract_clickable_tag=False, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=True, + ), + long_description=False, + individual_examples=True, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, + add_missparsed_messages=True, +) + + +AGENT_8B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["meta-llama/Meta-Llama-3-8B-Instruct"], + flags=FLAGS_8B, +) + + +AGENT_LLAMA31_8B = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"], + flags=FLAGS_8B, +) + + +# GPT-4o default config +FLAGS_GPT_4o = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + +AGENT_4o = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=FLAGS_GPT_4o, +) + +AGENT_4o_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"], + flags=FLAGS_GPT_4o, +) +AGENT_CLAUDE_SONNET_35 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"], + flags=FLAGS_GPT_4o, +) + +# Claude Sonnet 4 default config with task hints enabled +FLAGS_CLAUDE_SONNET_4 = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=False, + use_diff=False, + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ), + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + use_task_hint=True, # Explicitly enable task hints + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + +AGENT_CLAUDE_SONNET_4 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["anthropic/claude-sonnet-4-20250514"], + flags=FLAGS_CLAUDE_SONNET_4, +) +AGENT_37_SONNET = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.7-sonnet"], + flags=FLAGS_GPT_4o, +) +# AGENT_o3_MINI = GenericAgentArgs( +# chat_model_args=CHAT_MODEL_ARGS_DICT["openai/o3-mini-2025-01-31"], +# flags=FLAGS_GPT_4o, +# ) +AGENT_o3_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o3-mini"], + flags=FLAGS_GPT_4o, +) + +AGENT_o1_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o1-mini-2024-09-12"], + flags=FLAGS_GPT_4o, +) +# GPT-4o vision default config +FLAGS_GPT_4o_VISION = FLAGS_GPT_4o.copy() +FLAGS_GPT_4o_VISION.obs.use_screenshot = True +FLAGS_GPT_4o_VISION.obs.use_som = True + +AGENT_4o_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_4o_MINI_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_CLAUDE_SONNET_35_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"], + flags=FLAGS_GPT_4o_VISION, +) +AGENT_LLAMA4_17B_INSTRUCT = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-4-maverick"], + flags=BASE_FLAGS, +) +GPT5_MINI_FLAGS = BASE_FLAGS.copy() +GPT5_MINI_FLAGS.action = dp.ActionFlags( # action should not be str to work with agentlab-assistant + action_set=HighLevelActionSetArgs( + subsets=["bid"], + multiaction=False, + ) +) + +AGENT_GPT5_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-mini-2025-08-07"], + flags=GPT5_MINI_FLAGS, +) + +AGENT_GPT5 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-2025-08-07"], + flags=GPT5_MINI_FLAGS, +) + +DEFAULT_RS_FLAGS = GenericPromptFlags( + flag_group="default_rs", + obs=dp.ObsFlags( + use_html=True, + use_ax_tree=args.Choice([True, False]), + use_focused_element=False, + use_error_logs=True, + use_history=True, + use_past_error_logs=args.Choice([True, False], p=[0.7, 0.3]), + use_action_history=True, + use_think_history=args.Choice([True, False], p=[0.7, 0.3]), + use_diff=args.Choice([True, False], p=[0.3, 0.7]), + html_type="pruned_html", + use_screenshot=False, + use_som=False, + extract_visible_tag=args.Choice([True, False]), + extract_clickable_tag=False, + extract_coords=args.Choice(["center", "box"]), + filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]), + ), + action=dp.ActionFlags( + action_set=HighLevelActionSetArgs( + subsets=args.Choice([["bid"], ["bid", "coord"]]), + multiaction=args.Choice([True, False], p=[0.7, 0.3]), + ), + long_description=False, + individual_examples=False, + ), + # drop_ax_tree_first=True, # this flag is no longer active, according to browsergym doc + use_plan=args.Choice([True, False]), + use_criticise=args.Choice([True, False], p=[0.7, 0.3]), + use_thinking=args.Choice([True, False], p=[0.7, 0.3]), + use_memory=args.Choice([True, False], p=[0.7, 0.3]), + use_concrete_example=True, + use_abstract_example=True, + use_hints=args.Choice([True, False], p=[0.7, 0.3]), + be_cautious=args.Choice([True, False]), + enable_chat=False, + max_prompt_tokens=40_000, + extra_instructions=None, +) + + +RANDOM_SEARCH_AGENT = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=DEFAULT_RS_FLAGS, +) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py new file mode 100644 index 00000000..91b2f70f --- /dev/null +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -0,0 +1,211 @@ +""" +GenericAgent implementation for AgentLab + +This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \ +The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \ +observations. It includes methods for preprocessing observations, generating actions, and managing internal \ +state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \ +the agent, including model arguments and flags for various behaviors. +""" + +from copy import deepcopy +from dataclasses import asdict, dataclass +from functools import partial +from warnings import warn + +import bgym +from bgym import Benchmark +from browsergym.experiments.agent import Agent, AgentInfo + +from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.chat_api import BaseModelArgs +from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry +from agentlab.llm.tracking import cost_tracker_decorator + +from .generic_agent_prompt import GenericPromptFlags, MainPrompt + + +@dataclass +class GenericAgentArgs(AgentArgs): + chat_model_args: BaseModelArgs = None + flags: GenericPromptFlags = None + max_retry: int = 4 + + def __post_init__(self): + try: # some attributes might be temporarily args.CrossProd for hyperparameter generation + self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_") + except AttributeError: + pass + + def set_benchmark(self, benchmark: Benchmark, demo_mode): + """Override Some flags based on the benchmark.""" + if benchmark.name.startswith("miniwob"): + self.flags.obs.use_html = True + + self.flags.obs.use_tabs = benchmark.is_multi_tab + self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args) + + # for backward compatibility with old traces + if self.flags.action.multi_actions is not None: + self.flags.action.action_set.multiaction = self.flags.action.multi_actions + if self.flags.action.is_strict is not None: + self.flags.action.action_set.strict = self.flags.action.is_strict + + # verify if we can remove this + if demo_mode: + self.flags.action.action_set.demo_mode = "all_blue" + + def set_reproducibility_mode(self): + self.chat_model_args.temperature = 0 + + def prepare(self): + return self.chat_model_args.prepare_server() + + def close(self): + return self.chat_model_args.close_server() + + def make_agent(self): + return GenericAgent( + chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry + ) + + +class GenericAgent(Agent): + + def __init__( + self, + chat_model_args: BaseModelArgs, + flags: GenericPromptFlags, + max_retry: int = 4, + ): + + self.chat_llm = chat_model_args.make_model() + self.chat_model_args = chat_model_args + self.max_retry = max_retry + + self.flags = flags + self.action_set = self.flags.action.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + + self._check_flag_constancy() + self.reset(seed=None) + + def obs_preprocessor(self, obs: dict) -> dict: + return self._obs_preprocessor(obs) + + def set_task_name(self, task_name: str): + """Set the task name for task hints functionality.""" + self.task_name = task_name + + @cost_tracker_decorator + def get_action(self, obs): + + self.obs_history.append(obs) + main_prompt = MainPrompt( + action_set=self.action_set, + obs_history=self.obs_history, + actions=self.actions, + memories=self.memories, + thoughts=self.thoughts, + previous_plan=self.plan, + step=self.plan_step, + flags=self.flags, + ) + + # Set task name for task hints if available + if self.flags.use_task_hint and hasattr(self, 'task_name'): + main_prompt.set_task_name(self.task_name) + + max_prompt_tokens, max_trunc_itr = self._get_maxes() + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + + human_prompt = dp.fit_tokens( + shrinkable=main_prompt, + max_prompt_tokens=max_prompt_tokens, + model_name=self.chat_model_args.model_name, + max_iterations=max_trunc_itr, + additional_prompts=system_prompt, + ) + try: + # TODO, we would need to further shrink the prompt if the retry + # cause it to be too long + + chat_messages = Discussion([system_prompt, human_prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=main_prompt._parse_answer, + ) + ans_dict["busted_retry"] = 0 + # inferring the number of retries, TODO: make this less hacky + ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 + except ParseError as e: + ans_dict = dict( + action=None, + n_retry=self.max_retry + 1, + busted_retry=1, + ) + + stats = self.chat_llm.get_stats() + stats["n_retry"] = ans_dict["n_retry"] + stats["busted_retry"] = ans_dict["busted_retry"] + + self.plan = ans_dict.get("plan", self.plan) + self.plan_step = ans_dict.get("step", self.plan_step) + self.actions.append(ans_dict["action"]) + self.memories.append(ans_dict.get("memory", None)) + self.thoughts.append(ans_dict.get("think", None)) + + agent_info = AgentInfo( + think=ans_dict.get("think", None), + chat_messages=chat_messages, + stats=stats, + extra_info={"chat_model_args": asdict(self.chat_model_args)}, + ) + return ans_dict["action"], agent_info + + def reset(self, seed=None): + self.seed = seed + self.plan = "No plan yet" + self.plan_step = -1 + self.memories = [] + self.thoughts = [] + self.actions = [] + self.obs_history = [] + + def _check_flag_constancy(self): + flags = self.flags + if flags.obs.use_som: + if not flags.obs.use_screenshot: + warn( + """ +Warning: use_som=True requires use_screenshot=True. Disabling use_som.""" + ) + flags.obs.use_som = False + if flags.obs.use_screenshot: + if not self.chat_model_args.vision_support: + warn( + """ +Warning: use_screenshot is set to True, but the chat model \ +does not support vision. Disabling use_screenshot.""" + ) + flags.obs.use_screenshot = False + return flags + + def _get_maxes(self): + maxes = ( + self.flags.max_prompt_tokens, + self.chat_model_args.max_total_tokens, + self.chat_model_args.max_input_tokens, + ) + maxes = [m for m in maxes if m is not None] + max_prompt_tokens = min(maxes) if maxes else None + max_trunc_itr = ( + self.flags.max_trunc_itr + if self.flags.max_trunc_itr + else 20 # dangerous to change the default value here? + ) + return max_prompt_tokens, max_trunc_itr diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py new file mode 100644 index 00000000..bc12cc2c --- /dev/null +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -0,0 +1,368 @@ +""" +Prompt builder for GenericAgent + +It is based on the dynamic_prompting module from the agentlab package. +""" + +import logging +from dataclasses import dataclass + +from browsergym.core import action +from browsergym.core.action.base import AbstractActionSet + +from agentlab.agents import dynamic_prompting as dp +from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise +import fnmatch +import pandas as pd +from pathlib import Path + + +@dataclass +class GenericPromptFlags(dp.Flags): + """ + A class to represent various flags used to control features in an application. + + Attributes: + use_plan (bool): Ask the LLM to provide a plan. + use_criticise (bool): Ask the LLM to first draft and criticise the action before producing it. + use_thinking (bool): Enable a chain of thoughts. + use_concrete_example (bool): Use a concrete example of the answer in the prompt for a generic task. + use_abstract_example (bool): Use an abstract example of the answer in the prompt. + use_hints (bool): Add some human-engineered hints to the prompt. + use_task_hint (bool): Enable task-specific hints from hint database. + hint_db_path (str): Path to the hint database file. + enable_chat (bool): Enable chat mode, where the agent can interact with the user. + max_prompt_tokens (int): Maximum number of tokens allowed in the prompt. + be_cautious (bool): Instruct the agent to be cautious about its actions. + extra_instructions (Optional[str]): Extra instructions to provide to the agent. + add_missparsed_messages (bool): When retrying, add the missparsed messages to the prompt. + flag_group (Optional[str]): Group of flags used. + """ + + obs: dp.ObsFlags + action: dp.ActionFlags + use_plan: bool = False # + use_criticise: bool = False # + use_thinking: bool = False + use_memory: bool = False # + use_concrete_example: bool = True + use_abstract_example: bool = False + use_hints: bool = False + use_task_hint: bool = False + hint_db_path: str = None + enable_chat: bool = False + max_prompt_tokens: int = None + be_cautious: bool = True + extra_instructions: str | None = None + add_missparsed_messages: bool = True + max_trunc_itr: int = 20 + flag_group: str = None + + +class MainPrompt(dp.Shrinkable): + def __init__( + self, + action_set: AbstractActionSet, + obs_history: list[dict], + actions: list[str], + memories: list[str], + thoughts: list[str], + previous_plan: str, + step: int, + flags: GenericPromptFlags, + ) -> None: + super().__init__() + self.flags = flags + self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs) + if self.flags.enable_chat: + self.instructions = dp.ChatInstructions( + obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions + ) + else: + if sum([msg["role"] == "user" for msg in obs_history[-1].get("chat_messages", [])]) > 1: + logging.warning( + "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." + ) + self.instructions = dp.GoalInstructions( + obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions + ) + + self.obs = dp.Observation( + obs_history[-1], + self.flags.obs, + ) + + self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) + + def time_for_caution(): + # no need for caution if we're in single action mode + return flags.be_cautious and ( + flags.action.action_set.multiaction or flags.action.action_set == "python" + ) + + self.be_cautious = dp.BeCautious(visible=time_for_caution) + self.think = dp.Think(visible=lambda: flags.use_thinking) + self.hints = dp.Hints(visible=lambda: flags.use_hints) + self.task_hint = TaskHint( + use_task_hint=flags.use_task_hint, + hint_db_path=flags.hint_db_path + ) + self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan + self.criticise = Criticise(visible=lambda: flags.use_criticise) + self.memory = Memory(visible=lambda: flags.use_memory) + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + + # Add task hints if enabled + task_hints_text = "" + if self.flags.use_task_hint and hasattr(self, 'task_name'): + task_hints_text = self.task_hint.get_hints_for_task(self.task_name) + + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +{self.action_prompt.prompt}\ +{self.hints.prompt}\ +{task_hints_text}\ +{self.be_cautious.prompt}\ +{self.think.prompt}\ +{self.plan.prompt}\ +{self.memory.prompt}\ +{self.criticise.prompt}\ +""" + ) + + if self.flags.use_abstract_example: + prompt.add_text( + f""" +# Abstract Example + +Here is an abstract version of the answer with description of the content of +each tag. Make sure you follow this structure, but replace the content with your +answer: +{self.think.abstract_ex}\ +{self.plan.abstract_ex}\ +{self.memory.abstract_ex}\ +{self.criticise.abstract_ex}\ +{self.task_hint.abstract_ex}\ +{self.action_prompt.abstract_ex}\ +""" + ) + + if self.flags.use_concrete_example: + prompt.add_text( + f""" +# Concrete Example + +Here is a concrete example of how to format your answer. +Make sure to follow the template with proper tags: +{self.think.concrete_ex}\ +{self.plan.concrete_ex}\ +{self.memory.concrete_ex}\ +{self.criticise.concrete_ex}\ +{self.task_hint.concrete_ex}\ +{self.action_prompt.concrete_ex}\ +""" + ) + return self.obs.add_screenshot(prompt) + + def shrink(self): + self.history.shrink() + self.obs.shrink() + + def set_task_name(self, task_name: str): + """Set the task name for task hints functionality.""" + self.task_name = task_name + + def _parse_answer(self, text_answer): + ans_dict = {} + ans_dict.update(self.think.parse_answer(text_answer)) + ans_dict.update(self.plan.parse_answer(text_answer)) + ans_dict.update(self.memory.parse_answer(text_answer)) + ans_dict.update(self.criticise.parse_answer(text_answer)) + ans_dict.update(self.action_prompt.parse_answer(text_answer)) + return ans_dict + + +class Memory(dp.PromptElement): + _prompt = "" # provided in the abstract and concrete examples + + _abstract_ex = """ + +Write down anything you need to remember for next steps. You will be presented +with the list of previous memories and past actions. Some tasks require to +remember hints from previous steps in order to solve it. + +""" + + _concrete_ex = """ + +I clicked on bid "32" to activate tab 2. The accessibility tree should mention +focusable for elements of the form at next step. + +""" + + def _parse_answer(self, text_answer): + return parse_html_tags_raise(text_answer, optional_keys=["memory"], merge_multiple=True) + + +class Plan(dp.PromptElement): + def __init__(self, previous_plan, plan_step, visible: bool = True) -> None: + super().__init__(visible=visible) + self.previous_plan = previous_plan + self._prompt = f""" +# Plan: + +You just executed step {plan_step} of the previously proposed plan:\n{previous_plan}\n +After reviewing the effect of your previous actions, verify if your plan is still +relevant and update it if necessary. +""" + + _abstract_ex = """ + +Provide a multi step plan that will guide you to accomplish the goal. There +should always be steps to verify if the previous action had an effect. The plan +can be revisited at each steps. Specifically, if there was something unexpected. +The plan should be cautious and favor exploring befor submitting. + + +Integer specifying the step of current action + +""" + + _concrete_ex = """ + +1. fill form (failed) + * type first name + * type last name +2. Try to activate the form + * click on tab 2 +3. fill form again + * type first name + * type last name +4. verify and submit + * verify form is filled + * submit if filled, if not, replan + + +2 +""" + + def _parse_answer(self, text_answer): + return parse_html_tags_raise(text_answer, optional_keys=["plan", "step"]) + + +class Criticise(dp.PromptElement): + _prompt = "" + + _abstract_ex = """ + +Write a first version of what you think is the right action. + + + +Criticise action_draft. What could be wrong with it? Enumerate reasons why it +could fail. Did your past actions had the expected effect? Make sure you're not +repeating the same mistakes. + +""" + + _concrete_ex = """ + +click("32") + + + +click("32") might not work because the element is not visible yet. I need to +explore the page to find a way to activate the form. + +""" + + def _parse_answer(self, text_answer): + return parse_html_tags_raise(text_answer, optional_keys=["action_draft", "criticise"]) + + +class TaskHint(dp.PromptElement): + def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None: + super().__init__(visible=use_task_hint) + self.use_task_hint = use_task_hint + self.hint_db_rel_path = "hint_db.csv" + self.hint_db_path = hint_db_path # Allow external path override + self._init() + + _prompt = "" # Task hints are added dynamically in MainPrompt + + _abstract_ex = """ + +What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none. + +""" + + _concrete_ex = """ + +Relevant hint: Based on the hints provided, I should focus on the form elements and use the +accessibility tree to identify interactive elements before taking actions. + +""" + + def _init(self): + """Initialize the block.""" + try: + # Use external path if provided, otherwise fall back to relative path + if self.hint_db_path and Path(self.hint_db_path).exists(): + hint_db_path = Path(self.hint_db_path) + else: + hint_db_path = Path(__file__).parent / self.hint_db_rel_path + + if hint_db_path.exists(): + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + # Verify the expected columns exist + if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + else: + print(f"Warning: Hint database not found at {hint_db_path}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + except Exception as e: + # Fallback to empty database on any error + print(f"Warning: Could not load hint database: {e}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + + + def get_hints_for_task(self, task_name: str) -> str: + """Get hints for a specific task.""" + if not self.use_task_hint: + return "" + + # Ensure hint_db is initialized + if not hasattr(self, 'hint_db'): + self._init() + + # Check if hint_db has the expected structure + if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + return "" + + try: + task_hints = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + + hints = [] + for hint in task_hints["hint"]: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + if len(hints) > 0: + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(hints) + ) + return hints_str + except Exception as e: + print(f"Warning: Error getting hints for task {task_name}: {e}") + + return "" diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py new file mode 100644 index 00000000..d222b7c0 --- /dev/null +++ b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py @@ -0,0 +1,78 @@ +""" +Specific configurations for our 2024 TMLR submission. +""" + +from copy import deepcopy + +from agentlab.agents import dynamic_prompting as dp +from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags + +BASE_FLAGS = GenericPromptFlags( + obs=dp.ObsFlags( + use_html=False, + use_ax_tree=True, + use_focused_element=True, + use_error_logs=True, + use_history=True, + use_past_error_logs=False, + use_action_history=True, + use_think_history=True, # gpt-4o config except for this line + use_diff=False, + html_type="pruned_html", + use_screenshot=True, + use_som=False, + extract_visible_tag=True, + extract_clickable_tag=True, + extract_coords="False", + filter_visible_elements_only=False, + ), + action=dp.ActionFlags( + multi_actions=False, + action_set="bid", + long_description=False, + individual_examples=False, + ), + use_plan=False, + use_criticise=False, + use_thinking=True, + use_memory=False, + use_concrete_example=True, + use_abstract_example=True, + use_hints=True, + enable_chat=False, + max_prompt_tokens=40_000, + be_cautious=True, + extra_instructions=None, +) + + +def get_base_agent(llm_config: str): + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=BASE_FLAGS, + ) + + +def get_vision_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + agent_args = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) + agent_args.agent_name = f"{agent_args.agent_name}_vision" + return agent_args + + +def get_som_agent(llm_config: str): + flags = deepcopy(BASE_FLAGS) + flags.obs.use_screenshot = True + flags.obs.use_som = True + return GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config], + flags=flags, + ) From bf0b6e71ebc9aeb9508e1bf8375212283cf38166 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Thu, 28 Aug 2025 15:05:18 +0200 Subject: [PATCH 10/53] fix repeated llm configs --- src/agentlab/llm/llm_configs.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 3d5828b9..afbf094f 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -20,22 +20,6 @@ ] CHAT_MODEL_ARGS_DICT = { - "openai/gpt-5-nano-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-nano-2025-08-07", - max_total_tokens=128_000, - max_input_tokens=128_000, - max_new_tokens=16_384, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), - "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-mini-2025-08-07", - max_total_tokens=128_000, - max_input_tokens=128_000, - max_new_tokens=16_384, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), "openai/gpt-4.1-mini-2025-04-14": OpenAIModelArgs( model_name="gpt-4.1-mini-2025-04-14", max_total_tokens=128_000, @@ -117,6 +101,7 @@ max_input_tokens=400_000 - 4_000, max_new_tokens=4_000, temperature=1, # temperature param not supported by gpt-5 + vision_support=True, ), "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs( model_name="gpt-5-mini-2025-08-07", @@ -124,6 +109,7 @@ max_input_tokens=400_000 - 4_000, max_new_tokens=4_000, temperature=1, # temperature param not supported by gpt-5 + vision_support=True, ), "azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs( model_name="gpt-35-turbo", From f7d154551c03bc427343af4e22426b87c040274e Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Thu, 28 Aug 2025 15:06:00 +0200 Subject: [PATCH 11/53] load env vars in codegen agent --- src/agentlab/agents/human_trace_recorder/codegen_agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/agentlab/agents/human_trace_recorder/codegen_agent.py b/src/agentlab/agents/human_trace_recorder/codegen_agent.py index 16d0222c..cd902bd2 100644 --- a/src/agentlab/agents/human_trace_recorder/codegen_agent.py +++ b/src/agentlab/agents/human_trace_recorder/codegen_agent.py @@ -14,8 +14,6 @@ from pathlib import Path import bgym -from playwright.sync_api import Page - from agentlab.agents.agent_args import AgentArgs from browsergym.core.observation import ( extract_dom_extra_properties, @@ -25,7 +23,10 @@ extract_screenshot, ) from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html +from dotenv import load_dotenv +from playwright.sync_api import Page +load_dotenv() def extract_log_message_from_pw_trace(pw_trace_file_path): zip_file = zipfile.ZipFile(pw_trace_file_path, "r") From 55ce26a2f85e02b965c31a06660aa4f2518937b5 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Fri, 29 Aug 2025 17:45:41 +0200 Subject: [PATCH 12/53] same hints retrieval for both generic and tooluse agents --- .../generic_agent_hinter/generic_agent.py | 3 +- .../generic_agent_prompt.py | 65 +++++++++---- .../agents/tool_use_agent/tool_use_agent.py | 91 +++++++++++++------ 3 files changed, 109 insertions(+), 50 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 91b2f70f..cfbd19bd 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -111,10 +111,11 @@ def get_action(self, obs): previous_plan=self.plan, step=self.plan_step, flags=self.flags, + llm=self.chat_llm, ) # Set task name for task hints if available - if self.flags.use_task_hint and hasattr(self, 'task_name'): + if self.flags.use_task_hint and hasattr(self, "task_name"): main_prompt.set_task_name(self.task_name) max_prompt_tokens, max_trunc_itr = self._get_maxes() diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index bc12cc2c..f529fd78 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -6,15 +6,16 @@ import logging from dataclasses import dataclass +from pathlib import Path +from typing import Literal -from browsergym.core import action +import pandas as pd from browsergym.core.action.base import AbstractActionSet from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource +from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise -import fnmatch -import pandas as pd -from pathlib import Path @dataclass @@ -49,6 +50,7 @@ class GenericPromptFlags(dp.Flags): use_abstract_example: bool = False use_hints: bool = False use_task_hint: bool = False + task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" hint_db_path: str = None enable_chat: bool = False max_prompt_tokens: int = None @@ -70,10 +72,12 @@ def __init__( previous_plan: str, step: int, flags: GenericPromptFlags, + llm: ChatModel, ) -> None: super().__init__() self.flags = flags self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs) + goal = obs_history[-1]["goal_object"] if self.flags.enable_chat: self.instructions = dp.ChatInstructions( obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions @@ -84,7 +88,7 @@ def __init__( "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." ) self.instructions = dp.GoalInstructions( - obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions + goal, extra_instructions=flags.extra_instructions ) self.obs = dp.Observation( @@ -105,7 +109,10 @@ def time_for_caution(): self.hints = dp.Hints(visible=lambda: flags.use_hints) self.task_hint = TaskHint( use_task_hint=flags.use_task_hint, - hint_db_path=flags.hint_db_path + hint_db_path=flags.hint_db_path, + goal=goal, + hint_retrieval_mode=flags.task_hint_retrieval_mode, + llm=llm, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -114,12 +121,12 @@ def time_for_caution(): @property def _prompt(self) -> HumanMessage: prompt = HumanMessage(self.instructions.prompt) - + # Add task hints if enabled task_hints_text = "" - if self.flags.use_task_hint and hasattr(self, 'task_name'): + if self.flags.use_task_hint and hasattr(self, "task_name"): task_hints_text = self.task_hint.get_hints_for_task(self.task_name) - + prompt.add_text( f"""\ {self.obs.prompt}\ @@ -286,11 +293,21 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): - def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None: + def __init__( + self, + use_task_hint: bool, + hint_db_path: str, + goal: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"], + llm: ChatModel, + ) -> None: super().__init__(visible=use_task_hint) self.use_task_hint = use_task_hint self.hint_db_rel_path = "hint_db.csv" self.hint_db_path = hint_db_path # Allow external path override + self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode + self.goal = goal + self.llm = llm self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -316,39 +333,49 @@ def _init(self): hint_db_path = Path(self.hint_db_path) else: hint_db_path = Path(__file__).parent / self.hint_db_rel_path - + if hint_db_path.exists(): self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) # Verify the expected columns exist if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: - print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}") + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) else: print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( + hint_db_path=self.hint_db_rel_path, + hint_retrieval_mode=self.hint_retrieval_mode, + ) except Exception as e: # Fallback to empty database on any error print(f"Warning: Could not load hint database: {e}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - def get_hints_for_task(self, task_name: str) -> str: """Get hints for a specific task.""" if not self.use_task_hint: return "" # Ensure hint_db is initialized - if not hasattr(self, 'hint_db'): + if not hasattr(self, "hint_db"): self._init() # Check if hint_db has the expected structure - if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + if ( + self.hint_db.empty + or "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): return "" try: - task_hints = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] + # task_hints = self.hint_db[ + # self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + # ] + task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) hints = [] for hint in task_hints["hint"]: @@ -364,5 +391,5 @@ def get_hints_for_task(self, task_name: str) -> str: return hints_str except Exception as e: print(f"Warning: Error getting hints for task {task_name}: {e}") - + return "" diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 375c829e..9025107e 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -28,6 +28,7 @@ from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark from agentlab.benchmarks.osworld import OSWorldActionSet from agentlab.llm.base_api import BaseModelArgs +from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( APIPayload, @@ -316,39 +317,21 @@ class TaskHint(Block): def _init(self): """Initialize the block.""" - if Path(self.hint_db_rel_path).is_absolute(): - hint_db_path = Path(self.hint_db_rel_path) - else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - if self.hint_retrieval_mode == "emb": - self.encode_hints() - - def oai_embed(self, text: str): - response = self._oai_emb.create(input=text, model="text-embedding-3-small") - return response.data[0].embedding - - def encode_hints(self): - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") - logger.info( - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + self.hints_source = HintsSource( + hint_db_path=self.hint_db_rel_path, + hint_retrieval_mode=self.hint_retrieval_mode, + top_n=self.top_n, + embedder_model=self.embedder_model, + embedder_server=self.embedder_server, + llm_prompt=self.llm_prompt, ) - hints = self.uniq_hints["hint"].tolist() - semantic_keys = self.uniq_hints["semantic_keys"].tolist() - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] - emb_path = f"{self.hint_db_rel_path}.embs.npy" - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" - logger.info(f"Loading hint embeddings from: {emb_path}") - emb_dict = np.load(emb_path, allow_pickle=True).item() - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: return {} goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) - task_hints = self.choose_hints(llm, task_name, goal) + task_hints = self.hints_source.choose_hints(llm, task_name, goal) hints = [] for hint in task_hints: @@ -365,6 +348,49 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) + +class HintsSource: + def __init__( + self, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + top_n: int = 4, + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", + embedder_server: str = "http://localhost:5000/", + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", + ) -> None: + self.hint_db_path = hint_db_path + self.hint_retrieval_mode = hint_retrieval_mode + self.top_n = top_n + self.embedder_model = embedder_model + self.embedder_server = embedder_server + self.llm_prompt = llm_prompt + + if Path(hint_db_path).is_absolute(): + self.hint_db_path = Path(hint_db_path).as_posix() + else: + self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() + self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + if self.hint_retrieval_mode == "emb": + self.load_hint_vectors() + + def load_hint_vectors(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: """Choose hints based on the task name.""" if self.hint_retrieval_mode == "llm": @@ -384,11 +410,14 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: hint_topics = list(topic_to_hints.keys()) topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) prompt = self.llm_prompt.format(goal=goal, topics=topics) - response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])) + if isinstance(llm, ChatModel): + response: str = llm(messages=[dict(role="user", content=prompt)])["content"] + else: + response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think try: - hint_topic_idx = json.loads(response.think) + hint_topic_idx = json.loads(response) if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): - logger.error(f"Wrong LLM hint id response: {response.think}, no hints") + logger.error(f"Wrong LLM hint id response: {response}, no hints") return [] hint_topic = hint_topics[hint_topic_idx] hint_indices = topic_to_hints[hint_topic] @@ -397,7 +426,7 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: hints = df["hint"].tolist() logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") except json.JSONDecodeError: - logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints") + logger.error(f"Failed to parse LLM hint id response: {response}, no hints") hints = [] return hints @@ -427,6 +456,7 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret raise e time.sleep(random.uniform(1, timeout)) continue + raise ValueError("Failed to encode hints") def _similarity( self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 @@ -446,6 +476,7 @@ def _similarity( raise e time.sleep(random.uniform(1, timeout)) continue + raise ValueError("Failed to compute similarity") def choose_hints_direct(self, task_name: str) -> list[str]: hints = self.hint_db[ From cad12096f312cfd74de24f0b50ba4010f12953f3 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 1 Sep 2025 13:51:53 +0200 Subject: [PATCH 13/53] filter out current task hints if needed --- .../agents/tool_use_agent/tool_use_agent.py | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 9025107e..4e6de3b3 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -341,7 +341,7 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if len(hints) > 0: hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" + "\n# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join(hints) ) msg = llm.msg.user().add_text(hints_str) @@ -354,6 +354,7 @@ def __init__( self, hint_db_path: str, hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + skip_hints_for_current_task: bool = False, top_n: int = 4, embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", embedder_server: str = "http://localhost:5000/", @@ -363,6 +364,7 @@ def __init__( ) -> None: self.hint_db_path = hint_db_path self.hint_retrieval_mode = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task self.top_n = top_n self.embedder_model = embedder_model self.embedder_server = embedder_server @@ -405,7 +407,14 @@ def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: def choose_hints_llm(self, llm, goal: str) -> list[str]: """Choose hints using LLM to filter the hints.""" topic_to_hints = defaultdict(list) - for i, row in self.hint_db.iterrows(): + hints_df = self.hint_db + if self.skip_hints_for_current_task: + current_task_hints = self.get_current_task_hints(task_name) + hints_df = hints_df[~hints_df["hint"].isin(current_task_hints)] + logger.info( + f"Filtered out current task hints, remaining hints: {hints_df.shape[0]} out of {self.hint_db.shape[0]}" + ) + for i, row in hints_df.iterrows(): topic_to_hints[row["semantic_keys"]].append(i) hint_topics = list(topic_to_hints.keys()) topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) @@ -421,10 +430,10 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: return [] hint_topic = hint_topics[hint_topic_idx] hint_indices = topic_to_hints[hint_topic] - df = self.hint_db.iloc[hint_indices].copy() + df = hints_df.iloc[hint_indices].copy() df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints hints = df["hint"].tolist() - logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") + logger.info(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") except json.JSONDecodeError: logger.error(f"Failed to parse LLM hint id response: {response}, no hints") hints = [] @@ -433,10 +442,21 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: def choose_hints_emb(self, goal: str) -> list[str]: """Choose hints using embeddings to filter the hints.""" goal_embeddings = self._encode([goal], prompt="task description") - similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist()) + hint_embeddings = self.hint_embeddings + hints_df = self.uniq_hints + if self.skip_hints_for_current_task: + current_task_hints = self.get_current_task_hints(task_name) + mask = ~hints_df["hint"].isin(current_task_hints) + hints_df = hints_df[mask] + filtered_indices = hints_df.index.tolist() + hint_embeddings = hint_embeddings[filtered_indices] + logger.info( + f"Filtered same task hint, remained: {len(hint_embeddings)} out of {len(self.hint_embeddings)} embeddings" + ) + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings.tolist()) top_indices = similarities.argsort()[0][-self.top_n :].tolist() logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = self.uniq_hints.iloc[top_indices] + hints = hints_df.iloc[top_indices] logger.info(f"Embedding-based hints chosen: {hints}") return hints["hint"].tolist() @@ -479,10 +499,15 @@ def _similarity( raise ValueError("Failed to compute similarity") def choose_hints_direct(self, task_name: str) -> list[str]: - hints = self.hint_db[ + hints = self.get_current_task_hints(task_name) + logger.info(f"Direct hints chosen: {hints}") + return hints + + def get_current_task_hints(self, task_name): + hints_df = self.hint_db[ self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) ] - return hints["hint"].tolist() + return hints_df["hint"].tolist() @dataclass From d920b8eb6cae5e39ba5f1a49bd1b73b633294e6c Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 1 Sep 2025 13:52:08 +0200 Subject: [PATCH 14/53] fix llm config, add gpt-5 --- src/agentlab/llm/llm_configs.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index afbf094f..7ac2450a 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -111,6 +111,14 @@ temperature=1, # temperature param not supported by gpt-5 vision_support=True, ), + "openai/gpt-5-2025-08-07": OpenAIModelArgs( + model_name="gpt-5-2025-08-07", + max_total_tokens=400_000, + max_input_tokens=400_000 - 4_000, + max_new_tokens=4_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + ), "azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs( model_name="gpt-35-turbo", deployment_name="gpt-35-turbo", From 5315f14b2b5b57f43e23a0da0eec6b31f273ce99 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 1 Sep 2025 13:52:21 +0200 Subject: [PATCH 15/53] fix --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index f529fd78..19f0efda 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -378,7 +378,7 @@ def get_hints_for_task(self, task_name: str) -> str: task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) hints = [] - for hint in task_hints["hint"]: + for hint in task_hints: hint = hint.strip() if hint: hints.append(f"- {hint}") From 26f0abb36fc80999576cc7beec065f3da07dbb1e Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 1 Sep 2025 14:35:05 +0200 Subject: [PATCH 16/53] pass new flag and fix db path passing issue --- .../generic_agent_hinter/generic_agent_prompt.py | 7 ++++++- src/agentlab/agents/tool_use_agent/tool_use_agent.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 19f0efda..84b5d332 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -51,6 +51,7 @@ class GenericPromptFlags(dp.Flags): use_hints: bool = False use_task_hint: bool = False task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" + skip_hints_for_current_task: bool = False hint_db_path: str = None enable_chat: bool = False max_prompt_tokens: int = None @@ -113,6 +114,7 @@ def time_for_caution(): goal=goal, hint_retrieval_mode=flags.task_hint_retrieval_mode, llm=llm, + skip_hints_for_current_task=flags.skip_hints_for_current_task, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -299,6 +301,7 @@ def __init__( hint_db_path: str, goal: str, hint_retrieval_mode: Literal["direct", "llm", "emb"], + skip_hints_for_current_task: bool, llm: ChatModel, ) -> None: super().__init__(visible=use_task_hint) @@ -306,6 +309,7 @@ def __init__( self.hint_db_rel_path = "hint_db.csv" self.hint_db_path = hint_db_path # Allow external path override self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task self.goal = goal self.llm = llm self._init() @@ -346,8 +350,9 @@ def _init(self): print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) self.hints_source = HintsSource( - hint_db_path=self.hint_db_rel_path, + hint_db_path=hint_db_path.as_posix(), hint_retrieval_mode=self.hint_retrieval_mode, + skip_hints_for_current_task=self.skip_hints_for_current_task, ) except Exception as e: # Fallback to empty database on any error diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 4e6de3b3..b8f21431 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -375,6 +375,7 @@ def __init__( else: self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") if self.hint_retrieval_mode == "emb": self.load_hint_vectors() @@ -395,16 +396,19 @@ def load_hint_vectors(self): def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: """Choose hints based on the task name.""" + logger.info( + f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" + ) if self.hint_retrieval_mode == "llm": - return self.choose_hints_llm(llm, goal) + return self.choose_hints_llm(llm, goal, task_name) elif self.hint_retrieval_mode == "direct": return self.choose_hints_direct(task_name) elif self.hint_retrieval_mode == "emb": - return self.choose_hints_emb(goal) + return self.choose_hints_emb(goal, task_name) else: raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") - def choose_hints_llm(self, llm, goal: str) -> list[str]: + def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: """Choose hints using LLM to filter the hints.""" topic_to_hints = defaultdict(list) hints_df = self.hint_db @@ -439,7 +443,7 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: hints = [] return hints - def choose_hints_emb(self, goal: str) -> list[str]: + def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: """Choose hints using embeddings to filter the hints.""" goal_embeddings = self._encode([goal], prompt="task description") hint_embeddings = self.hint_embeddings From 5393a34112beab3e92e339d429947c871bfeb67e Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 1 Sep 2025 14:59:54 +0200 Subject: [PATCH 17/53] fix goal text --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 84b5d332..b684b6c9 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -108,10 +108,11 @@ def time_for_caution(): self.be_cautious = dp.BeCautious(visible=time_for_caution) self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) + goal_str: str = goal[0]["text"] self.task_hint = TaskHint( use_task_hint=flags.use_task_hint, hint_db_path=flags.hint_db_path, - goal=goal, + goal=goal_str, hint_retrieval_mode=flags.task_hint_retrieval_mode, llm=llm, skip_hints_for_current_task=flags.skip_hints_for_current_task, From deddc50697b3871077d8000b2fa3fe0b48649b5d Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 1 Sep 2025 17:35:22 +0200 Subject: [PATCH 18/53] fix current task hints exclusion --- .../agents/tool_use_agent/tool_use_agent.py | 82 ++++++++++--------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index b8f21431..bd200da3 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -411,58 +411,62 @@ def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: """Choose hints using LLM to filter the hints.""" topic_to_hints = defaultdict(list) - hints_df = self.hint_db + skip_hints = [] if self.skip_hints_for_current_task: - current_task_hints = self.get_current_task_hints(task_name) - hints_df = hints_df[~hints_df["hint"].isin(current_task_hints)] - logger.info( - f"Filtered out current task hints, remaining hints: {hints_df.shape[0]} out of {self.hint_db.shape[0]}" - ) - for i, row in hints_df.iterrows(): - topic_to_hints[row["semantic_keys"]].append(i) + skip_hints = self.get_current_task_hints(task_name) + for _, row in self.hint_db.iterrows(): + hint = row["hint"] + if hint in skip_hints: + continue + topic_to_hints[row["semantic_keys"]].append(hint) + logger.info(f"Collected {len(topic_to_hints)} hint topics") hint_topics = list(topic_to_hints.keys()) topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) prompt = self.llm_prompt.format(goal=goal, topics=topics) + if isinstance(llm, ChatModel): response: str = llm(messages=[dict(role="user", content=prompt)])["content"] else: response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think try: - hint_topic_idx = json.loads(response) - if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): + topic_number = json.loads(response) + if topic_number < 0 or topic_number >= len(hint_topics): logger.error(f"Wrong LLM hint id response: {response}, no hints") return [] - hint_topic = hint_topics[hint_topic_idx] - hint_indices = topic_to_hints[hint_topic] - df = hints_df.iloc[hint_indices].copy() - df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints - hints = df["hint"].tolist() - logger.info(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") - except json.JSONDecodeError: - logger.error(f"Failed to parse LLM hint id response: {response}, no hints") + hint_topic = hint_topics[topic_number] + hints = list(set(topic_to_hints[hint_topic])) + logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") + except Exception as e: + logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") hints = [] return hints def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: """Choose hints using embeddings to filter the hints.""" - goal_embeddings = self._encode([goal], prompt="task description") - hint_embeddings = self.hint_embeddings - hints_df = self.uniq_hints - if self.skip_hints_for_current_task: - current_task_hints = self.get_current_task_hints(task_name) - mask = ~hints_df["hint"].isin(current_task_hints) - hints_df = hints_df[mask] - filtered_indices = hints_df.index.tolist() - hint_embeddings = hint_embeddings[filtered_indices] - logger.info( - f"Filtered same task hint, remained: {len(hint_embeddings)} out of {len(self.hint_embeddings)} embeddings" - ) - similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings.tolist()) - top_indices = similarities.argsort()[0][-self.top_n :].tolist() - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = hints_df.iloc[top_indices] - logger.info(f"Embedding-based hints chosen: {hints}") - return hints["hint"].tolist() + try: + goal_embeddings = self._encode([goal], prompt="task description") + hint_embeddings = self.hint_embeddings.copy() + all_hints = self.uniq_hints["hint"].tolist() + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + hint_embeddings = [] + id_to_hint = {} + for hint, emb in zip(all_hints, self.hint_embeddings): + if hint in skip_hints: + continue + hint_embeddings.append(emb.tolist()) + id_to_hint[len(hint_embeddings) - 1] = hint + logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = [id_to_hint[idx] for idx in top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + except Exception as e: + logger.exception(f"Failed to choose hints using embeddings: {e}") + hints = [] + return hints def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): """Call the encode API endpoint with timeout and retries""" @@ -483,7 +487,11 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret raise ValueError("Failed to encode hints") def _similarity( - self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 + self, + texts1: list, + texts2: list, + timeout: int = 2, + max_retries: int = 5, ): """Call the similarity API endpoint with timeout and retries""" for attempt in range(max_retries): From b9d09d4d8d2ee557a04b76c358496c21cc1657cd Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 2 Sep 2025 11:44:38 +0200 Subject: [PATCH 19/53] remove old reqs --- requirements.txt | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index a2798f2e..00000000 --- a/requirements.txt +++ /dev/null @@ -1,31 +0,0 @@ -black[jupyter]>=24.2.0 -blacken-docs -pre-commit -pytest==7.3.2 -flaky -pytest-xdist -pytest-playwright -pydantic~=2.9 -dask -distributed -browsergym>=0.7.1 -joblib>=1.2.0 -openai>=1.7,<2 -langchain_community -tiktoken -tapeagents[converters] -huggingface_hub -contexttimer -ipython -pyyaml>=6 -pandas -gradio>=5.5 # issue with DataFrame scrolling before 5.5 -gitpython # for the reproducibility script -requests -matplotlib -ray[default] -python-slugify -pillow -gymnasium>=0.27 -sentence-transformers>=5.0.0 -python-dotenv>=1.1.1 \ No newline at end of file From 725e7a03750780263cb6ce0190ef252fc2e3d688 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 2 Sep 2025 11:45:54 +0200 Subject: [PATCH 20/53] remove recorder from that brach --- .../agents/human_trace_recorder/agent.py | 302 ---------- .../human_trace_recorder/codegen_agent.py | 193 ------ .../human_trace_recorder/event_listeners.py | 563 ------------------ 3 files changed, 1058 deletions(-) delete mode 100644 src/agentlab/agents/human_trace_recorder/agent.py delete mode 100644 src/agentlab/agents/human_trace_recorder/codegen_agent.py delete mode 100644 src/agentlab/agents/human_trace_recorder/event_listeners.py diff --git a/src/agentlab/agents/human_trace_recorder/agent.py b/src/agentlab/agents/human_trace_recorder/agent.py deleted file mode 100644 index 556922af..00000000 --- a/src/agentlab/agents/human_trace_recorder/agent.py +++ /dev/null @@ -1,302 +0,0 @@ -"""Human Trace Agent for Browser Automation Training Data - -Captures human interactions at each step including: - - Comprehensive action tracking (clicks, input, navigation, etc.) - - Saves only human_action.json files in simple numbered folders -""" - -from __future__ import annotations - -import json -import time -from dataclasses import dataclass -from pathlib import Path - -import bgym -from playwright.sync_api import Page - -from agentlab.agents.agent_args import AgentArgs -from agentlab.agents.human_trace_recorder.event_listeners import ( - get_interaction_tracking_script, - get_recording_indicators_script, -) -from browsergym.core.observation import ( - extract_dom_extra_properties, - extract_dom_snapshot, - extract_focused_element_bid, - extract_merged_axtree, - extract_screenshot, -) -from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html - - -@dataclass -class HumanTraceAgentArgs(AgentArgs): - agent_name: str = "HumanTraceAgent" - trace_dir: str = "human_traces" - use_raw_page_output: bool = True - - def make_agent(self) -> bgym.Agent: # type: ignore[override] - return HumanTraceAgent(self.trace_dir) - - def set_reproducibility_mode(self): - pass - - -class HumanTraceAgent(bgym.Agent): - def __init__(self, trace_dir: str): - self.action_set = bgym.HighLevelActionSet(["bid"], multiaction=False) - self._root = Path(trace_dir) - self._root.mkdir(parents=True, exist_ok=True) - self._page: Page | None = None - self._step = 0 - - def obs_preprocessor(self, obs: dict): # type: ignore[override] - if isinstance(obs, dict): - self._page = obs.get("page") - # Remove the page object from obs to avoid pickle issues - if "page" in obs: - del obs["page"] - - obs["screenshot"] = extract_screenshot(self._page) - obs["dom_object"] = extract_dom_snapshot(self._page) - obs["axtree_object"] = extract_merged_axtree(self._page) - scale_factor = getattr(self._page, "_bgym_scale_factor", 1.0) - extra_properties = extract_dom_extra_properties( - obs["dom_object"], scale_factor=scale_factor - ) - obs["extra_element_properties"] = extra_properties - obs["focused_element_bid"] = extract_focused_element_bid(self._page) - - # Add text representations for easier analysis - if obs["axtree_object"]: - axt = obs["axtree_object"] - if extra_properties: - obs["axtree_txt"] = flatten_axtree_to_str(axt) - - if obs["dom_object"]: - obs["dom_txt"] = flatten_dom_to_str(obs["dom_object"]) - obs["pruned_html"] = prune_html(obs["dom_txt"]) - return obs - - def get_action(self, obs: dict): # type: ignore[override] - if self._page is None: - raise RuntimeError("Playwright Page missing; ensure use_raw_page_output=True") - - page = self._page - step_dir = self._create_step_directory() - - self._display_recording_prompt() - self._show_recording_indicators(page) - - # Capture human interactions - captured_action, human_interactions = self._capture_interactions_with_js(page, step_dir) - - # Save and cleanup - self._save_human_action(captured_action, step_dir) - self._cleanup_indicators(page) - - self._step += 1 - return "noop()", { - "extra_info": { - "step": self._step - 1, - "human_interactions": human_interactions, - } - } - - def _create_step_directory(self) -> Path: - """Create directory for current step.""" - step_dir = self._root / str(self._step) - step_dir.mkdir(parents=True, exist_ok=True) - return step_dir - - def _display_recording_prompt(self): - """Display prompt messages to user.""" - print(f"[HumanTrace] Step {self._step}: Perform ONE action") - print("[HumanTrace] ⚠️ WAIT FOR THE RED BORDER TO APPEAR BEFORE PERFORMING ANY ACTION ⚠️") - print("[HumanTrace] The system will automatically save after detecting your action") - - def _show_recording_indicators(self, page: Page): - """Show visual recording indicators on the page.""" - page.evaluate(get_recording_indicators_script()) - - def _save_human_action(self, captured_action: dict, step_dir: Path): - """Save the captured human action to JSON file.""" - try: - human_action_path = step_dir / "human_action.json" - if captured_action and isinstance(captured_action, dict): - human_action_path.write_text(json.dumps(captured_action, indent=2)) - action_type = captured_action.get("type", "unknown") - else: - # Create empty action record for consistency - empty_action = { - "type": "no_action", - "timestamp": time.time() * 1000, - "reason": "No meaningful human action captured in this step", - } - human_action_path.write_text(json.dumps(empty_action, indent=2)) - action_type = "no_action" - - print(f"[HumanTrace] Step {self._step} complete - Action: {action_type}") - - except Exception as e: - print(f"[HumanTrace] Warning: Failed to save human action: {e}") - - def _cleanup_indicators(self, page: Page): - """Remove recording indicators from the page.""" - page.evaluate("document.getElementById('__rec')?.remove(); document.getElementById('__rec_border')?.remove()") - - def _capture_interactions_with_js(self, page: Page, step_dir: Path) -> tuple[dict, str]: - """Capture human interactions using JavaScript injection.""" - try: - print("[HumanTrace] JavaScript interaction tracking enabled") - initial_url, initial_title = page.url, page.title() - - # Inject interaction tracking - self._inject_interaction_tracking(page) - - # Wait for user action - self._wait_for_user_action(page) - - # Collect and process interaction data - return self._collect_interaction_data(page, initial_url, initial_title) - - except Exception as e: - print(f"[HumanTrace] Error: {e}") - return { - "type": "error", - "timestamp": time.time() * 1000, - "error": str(e), - }, f"Error: {e}" - - def _inject_interaction_tracking(self, page: Page): - """Inject JavaScript code for comprehensive interaction tracking.""" - tracking_script = get_interaction_tracking_script() - page.evaluate(tracking_script) - - def _wait_for_user_action(self, page: Page): - """Wait for user to perform an action.""" - start_time = time.time() - while time.time() - start_time < 300: - try: - action_detected = page.evaluate("window.__acted || false") - if action_detected: - print(f"[HumanTrace] Action detected! Exiting immediately...") - break - except Exception as e: - print(f"[HumanTrace] Debug: Error checking actions: {e}") - pass - time.sleep(0.1) - - def _collect_interaction_data(self, page: Page, initial_url: str, initial_title: str) -> tuple[dict, str]: - """Collect and format interaction data.""" - try: - action_detected = page.evaluate("window.__acted || false") - interactions = page.evaluate("window.__interactions || []") - - action_data = { - "type": "human_interactions" if action_detected else "no_action", - "timestamp": time.time() * 1000, - "detected": action_detected, - "interactions": interactions, - "interaction_count": len(interactions) - } - - summary = self._create_interaction_summary(interactions) - self._add_page_change_info(action_data, initial_url, initial_title, page) - - print(f"[HumanTrace] {summary}") - return action_data, summary - - except Exception as e: - return { - "type": "error", - "timestamp": time.time() * 1000, - "detected": False, - "error": str(e), - "interactions": [], - "interaction_count": 0 - }, f"Error collecting interactions: {e}" - - def _create_interaction_summary(self, interactions: list) -> str: - """Create a summary string of captured interactions.""" - if interactions: - interaction_types = {} - for interaction in interactions: - itype = interaction.get('type', 'unknown') - interaction_types[itype] = interaction_types.get(itype, 0) + 1 - - summary_parts = [] - for itype, count in interaction_types.items(): - summary_parts.append(f"{itype}:{count}") - return f"Captured {len(interactions)} interactions: {', '.join(summary_parts)}" - else: - return "No interactions detected" - - def _add_page_change_info(self, action_data: dict, initial_url: str, initial_title: str, page: Page): - """Add page change information to action data.""" - final_url, final_title = page.url, page.title() - if initial_url != final_url or initial_title != final_title: - action_data["page_changed"] = True - action_data["url_change"] = {"from": initial_url, "to": final_url} - action_data["title_change"] = {"from": initial_title, "to": final_title} - - def _format_js_interaction_summary(self, action_data, interaction_log): - """Format JavaScript-captured interactions into readable summary.""" - lines = ["Human Interactions (JavaScript Tracking):"] - - if action_data["interactions"]: - lines.append(f"Total Actions: {len(action_data['interactions'])}") - lines.append("") - - # Group interactions by type - by_type = {} - for interaction in action_data["interactions"]: - interaction_type = interaction["type"] - if interaction_type not in by_type: - by_type[interaction_type] = [] - by_type[interaction_type].append(interaction) - - # Show summary by type - for interaction_type, interactions in by_type.items(): - lines.append(f"{interaction_type.title()}: {len(interactions)} actions") - - lines.append("") - lines.append("Detailed Actions:") - - # Add each interaction from the log - for log_entry in interaction_log: - lines.append(f" {log_entry}") - else: - lines.append("No interactions detected - user may have just observed the page") - - # Add page state changes if URL changed - if action_data.get("page_changed"): - url_info = action_data.get("url") - if url_info: - lines.append("") - lines.append("� Page Navigation:") - lines.append(f" From: {url_info['from']}") - lines.append(f" To: {url_info['to']}") - - return "\n".join(lines) - - -HUMAN_TRACE_AGENT = HumanTraceAgentArgs() - - -if __name__ == "__main__": - from agentlab.agents.human_trace_recorder.agent import HUMAN_TRACE_AGENT - from agentlab.experiments.study import Study - - agent_configs = [HUMAN_TRACE_AGENT] - benchmark = bgym.DEFAULT_BENCHMARKS["workarena_l1"](n_repeats=1) # type: bgym.Benchmark - benchmark = benchmark.subset_from_glob("task_name", "*filter*") - benchmark.env_args_list = benchmark.env_args_list[:1] - for env_args in benchmark.env_args_list: - print(env_args.task_name) - env_args.max_steps = 15 - env_args.headless = False - - study = Study(agent_configs, benchmark) - study.run(n_jobs=1, parallel_backend="sequential") diff --git a/src/agentlab/agents/human_trace_recorder/codegen_agent.py b/src/agentlab/agents/human_trace_recorder/codegen_agent.py deleted file mode 100644 index cd902bd2..00000000 --- a/src/agentlab/agents/human_trace_recorder/codegen_agent.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Simple Codegen Agent - -Captures human interactions using playwright inspector. -Playwright trace logs are stored in "think" messages and can be viewed in Agentlab Xray. -""" - -from __future__ import annotations - -import json -import logging -import tempfile -import zipfile -from dataclasses import dataclass -from pathlib import Path - -import bgym -from agentlab.agents.agent_args import AgentArgs -from browsergym.core.observation import ( - extract_dom_extra_properties, - extract_dom_snapshot, - extract_focused_element_bid, - extract_merged_axtree, - extract_screenshot, -) -from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html -from dotenv import load_dotenv -from playwright.sync_api import Page - -load_dotenv() - -def extract_log_message_from_pw_trace(pw_trace_file_path): - zip_file = zipfile.ZipFile(pw_trace_file_path, "r") - trace_lines = zip_file.read("trace.trace").decode("utf-8").splitlines() - - actions = [] - for line in trace_lines: - if line.strip(): - event = json.loads(line) - if event.get("type") == "log": - actions.append(event) - # Extract log messages from the trace - return [log["message"].strip() for log in sorted(actions, key=lambda x: x.get("time", 0))] - - -def clean_pw_logs(logs, exclude_blacklist=True, use_substitutions=True): - clean_logs = list(logs) - blacklist = { - "attempting click action", - "waiting for element to be visible, enabled and stable", - "element is visible, enabled and stable", - "scrolling into view if needed", - "done scrolling", - "performing click action", - "click action done", - "waiting for scheduled navigations to finish", - "navigations have finished", - } - - substitutions = [("waiting for ", "")] - - def apply_substitutions(log): - for old, new in substitutions: - log = log.replace(old, new) - return log - - if exclude_blacklist: - clean_logs = [log for log in clean_logs if log not in blacklist] - if use_substitutions: - clean_logs = [apply_substitutions(log) for log in clean_logs] - - return clean_logs - - -@dataclass -class PlayWrightCodeGenAgentArgs(AgentArgs): - agent_name: str = "PlayWrightCodeGenAgent" - trace_dir: str = "playwright_codegen_traces" - use_raw_page_output: bool = True - store_raw_trace: bool = False - - def make_agent(self) -> bgym.Agent: # type: ignore[override] - return PlayWrightCodeGenAgent(self.trace_dir, self.store_raw_trace) - - def set_reproducibility_mode(self): - pass - - -class PlayWrightCodeGenAgent(bgym.Agent): - def __init__(self, trace_dir: str, store_raw_trace: bool): - self.action_set = bgym.HighLevelActionSet(["bid"], multiaction=False) - self._root = Path(trace_dir) - self._page: Page | None = None - self._step = 0 - self.store_raw_trace = store_raw_trace - self._episode_trace_dir = None # Cache for single episode - - def _get_trace_dir(self): - """Return the trace directory based on store_raw_trace setting.""" - if self._episode_trace_dir is None: - if self.store_raw_trace: - import datetime - - dt_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - self._episode_trace_dir = self._root / f"codegen_traces_{dt_str}" - self._episode_trace_dir.mkdir(parents=True, exist_ok=True) - else: - self._episode_trace_dir = Path(tempfile.mkdtemp()) - return self._episode_trace_dir - - def obs_preprocessor(self, obs: dict): # type: ignore[override] - if isinstance(obs, dict): - self._page = obs.get("page") - obs["screenshot"] = extract_screenshot(self._page) - obs["dom_object"] = extract_dom_snapshot(self._page) - obs["axtree_object"] = extract_merged_axtree(self._page) - scale_factor = getattr(self._page, "_bgym_scale_factor", 1.0) - extra_properties = extract_dom_extra_properties( - obs["dom_object"], scale_factor=scale_factor - ) - obs["extra_element_properties"] = extra_properties - obs["focused_element_bid"] = extract_focused_element_bid(self._page) - - if obs["axtree_object"]: - obs["axtree_txt"] = flatten_axtree_to_str(obs["axtree_object"]) - - if obs["dom_object"]: - obs["dom_txt"] = flatten_dom_to_str(obs["dom_object"]) - obs["pruned_html"] = prune_html(obs["dom_txt"]) - - if "page" in obs: # unpickable - del obs["page"] - - return obs - - def get_action(self, obs: dict): # type: ignore[override] - - if self._page is None: - raise RuntimeError("Playwright Page missing; ensure use_raw_page_output=True") - - page = self._page - trace_dir = self._get_trace_dir() - trace_path = trace_dir / f"step_{self._step}.zip" - page.context.tracing.start(screenshots=True, snapshots=True, sources=True) - page.context.tracing.start_chunk(name=f"step_{self._step}") - - print( - f"{'─'*60}\n" f"Step {self._step}\n", - f"{'─'*60}\n", - "1. 🔴 Start Recording (Press 'Record' in the Playwright Inspector.)\n", - "2. ✨ Perform actions for a single step.\n", - "3. ⚫ Stop Recording (Press 'Record' again to stop recording.)\n", - "4. ▶️ Press 'Resume' in the Playwright Inspector.", - ) - - page.pause() # Launch Inspector and record actions - page.context.tracing.stop_chunk(path=trace_path) - page.context.tracing.stop() - - pw_logs = extract_log_message_from_pw_trace(trace_path) - pw_logs = clean_pw_logs(pw_logs, exclude_blacklist=True) - pw_logs_str = "\n".join([f"{i}. {log}" for i, log in enumerate(pw_logs, 1)]) - - print(f"\n Playwright logs for step {self._step}:\n{pw_logs_str}") - - self._step += 1 - - agent_info = bgym.AgentInfo( - think=pw_logs_str, - chat_messages=[], - stats={}, - ) - - return "noop()", agent_info - - -PW_CODEGEN_AGENT = PlayWrightCodeGenAgentArgs(store_raw_trace=True) - - -if __name__ == "__main__": - from agentlab.agents.human_trace_recorder.codegen_agent import PW_CODEGEN_AGENT - from agentlab.experiments.study import Study - - agent_configs = [PW_CODEGEN_AGENT] - benchmark = bgym.DEFAULT_BENCHMARKS["workarena_l1"]() # type: bgym.Benchmark - benchmark = benchmark.subset_from_glob("task_name", "*create*") - benchmark.env_args_list = benchmark.env_args_list[:1] - for env_args in benchmark.env_args_list: - print(env_args.task_name) - env_args.max_steps = 15 - env_args.headless = False - - study = Study(agent_configs, benchmark, logging_level_stdout=logging.INFO) - study.run(n_jobs=1, parallel_backend="sequential", n_relaunch=1) diff --git a/src/agentlab/agents/human_trace_recorder/event_listeners.py b/src/agentlab/agents/human_trace_recorder/event_listeners.py deleted file mode 100644 index 2fd8453c..00000000 --- a/src/agentlab/agents/human_trace_recorder/event_listeners.py +++ /dev/null @@ -1,563 +0,0 @@ -"""JavaScript Event Listeners for Human Trace Capture - -This module contains all the JavaScript code for capturing comprehensive -browser interactions including mouse, keyboard, form, scroll, and focus events. -""" - - -def get_interaction_tracking_script() -> str: - """Get the complete JavaScript code for interaction tracking.""" - return ( - """ - window.__acted = false; - window.__interactions = []; - - // Debug mode - set to true to see all events in console - window.__debug_events = false; - - function captureInteraction(type, event, extra = {}) { - // Skip our own recording indicators - if (event.target.id === '__rec' || event.target.id === '__rec_border' || - event.target.closest('#__rec') || event.target.closest('#__rec_border')) { - return; - } - - const interaction = { - type: type, - timestamp: Date.now(), - coords: { - x: event.clientX || 0, - y: event.clientY || 0 - }, - target: { - tagName: event.target.tagName, - id: event.target.id || null, - className: event.target.className || null, - text: event.target.textContent?.slice(0, 50) || null, - bid: event.target.getAttribute('bid') || null - }, - ...extra - }; - - window.__interactions.push(interaction); - window.__acted = true; - - // Debug logging - if (window.__debug_events) { - console.log(`🎯 Captured: ${type}`, interaction); - } - - // Update indicators immediately - const indicator = document.getElementById('__rec'); - const border = document.getElementById('__rec_border'); - if (indicator) { - indicator.innerHTML = '✅ ACTION DETECTED - SAVING...'; - indicator.style.background = '#28a745'; - indicator.style.animation = 'none'; - } - if (border) { - border.style.border = '8px solid #28a745'; - border.style.animation = 'none'; - } - } - - // Debug function - add this temporarily to see what events fire - if (window.__debug_events) { - ['input', 'change', 'select', 'focus', 'click', 'keydown', 'paste', 'cut', 'copy'].forEach(eventType => { - document.addEventListener(eventType, (e) => { - console.log(`🔍 DEBUG: ${eventType} on`, e.target.tagName, e.target.type, e.target); - }, true); - }); - } - - """ - + get_mouse_event_listeners() - + """ - """ - + get_keyboard_event_listeners() - + """ - """ - + get_form_event_listeners() - + """ - """ - + get_scroll_event_listeners() - + """ - """ - + get_focus_event_listeners() - + """ - - console.log('Comprehensive interaction tracking initialized'); - """ - ) - - -def get_mouse_event_listeners() -> str: - """Get JavaScript code for mouse event listeners.""" - return """ - // Mouse events with comprehensive button tracking and performance optimizations - let lastClickTime = 0; - - document.addEventListener('click', (e) => { - const now = Date.now(); - // Prevent spam clicking from creating too many events (minimum 50ms between clicks) - if (now - lastClickTime < 50) return; - lastClickTime = now; - - captureInteraction('click', e, { - button: e.button, // 0=left, 1=middle, 2=right - buttons: e.buttons, // bitmask of pressed buttons - buttonName: ['left', 'middle', 'right'][e.button] || 'unknown', - detail: e.detail, // click count (single, double, etc.) - clickType: e.detail === 1 ? 'single' : e.detail === 2 ? 'double' : `${e.detail}x` - }); - }, true); - - document.addEventListener('dblclick', (e) => { - captureInteraction('dblclick', e, { - button: e.button, - buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' - }); - }, true); - - document.addEventListener('mousedown', (e) => { - captureInteraction('mousedown', e, { - button: e.button, - buttons: e.buttons, - buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' - }); - }, true); - - document.addEventListener('mouseup', (e) => { - captureInteraction('mouseup', e, { - button: e.button, - buttons: e.buttons, - buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' - }); - }, true); - - // Context menu (right-click menu) - document.addEventListener('contextmenu', (e) => { - captureInteraction('contextmenu', e, { - button: e.button, - buttonName: 'right' - }); - }, true); - - // Middle mouse button events (often used for scrolling/opening in new tab) - document.addEventListener('auxclick', (e) => { - captureInteraction('auxclick', e, { - button: e.button, - buttonName: e.button === 1 ? 'middle' : (e.button === 2 ? 'right' : 'other'), - detail: e.detail - }); - }, true); - - // Enhanced drag tracking (without redundant mousedown) - let isDragging = false; - let dragStart = null; - let dragButton = null; - let hasDraggedSignificantly = false; - - document.addEventListener('mousedown', (e) => { - isDragging = true; - dragButton = e.button; - hasDraggedSignificantly = false; - dragStart = { - x: e.clientX, - y: e.clientY, - time: Date.now(), - button: e.button, - buttonName: ['left', 'middle', 'right'][e.button] || 'unknown' - }; - }, true); - - document.addEventListener('mousemove', (e) => { - if (isDragging && dragStart) { - const distance = Math.sqrt( - Math.pow(e.clientX - dragStart.x, 2) + - Math.pow(e.clientY - dragStart.y, 2) - ); - if (distance > 5 && !hasDraggedSignificantly) { - // Only capture the start of a significant drag, not every movement - hasDraggedSignificantly = true; - captureInteraction('drag_start', e, { - startX: dragStart.x, - startY: dragStart.y, - endX: e.clientX, - endY: e.clientY, - distance: distance, - button: dragButton, - buttonName: dragStart.buttonName, - duration: Date.now() - dragStart.time - }); - } - } - // Note: Removed general mousemove tracking to reduce noise - }, true); - - document.addEventListener('mouseup', (e) => { - if (isDragging && dragStart && hasDraggedSignificantly) { - const distance = Math.sqrt( - Math.pow(e.clientX - dragStart.x, 2) + - Math.pow(e.clientY - dragStart.y, 2) - ); - captureInteraction('drag_end', e, { - startX: dragStart.x, - startY: dragStart.y, - endX: e.clientX, - endY: e.clientY, - distance: distance, - duration: Date.now() - dragStart.time, - button: dragButton, - buttonName: dragStart.buttonName, - totalDistance: distance - }); - } - isDragging = false; - dragStart = null; - dragButton = null; - hasDraggedSignificantly = false; - }, true); - - // Drag and drop events - document.addEventListener('dragstart', (e) => { - captureInteraction('dragstart', e, { - dataTransfer: { - effectAllowed: e.dataTransfer.effectAllowed, - types: Array.from(e.dataTransfer.types) - } - }); - }, true); - - document.addEventListener('dragend', (e) => { - captureInteraction('dragend', e, { - dataTransfer: { - dropEffect: e.dataTransfer.dropEffect - } - }); - }, true); - - document.addEventListener('drop', (e) => { - captureInteraction('drop', e, { - dataTransfer: { - dropEffect: e.dataTransfer.dropEffect, - types: Array.from(e.dataTransfer.types) - }, - files: e.dataTransfer.files.length > 0 ? Array.from(e.dataTransfer.files).map(f => ({ - name: f.name, - type: f.type, - size: f.size - })) : null - }); - }, true); - """ - - -def get_keyboard_event_listeners() -> str: - """Get JavaScript code for keyboard event listeners.""" - return """ - // Keyboard events with shortcut detection - document.addEventListener('keydown', (e) => { - let shortcut = null; - if (e.ctrlKey || e.metaKey) { - const modifier = e.ctrlKey ? 'Ctrl' : 'Cmd'; - const key = e.key.length === 1 ? e.key.toUpperCase() : e.key; - shortcut = `${modifier}+${key}`; - } else if (e.altKey && e.key.length === 1) { - shortcut = `Alt+${e.key.toUpperCase()}`; - } else if (e.shiftKey && e.key.length === 1) { - shortcut = `Shift+${e.key.toUpperCase()}`; - } - - captureInteraction('keydown', e, { - key: e.key, - code: e.code, - ctrlKey: e.ctrlKey, - shiftKey: e.shiftKey, - altKey: e.altKey, - metaKey: e.metaKey, - shortcut: shortcut - }); - }, true); - - document.addEventListener('keyup', (e) => { - captureInteraction('keyup', e, { - key: e.key, - code: e.code - }); - }, true); - """ - - -def get_form_event_listeners() -> str: - """Get JavaScript code for form event listeners.""" - return """ - // Input events with throttling to prevent spam during fast typing - let inputTimeout; - let lastInputValue = ''; - - document.addEventListener('input', (e) => { - if (['INPUT', 'TEXTAREA'].includes(e.target.tagName) || e.target.contentEditable === 'true') { - clearTimeout(inputTimeout); - inputTimeout = setTimeout(() => { - const currentValue = e.target.value || e.target.textContent; - // Only capture if value actually changed significantly - if (currentValue !== lastInputValue) { - lastInputValue = currentValue; - captureInteraction('input', e, { - value: currentValue, - inputType: e.inputType || null, - valueLength: currentValue.length - }); - } - }, 50); // Reduced from 300ms to 50ms for better responsiveness - } - }, true); - - // Immediate input capture (without throttling) for certain cases - document.addEventListener('input', (e) => { - // Immediate capture for dropdown/select-like inputs or when selection changes - if (e.target.tagName === 'SELECT' || - e.inputType === 'deleteContentBackward' || - e.inputType === 'insertFromPaste' || - e.inputType === 'insertFromDrop') { - captureInteraction('input_immediate', e, { - value: e.target.value || e.target.textContent, - inputType: e.inputType || null, - immediate: true - }); - } - }, true); - - // Text selection events - document.addEventListener('select', (e) => { - if (['INPUT', 'TEXTAREA'].includes(e.target.tagName)) { - const selectedText = e.target.value.substring(e.target.selectionStart, e.target.selectionEnd); - captureInteraction('select', e, { - selectedText: selectedText, - selectionStart: e.target.selectionStart, - selectionEnd: e.target.selectionEnd, - value: e.target.value, - selectionLength: selectedText.length - }); - } - }, true); - - // Clipboard events - document.addEventListener('cut', (e) => { - captureInteraction('cut', e, { - clipboardData: e.clipboardData ? Array.from(e.clipboardData.types) : null, - targetValue: e.target.value || e.target.textContent - }); - }, true); - - document.addEventListener('copy', (e) => { - captureInteraction('copy', e, { - clipboardData: e.clipboardData ? Array.from(e.clipboardData.types) : null, - targetValue: e.target.value || e.target.textContent - }); - }, true); - - document.addEventListener('paste', (e) => { - captureInteraction('paste', e, { - clipboardData: e.clipboardData ? Array.from(e.clipboardData.types) : null, - targetValue: e.target.value || e.target.textContent - }); - }, true); - - // Enhanced form change events with better dropdown handling - document.addEventListener('change', (e) => { - let extra = {}; - if (e.target.tagName === 'SELECT') { - const option = e.target.options[e.target.selectedIndex]; - extra = { - selectedValue: e.target.value, - selectedText: option?.text || '', - selectedIndex: e.target.selectedIndex, - allOptions: Array.from(e.target.options).map(opt => ({ - value: opt.value, - text: opt.text, - selected: opt.selected - })), - optionsCount: e.target.options.length - }; - } else if (['checkbox', 'radio'].includes(e.target.type)) { - extra = { - checked: e.target.checked, - value: e.target.value, - name: e.target.name - }; - } else { - extra = { - value: e.target.value, - previousValue: e.target.defaultValue, // Capture what it was before - inputType: e.target.type - }; - } - captureInteraction('change', e, extra); - }, true); - - document.addEventListener('submit', (e) => { - captureInteraction('submit', e, { - formAction: e.target.action || null, - formMethod: e.target.method || 'GET', - formElements: Array.from(e.target.elements).length - }); - }, true); - - // Additional events for better field interaction capture - - // Option selection in datalists - document.addEventListener('input', (e) => { - if (e.target.list) { // Has datalist - captureInteraction('datalist_input', e, { - value: e.target.value, - listId: e.target.list.id, - optionsCount: e.target.list.options.length - }); - } - }, true); - - // File input changes - document.addEventListener('change', (e) => { - if (e.target.type === 'file') { - captureInteraction('file_select', e, { - filesCount: e.target.files.length, - files: Array.from(e.target.files).map(file => ({ - name: file.name, - type: file.type, - size: file.size, - lastModified: file.lastModified - })) - }); - } - }, true); - """ - - -def get_scroll_event_listeners() -> str: - """Get JavaScript code for scroll event listeners.""" - return """ - // Scroll events with debouncing to reduce noise - let scrollTimeout; - let lastScrollTime = 0; - - document.addEventListener('scroll', (e) => { - clearTimeout(scrollTimeout); - scrollTimeout = setTimeout(() => { - const now = Date.now(); - // Only capture scroll if it's been at least 200ms since last scroll capture - if (now - lastScrollTime > 200) { - lastScrollTime = now; - captureInteraction('scroll', e, { - scrollX: window.scrollX, - scrollY: window.scrollY, - scrollLeft: e.target.scrollLeft || 0, - scrollTop: e.target.scrollTop || 0 - }); - } - }, 150); // Increased debounce time - }, true); - - // Wheel events (for detailed scroll tracking) with throttling - let lastWheelTime = 0; - document.addEventListener('wheel', (e) => { - const now = Date.now(); - // Only capture wheel events every 100ms to reduce noise - if (now - lastWheelTime > 100) { - lastWheelTime = now; - captureInteraction('wheel', e, { - deltaX: e.deltaX, - deltaY: e.deltaY, - deltaZ: e.deltaZ, - deltaMode: e.deltaMode - }); - } - }, true); - """ - - -def get_focus_event_listeners() -> str: - """Get JavaScript code for focus event listeners.""" - return """ - // Focus events - only for interactive elements to reduce noise - document.addEventListener('focus', (e) => { - // Only capture focus on interactive elements - const interactiveElements = ['INPUT', 'TEXTAREA', 'SELECT', 'BUTTON', 'A']; - if (interactiveElements.includes(e.target.tagName) || - e.target.contentEditable === 'true' || - e.target.tabIndex >= 0) { - captureInteraction('focus', e); - } - }, true); - - document.addEventListener('blur', (e) => { - // Only capture blur on interactive elements - const interactiveElements = ['INPUT', 'TEXTAREA', 'SELECT', 'BUTTON', 'A']; - if (interactiveElements.includes(e.target.tagName) || - e.target.contentEditable === 'true' || - e.target.tabIndex >= 0) { - captureInteraction('blur', e); - } - }, true); - """ - - -def get_recording_indicators_script() -> str: - """Get JavaScript code for recording indicators.""" - return """ - // Remove any existing indicators - const existingBorder = document.getElementById('__rec_border'); - if (existingBorder) existingBorder.remove(); - const existingIndicator = document.getElementById('__rec'); - if (existingIndicator) existingIndicator.remove(); - - // Create border overlay - const border = document.createElement('div'); - border.id = '__rec_border'; - border.style.cssText = ` - position: fixed; - top: 0; - left: 0; - width: 100vw; - height: 100vh; - border: 8px solid #ff0000; - box-sizing: border-box; - pointer-events: none; - z-index: 999999; - animation: pulse 1.5s infinite; - `; - - // Create status indicator - const indicator = document.createElement('div'); - indicator.id = '__rec'; - indicator.innerHTML = '🔴 RECORDING - Perform your action now'; - indicator.style.cssText = ` - position: fixed; - top: 10px; - left: 50%; - transform: translateX(-50%); - background: #ff0000; - color: #fff; - padding: 12px 20px; - border-radius: 8px; - font: bold 10px -apple-system, BlinkMacSystemFont, sans-serif; - z-index: 9999999; - box-shadow: 0 4px 12px rgba(255,0,0,0.4); - animation: pulse 1.5s infinite; - `; - - // Add pulsing animation - const style = document.createElement('style'); - style.textContent = ` - @keyframes pulse { - 0% { opacity: 1; } - 50% { opacity: 0.4; } - 100% { opacity: 0.8; } - } - `; - document.head.appendChild(style); - - document.body.appendChild(border); - document.body.appendChild(indicator); - """ From e93fde52dbd3f97e4072a2bd624b115365cb3b17 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 2 Sep 2025 15:58:38 +0000 Subject: [PATCH 21/53] log task errors --- src/agentlab/experiments/graph_execution_ray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/experiments/graph_execution_ray.py b/src/agentlab/experiments/graph_execution_ray.py index f047f866..f7aad780 100644 --- a/src/agentlab/experiments/graph_execution_ray.py +++ b/src/agentlab/experiments/graph_execution_ray.py @@ -3,9 +3,8 @@ import bgym import ray -from ray.util import state - from agentlab.experiments.exp_utils import _episode_timeout, run_exp +from ray.util import state logger = logging.getLogger(__name__) @@ -79,6 +78,7 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter try: result = ray.get(task) except Exception as e: + logger.exception(f"Task failed: {e}") result = e results.append(result) From 5604ac36c861128b296a9b894497297d1e749146 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 2 Sep 2025 15:59:11 +0000 Subject: [PATCH 22/53] expore agentlabxray --- src/agentlab/analyze/agent_xray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 8accbfd6..b60c0dcb 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -539,7 +539,7 @@ def run_gradio(results_dir: Path): port = os.getenv("AGENTXRAY_APP_PORT", None) if isinstance(port, str): port = int(port) - demo.launch(server_port=port, share=do_share) + demo.launch(server_name="0.0.0.0", server_port=port, share=do_share) def handle_key_event(key_event, step_id: StepId): From 0e68bcab654ffe052a4445aafb338f74ae7400a2 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 3 Sep 2025 09:39:32 +0000 Subject: [PATCH 23/53] remove commented old chunk --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index b684b6c9..983c9d48 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -378,9 +378,6 @@ def get_hints_for_task(self, task_name: str) -> str: return "" try: - # task_hints = self.hint_db[ - # self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - # ] task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) hints = [] From e4cad16a9dd83401945624623ed4871dc30cc5dd Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 3 Sep 2025 09:43:19 +0000 Subject: [PATCH 24/53] share xray only when env flag present --- src/agentlab/analyze/agent_xray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index b60c0dcb..fed78b3e 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -537,9 +537,10 @@ def run_gradio(results_dir: Path): do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", "false").lower() == "true" port = os.getenv("AGENTXRAY_APP_PORT", None) + server_name = "0.0.0.0" if os.getenv("AGENTXRAY_PUBLIC", "false") == "true" else "127.0.0.1" if isinstance(port, str): port = int(port) - demo.launch(server_name="0.0.0.0", server_port=port, share=do_share) + demo.launch(server_name=server_name, server_port=port, share=do_share) def handle_key_event(key_event, step_id: StepId): From 94fa1ab7fb7a69e1ae75a6964a4d77098f0050f0 Mon Sep 17 00:00:00 2001 From: recursix Date: Thu, 4 Sep 2025 16:46:46 -0400 Subject: [PATCH 25/53] Add StepWiseQueriesPrompt for enhanced query handling in GenericAgent --- .../agents/generic_agent/generic_agent.py | 1 + .../generic_agent_hinter/generic_agent.py | 43 ++++++++++- .../generic_agent_prompt.py | 73 +++++++++++++++++++ 3 files changed, 114 insertions(+), 3 deletions(-) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index d1f48f76..646a52b2 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict: def get_action(self, obs): self.obs_history.append(obs) + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index cfbd19bd..c8368039 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -23,7 +23,11 @@ from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from .generic_agent_prompt import GenericPromptFlags, MainPrompt +from .generic_agent_prompt import ( + GenericPromptFlags, + MainPrompt, + StepWiseRetrievalPrompt, +) @dataclass @@ -102,6 +106,16 @@ def set_task_name(self, task_name: str): def get_action(self, obs): self.obs_history.append(obs) + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + + queries, think_queries = self._get_queries() + + # TODO + # use those queries to retreive from the database. e.g.: + # hints = self.hint_db.get_hints(queries) + # then add those hints to the main prompt + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, @@ -120,8 +134,6 @@ def get_action(self, obs): max_prompt_tokens, max_trunc_itr = self._get_maxes() - system_prompt = SystemMessage(dp.SystemPrompt().prompt) - human_prompt = dp.fit_tokens( shrinkable=main_prompt, max_prompt_tokens=max_prompt_tokens, @@ -168,6 +180,31 @@ def get_action(self, obs): ) return ans_dict["action"], agent_info + def _get_queries(self): + """Retrieve queries for hinting.""" + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + query_prompt = StepWiseRetrievalPrompt( + obs_history=self.obs_history, + actions=self.actions, + thoughts=self.thoughts, + obs_flags=self.flags.obs, + n_queries=self.flags.n_retrieval_queries, # TODO + ) + + chat_messages = Discussion([system_prompt, query_prompt.prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=query_prompt._parse_answer, + ) + + queries = ans_dict.get("queries", []) + assert len(queries) == self.flags.n_retrieval_queries + + # TODO: we should probably propagate these chat_messages to be able to see them in xray + return queries, ans_dict.get("think", None) + def reset(self, seed=None): self.seed = seed self.plan = "No plan yet" diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 983c9d48..44d17845 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -4,6 +4,7 @@ It is based on the dynamic_prompting module from the agentlab package. """ +import json import logging from dataclasses import dataclass from pathlib import Path @@ -60,6 +61,7 @@ class GenericPromptFlags(dp.Flags): add_missparsed_messages: bool = True max_trunc_itr: int = 20 flag_group: str = None + n_retrieval_queries: int = 3 class MainPrompt(dp.Shrinkable): @@ -396,3 +398,74 @@ def get_hints_for_task(self, task_name: str) -> str: print(f"Warning: Error getting hints for task {task_name}: {e}") return "" + + +class StepWiseRetrievalPrompt(dp.Shrinkable): + def __init__( + self, + obs_history: list[dict], + actions: list[str], + thoughts: list[str], + obs_flags: dp.ObsFlags, + n_queries: int = 3, + ) -> None: + super().__init__() + self.obs_flags = obs_flags + self.n_queries = n_queries + self.history = dp.History(obs_history, actions, None, thoughts, obs_flags) + self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"]) + self.obs = dp.Observation(obs_history[-1], obs_flags) + + self.think = dp.Think(visible=True) # To replace with static text maybe + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +""" + ) + + example_queries = [ + "How to sort with multiple columns on the ServiceNow platform?", + "What are the potential challenges of sorting by multiple columns?", + "How to handle sorting by multiple columns in a table?", + "Can I use the filter tool to sort by multiple columns?", + ] + + example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) + + prompt.add_text( + f""" +# Querying memory + +Before choosing an action, let's search our available documentation and memory on how to approach this step. +This could provide valuable hints on how to properly solve this task. Return your answer as follow +chain of thought +json list of strings for the queries. Return exactly {self.n_queries} +queries in the list. + +# Concrete Example + + +I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if +I will be able to sort by both at the same time. + + + +{example_queries_str} + +""" + ) + + return self.obs.add_screenshot(prompt) + + def _parse_answer(self, text_answer): + ans_dict = parse_html_tags_raise( + text_answer, keys=["think", "queries"], merge_multiple=True + ) + ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) + return ans_dict From 69048c48dd26848cd329103c8131bca72567071c Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Mon, 8 Sep 2025 18:01:51 -0400 Subject: [PATCH 26/53] update hinting agent retrieval --- .../generic_agent_hinter/generic_agent.py | 2 +- .../generic_agent_prompt.py | 114 ++++++++++++++---- 2 files changed, 93 insertions(+), 23 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index cfbd19bd..5e04df0d 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -34,7 +34,7 @@ class GenericAgentArgs(AgentArgs): def __post_init__(self): try: # some attributes might be temporarily args.CrossProd for hyperparameter generation - self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_") + self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_") except AttributeError: pass diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 983c9d48..76205341 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -60,6 +60,13 @@ class GenericPromptFlags(dp.Flags): add_missparsed_messages: bool = True max_trunc_itr: int = 20 flag_group: str = None + # hint flags + hint_type: Literal["human", "llm", "docs"] = "human" + hint_index_type: Literal["sparse", "dense"] = "sparse" + hint_query_type: Literal["direct", "llm", "emb"] = "direct" + hint_index_path: str = None + hint_retriever_path: str = None + hint_num_results: int = 5 class MainPrompt(dp.Shrinkable): @@ -116,6 +123,13 @@ def time_for_caution(): hint_retrieval_mode=flags.task_hint_retrieval_mode, llm=llm, skip_hints_for_current_task=flags.skip_hints_for_current_task, + # hint related + hint_type=flags.hint_type, + hint_index_type=flags.hint_index_type, + hint_query_type=flags.hint_query_type, + hint_index_path=flags.hint_index_path, + hint_retriever_path=flags.hint_retriever_path, + hint_num_results=flags.hint_num_results, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -301,12 +315,24 @@ def __init__( use_task_hint: bool, hint_db_path: str, goal: str, - hint_retrieval_mode: Literal["direct", "llm", "emb"], - skip_hints_for_current_task: bool, llm: ChatModel, + hint_type: Literal["human", "llm", "docs"] = "human", + hint_index_type: Literal["sparse", "dense"] = "sparse", + hint_query_type: Literal["direct", "llm", "emb"] = "direct", + hint_index_path: str = None, + hint_retriever_path: str = None, + hint_num_results: int = 5, + skip_hints_for_current_task: bool = False, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", ) -> None: super().__init__(visible=use_task_hint) self.use_task_hint = use_task_hint + self.hint_type = hint_type + self.hint_index_type = hint_index_type + self.hint_query_type = hint_query_type + self.hint_index_path = hint_index_path + self.hint_retriever_path = hint_retriever_path + self.hint_num_results = hint_num_results self.hint_db_rel_path = "hint_db.csv" self.hint_db_path = hint_db_path # Allow external path override self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode @@ -333,28 +359,46 @@ def __init__( def _init(self): """Initialize the block.""" try: - # Use external path if provided, otherwise fall back to relative path - if self.hint_db_path and Path(self.hint_db_path).exists(): - hint_db_path = Path(self.hint_db_path) + if self.hint_type == "docs": + if self.hint_index_type == "sparse": + print("Loading sparse hint index") + import bm25s + self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True) + print("Sparse hint index loaded successfully") + elif self.hint_index_type == "dense": + print("Loading dense hint index and retriever") + from datasets import load_from_disk + from sentence_transformers import SentenceTransformer + self.hint_index = load_from_disk(self.hint_index_path) + self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss") + print("Dense hint index loaded successfully") + self.hint_retriever = SentenceTransformer(self.hint_retriever_path) + print("Hint retriever loaded successfully") + else: + raise ValueError(f"Unknown hint index type: {self.hint_index_type}") else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - - if hint_db_path.exists(): - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - # Verify the expected columns exist - if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: - print( - f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" - ) + # Use external path if provided, otherwise fall back to relative path + if self.hint_db_path and Path(self.hint_db_path).exists(): + hint_db_path = Path(self.hint_db_path) + else: + hint_db_path = Path(__file__).parent / self.hint_db_rel_path + + if hint_db_path.exists(): + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + # Verify the expected columns exist + if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + else: + print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - else: - print(f"Warning: Hint database not found at {hint_db_path}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - self.hints_source = HintsSource( - hint_db_path=hint_db_path.as_posix(), - hint_retrieval_mode=self.hint_retrieval_mode, - skip_hints_for_current_task=self.skip_hints_for_current_task, - ) + self.hints_source = HintsSource( + hint_db_path=hint_db_path.as_posix(), + hint_retrieval_mode=self.hint_retrieval_mode, + skip_hints_for_current_task=self.skip_hints_for_current_task, + ) except Exception as e: # Fallback to empty database on any error print(f"Warning: Could not load hint database: {e}") @@ -365,6 +409,32 @@ def get_hints_for_task(self, task_name: str) -> str: if not self.use_task_hint: return "" + if self.hint_type == "docs": + if not hasattr(self, "hint_index"): + self._init() + + if self.hint_query_type == "goal": + query = self.goal + elif self.hint_query_type == "llm": + query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) + else: + raise ValueError(f"Unknown hint query type: {self.hint_query_type}") + + if self.hint_index_type == "sparse": + query_tokens = bm25s.tokenize(query) + docs = self.hint_index.search(query_tokens, k=self.hint_num_results) + docs = docs["text"] + elif self.hint_index_type == "dense": + query_embedding = self.hint_retriever.encode(query) + _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) + docs = docs["text"] + + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(docs) + ) + return hints_str + # Ensure hint_db is initialized if not hasattr(self, "hint_db"): self._init() From ee2653a0925b39da7ac9f1db3b3490ee3c7e8c17 Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Mon, 8 Sep 2025 19:41:27 -0400 Subject: [PATCH 27/53] stepwise hint retrieval --- .../generic_agent_hinter/generic_agent.py | 11 +++--- .../generic_agent_prompt.py | 36 ++++++++++++++----- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index c8368039..50e6d399 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -26,7 +26,7 @@ from .generic_agent_prompt import ( GenericPromptFlags, MainPrompt, - StepWiseRetrievalPrompt, + StepWiseContextIdentificationPrompt, ) @@ -111,10 +111,8 @@ def get_action(self, obs): queries, think_queries = self._get_queries() - # TODO - # use those queries to retreive from the database. e.g.: - # hints = self.hint_db.get_hints(queries) - # then add those hints to the main prompt + # use those queries to retrieve from the database and pass to prompt if step-level + queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None main_prompt = MainPrompt( action_set=self.action_set, @@ -126,6 +124,7 @@ def get_action(self, obs): step=self.plan_step, flags=self.flags, llm=self.chat_llm, + queries=queries_for_hints, ) # Set task name for task hints if available @@ -183,7 +182,7 @@ def get_action(self, obs): def _get_queries(self): """Retrieve queries for hinting.""" system_prompt = SystemMessage(dp.SystemPrompt().prompt) - query_prompt = StepWiseRetrievalPrompt( + query_prompt = StepWiseContextIdentificationPrompt( obs_history=self.obs_history, actions=self.actions, thoughts=self.thoughts, diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 44d17845..cf87a326 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -62,6 +62,7 @@ class GenericPromptFlags(dp.Flags): max_trunc_itr: int = 20 flag_group: str = None n_retrieval_queries: int = 3 + hint_level: Literal["episode", "step"] = "episode" class MainPrompt(dp.Shrinkable): @@ -76,6 +77,7 @@ def __init__( step: int, flags: GenericPromptFlags, llm: ChatModel, + queries: list[str] | None = None, ) -> None: super().__init__() self.flags = flags @@ -118,6 +120,8 @@ def time_for_caution(): hint_retrieval_mode=flags.task_hint_retrieval_mode, llm=llm, skip_hints_for_current_task=flags.skip_hints_for_current_task, + hint_level=flags.hint_level, + queries=queries, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -306,6 +310,8 @@ def __init__( hint_retrieval_mode: Literal["direct", "llm", "emb"], skip_hints_for_current_task: bool, llm: ChatModel, + hint_level: Literal["episode", "step"] = "episode", + queries: list[str] | None = None, ) -> None: super().__init__(visible=use_task_hint) self.use_task_hint = use_task_hint @@ -315,6 +321,8 @@ def __init__( self.skip_hints_for_current_task = skip_hints_for_current_task self.goal = goal self.llm = llm + self.hint_level: Literal["episode", "step"] = hint_level + self.queries: list[str] | None = queries self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -352,6 +360,7 @@ def _init(self): else: print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( hint_db_path=hint_db_path.as_posix(), hint_retrieval_mode=self.hint_retrieval_mode, @@ -380,7 +389,16 @@ def get_hints_for_task(self, task_name: str) -> str: return "" try: - task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.goal + if self.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.llm, + task_name, + goal_or_queries, + ) hints = [] for hint in task_hints: @@ -400,14 +418,14 @@ def get_hints_for_task(self, task_name: str) -> str: return "" -class StepWiseRetrievalPrompt(dp.Shrinkable): +class StepWiseContextIdentificationPrompt(dp.Shrinkable): def __init__( self, obs_history: list[dict], actions: list[str], thoughts: list[str], obs_flags: dp.ObsFlags, - n_queries: int = 3, + n_queries: int = 1, ) -> None: super().__init__() self.obs_flags = obs_flags @@ -430,10 +448,10 @@ def _prompt(self) -> HumanMessage: ) example_queries = [ - "How to sort with multiple columns on the ServiceNow platform?", - "What are the potential challenges of sorting by multiple columns?", - "How to handle sorting by multiple columns in a table?", - "Can I use the filter tool to sort by multiple columns?", + "The user has started sorting a table and needs to apply multiple column criteria simultaneously.", + "The user is attempting to configure advanced sorting options but the interface is unclear.", + "The user has selected the first sort column and is now looking for how to add a second sort criterion.", + "The user is in the middle of a multi-step sorting process and needs guidance on the next action.", ] example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) @@ -442,8 +460,8 @@ def _prompt(self) -> HumanMessage: f""" # Querying memory -Before choosing an action, let's search our available documentation and memory on how to approach this step. -This could provide valuable hints on how to properly solve this task. Return your answer as follow +Before choosing an action, let's search our available documentation and memory for relevant context. +Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow chain of thought json list of strings for the queries. Return exactly {self.n_queries} queries in the list. From ca11170a7d7e7d44669bba2369c9b5dddf6a75ec Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Tue, 9 Sep 2025 00:05:14 -0400 Subject: [PATCH 28/53] added shrink method --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 21ed167d..10cfeef6 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -551,6 +551,10 @@ def _prompt(self) -> HumanMessage: return self.obs.add_screenshot(prompt) + def shrink(self): + self.history.shrink() + self.obs.shrink() + def _parse_answer(self, text_answer): ans_dict = parse_html_tags_raise( text_answer, keys=["think", "queries"], merge_multiple=True From c86873b65858fccd39cd54a78428b2da51e67986 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Tue, 9 Sep 2025 16:50:24 -0400 Subject: [PATCH 29/53] (wip) refactor hinting index --- .../generic_agent_hinter/generic_agent.py | 49 ++++++++++++++++++- .../generic_agent_prompt.py | 22 +++++---- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 0cbdb6b3..540c4a5e 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -16,12 +16,14 @@ import bgym from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo - +import pandas as pd +from pathlib import Path from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource from .generic_agent_prompt import ( GenericPromptFlags, @@ -92,6 +94,8 @@ def __init__( self.action_set = self.flags.action.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + self._init_hints_index() + self._check_flag_constancy() self.reset(seed=None) @@ -246,3 +250,46 @@ def _get_maxes(self): else 20 # dangerous to change the default value here? ) return max_prompt_tokens, max_trunc_itr + + def _init_hints_index(self): + """Initialize the block.""" + try: + if self.flags.hint_type == "docs": + if self.flags.hint_index_type == "sparse": + import bm25s + self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) + elif self.flags.hint_index_type == "dense": + from datasets import load_from_disk + from sentence_transformers import SentenceTransformer + self.hint_index = load_from_disk(self.flags.hint_index_path) + self.hint_index.load_faiss_index("embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss") + self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) + else: + raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") + else: + # Use external path if provided, otherwise fall back to relative path + if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists(): + hint_db_path = Path(self.flags.hint_db_path) + else: + hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path + + if hint_db_path.exists(): + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + # Verify the expected columns exist + if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + else: + print(f"Warning: Hint database not found at {hint_db_path}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( + hint_db_path=hint_db_path.as_posix(), + hint_retrieval_mode=self.flags.hint_retrieval_mode, + skip_hints_for_current_task=self.flags.skip_hints_for_current_task, + ) + except Exception as e: + # Fallback to empty database on any error + print(f"Warning: Could not load hint database: {e}") + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 10cfeef6..d3f6ace7 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -80,6 +80,7 @@ def __init__( actions: list[str], memories: list[str], thoughts: list[str], + hints: list[str], previous_plan: str, step: int, flags: GenericPromptFlags, @@ -120,6 +121,7 @@ def time_for_caution(): self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) goal_str: str = goal[0]["text"] + # TODO: This design is not very good as we will instantiate the loop up at every step self.task_hint = TaskHint( use_task_hint=flags.use_task_hint, hint_db_path=flags.hint_db_path, @@ -147,7 +149,8 @@ def _prompt(self) -> HumanMessage: # Add task hints if enabled task_hints_text = "" - if self.flags.use_task_hint and hasattr(self, "task_name"): + # if self.flags.use_task_hint and hasattr(self, "task_name"): + if self.flags.use_task_hint: task_hints_text = self.task_hint.get_hints_for_task(self.task_name) prompt.add_text( @@ -371,19 +374,14 @@ def _init(self): try: if self.hint_type == "docs": if self.hint_index_type == "sparse": - print("Loading sparse hint index") import bm25s self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True) - print("Sparse hint index loaded successfully") elif self.hint_index_type == "dense": - print("Loading dense hint index and retriever") from datasets import load_from_disk from sentence_transformers import SentenceTransformer self.hint_index = load_from_disk(self.hint_index_path) self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss") - print("Dense hint index loaded successfully") self.hint_retriever = SentenceTransformer(self.hint_retriever_path) - print("Hint retriever loaded successfully") else: raise ValueError(f"Unknown hint index type: {self.hint_index_type}") else: @@ -422,8 +420,8 @@ def get_hints_for_task(self, task_name: str) -> str: if self.hint_type == "docs": if not hasattr(self, "hint_index"): + print("Initializing hint index new time") self._init() - if self.hint_query_type == "goal": query = self.goal elif self.hint_query_type == "llm": @@ -432,9 +430,15 @@ def get_hints_for_task(self, task_name: str) -> str: raise ValueError(f"Unknown hint query type: {self.hint_query_type}") if self.hint_index_type == "sparse": + import bm25s query_tokens = bm25s.tokenize(query) - docs = self.hint_index.search(query_tokens, k=self.hint_num_results) - docs = docs["text"] + docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) + docs = [elem["text"] for elem in docs[0]] + # HACK: truncate to 20k characters (should cover >99% of the cases) + for doc in docs: + if len(doc) > 20000: + doc = doc[:20000] + doc += " ...[truncated]" elif self.hint_index_type == "dense": query_embedding = self.hint_retriever.encode(query) _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) From 7e55cd786b8ccd12aa642e9471590b7605ab4132 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Tue, 9 Sep 2025 16:57:09 -0400 Subject: [PATCH 30/53] (wip) clean up prompt file --- .../generic_agent_prompt.py | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index d3f6ace7..599df838 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -352,7 +352,6 @@ def __init__( self.llm = llm self.hint_level: Literal["episode", "step"] = hint_level self.queries: list[str] | None = queries - self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -369,50 +368,6 @@ def __init__( """ - def _init(self): - """Initialize the block.""" - try: - if self.hint_type == "docs": - if self.hint_index_type == "sparse": - import bm25s - self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True) - elif self.hint_index_type == "dense": - from datasets import load_from_disk - from sentence_transformers import SentenceTransformer - self.hint_index = load_from_disk(self.hint_index_path) - self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss") - self.hint_retriever = SentenceTransformer(self.hint_retriever_path) - else: - raise ValueError(f"Unknown hint index type: {self.hint_index_type}") - else: - # Use external path if provided, otherwise fall back to relative path - if self.hint_db_path and Path(self.hint_db_path).exists(): - hint_db_path = Path(self.hint_db_path) - else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - - if hint_db_path.exists(): - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - # Verify the expected columns exist - if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: - print( - f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" - ) - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - else: - print(f"Warning: Hint database not found at {hint_db_path}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - - self.hints_source = HintsSource( - hint_db_path=hint_db_path.as_posix(), - hint_retrieval_mode=self.hint_retrieval_mode, - skip_hints_for_current_task=self.skip_hints_for_current_task, - ) - except Exception as e: - # Fallback to empty database on any error - print(f"Warning: Could not load hint database: {e}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - def get_hints_for_task(self, task_name: str) -> str: """Get hints for a specific task.""" if not self.use_task_hint: @@ -450,10 +405,6 @@ def get_hints_for_task(self, task_name: str) -> str: ) return hints_str - # Ensure hint_db is initialized - if not hasattr(self, "hint_db"): - self._init() - # Check if hint_db has the expected structure if ( self.hint_db.empty From 66b969204b987ba0dcfd6c2bf5884814020b7ad1 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 13:10:07 -0400 Subject: [PATCH 31/53] add scripts to run generic and hinter agents, update tmlr config for hinter --- experiments/generic/run_generic_agent.py | 55 ++++++++++++++ experiments/generic/run_generic_agent.sh | 17 +++++ experiments/hinter/run_hinter_agent.py | 76 +++++++++++++++++++ experiments/hinter/run_hinter_agent.sh | 31 ++++++++ .../generic_agent_hinter/tmlr_config.py | 10 +++ 5 files changed, 189 insertions(+) create mode 100644 experiments/generic/run_generic_agent.py create mode 100644 experiments/generic/run_generic_agent.sh create mode 100644 experiments/hinter/run_hinter_agent.py create mode 100644 experiments/hinter/run_hinter_agent.sh diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py new file mode 100644 index 00000000..cdeb3eaf --- /dev/null +++ b/experiments/generic/run_generic_agent.py @@ -0,0 +1,55 @@ +import argparse + +from dotenv import load_dotenv + +load_dotenv() + +import argparse +import logging + +from agentlab.agents.generic_agent.tmlr_config import get_base_agent +from agentlab.experiments.study import Study +from bgym import DEFAULT_BENCHMARKS + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="http://github.com/store_true") + parser.add_argument("--n-jobs", type=int, default=5) + parser.add_argument("--n-relaunch", type=int, default=3) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="http://github.com/store_true") + + args = parser.parse_args() + + # instantiate agent + agent_args = [get_base_agent(args.llm_config)] + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend="ray", + strict_reproducibility=args.reproducibility_mode, + n_relaunch=args.n_relaunch, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/generic/run_generic_agent.sh b/experiments/generic/run_generic_agent.sh new file mode 100644 index 00000000..426af66e --- /dev/null +++ b/experiments/generic/run_generic_agent.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +N_JOBS=5 +N_RELAUNCH=3 + +python experiments/generic/run_generic_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --n-relaunch $N_RELAUNCH \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py new file mode 100644 index 00000000..b08283ba --- /dev/null +++ b/experiments/hinter/run_hinter_agent.py @@ -0,0 +1,76 @@ + +from dotenv import load_dotenv +import argparse + +load_dotenv() + +import logging +import argparse + +from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs +from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o +from bgym import DEFAULT_BENCHMARKS +from agentlab.experiments.study import Study + +logging.getLogger().setLevel(logging.WARNING) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark", required=True) + parser.add_argument("--llm-config", required=True) + parser.add_argument("--relaunch", action="http://github.com/store_true") + parser.add_argument("--n-jobs", type=int, default=6) + parser.add_argument("--parallel-backend", type=str, default="ray") + parser.add_argument("--reproducibility-mode", action="http://github.com/store_true") + # hint flags + parser.add_argument("--hint-type", type=str, default="docs") + parser.add_argument("--hint-index-type", type=str, default="sparse") + parser.add_argument("--hint-query-type", type=str, default="direct") + parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") + parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") + parser.add_argument("--hint-num-results", type=int, default=5) + args = parser.parse_args() + + flags = FLAGS_GPT_4o + flags.use_task_hint = True + flags.hint_type = args.hint_type + flags.hint_index_type = args.hint_index_type + flags.hint_query_type = args.hint_query_type + flags.hint_index_path = args.hint_index_path + flags.hint_retriever_path = args.hint_retriever_path + flags.hint_num_results = args.hint_num_results + + # instantiate agent + agent_args = [GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config], + flags=flags, + )] + + benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + + + if args.relaunch: + # relaunch an existing study + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + + else: + study = Study( + agent_args, + benchmark, + logging_level=logging.WARNING, + logging_level_stdout=logging.WARNING, + ) + + study.run( + n_jobs=args.n_jobs, + parallel_backend=args.parallel_backend, + strict_reproducibility=args.reproducibility_mode, + n_relaunch=1, + ) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh new file mode 100644 index 00000000..ab35f35b --- /dev/null +++ b/experiments/hinter/run_hinter_agent.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +BENCHMARK="workarena_l1" + +LLM_CONFIG="azure/gpt-5-mini-2025-08-07" +# PARALLEL_BACKEND="sequential" +PARALLEL_BACKEND="ray" + +HINT_TYPE="docs" # human, llm, docs +HINT_INDEX_TYPE="sparse" # sparse, dense +HINT_QUERY_TYPE="goal" # goal, llm +HINT_NUM_RESULTS=5 + +HINT_INDEX_PATH="indexes/servicenow-docs-bm25" +# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" +HINT_RETRIEVER_PATH="google/embeddinggemma-300m" + +N_JOBS=6 + +python experiments/hint/run_hinter_agent.py \ + --benchmark $BENCHMARK \ + --llm-config $LLM_CONFIG \ + --parallel-backend $PARALLEL_BACKEND \ + --n-jobs $N_JOBS \ + --hint-type $HINT_TYPE \ + --hint-index-type $HINT_INDEX_TYPE \ + --hint-query-type $HINT_QUERY_TYPE \ + --hint-index-path $HINT_INDEX_PATH \ + --hint-retriever-path $HINT_RETRIEVER_PATH \ + --hint-num-results $HINT_NUM_RESULTS \ + --relaunch \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py index d222b7c0..b6f16058 100644 --- a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py +++ b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py @@ -47,6 +47,16 @@ max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, + + # hint flags + hint_type="human", + hint_index_type="sparse", + hint_query_type="direct", + hint_index_path=None, + hint_retriever_path=None, + hint_num_results=5, + n_retrieval_queries=3, + hint_level="episode", ) From d2166b3e74550ca67ec9c48cfe500e115a2d05b8 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 13:18:28 -0400 Subject: [PATCH 32/53] move HintsSource to separate hinting file --- .../agents/tool_use_agent/tool_use_agent.py | 174 +--------------- src/agentlab/utils/__init__.py | 0 src/agentlab/utils/hinting.py | 189 ++++++++++++++++++ 3 files changed, 190 insertions(+), 173 deletions(-) create mode 100644 src/agentlab/utils/__init__.py create mode 100644 src/agentlab/utils/hinting.py diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index bd200da3..c17b5c23 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -41,6 +41,7 @@ ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource logger = logging.getLogger(__name__) @@ -349,179 +350,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) -class HintsSource: - def __init__( - self, - hint_db_path: str, - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", - skip_hints_for_current_task: bool = False, - top_n: int = 4, - embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", - embedder_server: str = "http://localhost:5000/", - llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n -You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n -Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", - ) -> None: - self.hint_db_path = hint_db_path - self.hint_retrieval_mode = hint_retrieval_mode - self.skip_hints_for_current_task = skip_hints_for_current_task - self.top_n = top_n - self.embedder_model = embedder_model - self.embedder_server = embedder_server - self.llm_prompt = llm_prompt - - if Path(hint_db_path).is_absolute(): - self.hint_db_path = Path(hint_db_path).as_posix() - else: - self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() - self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) - logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") - if self.hint_retrieval_mode == "emb": - self.load_hint_vectors() - - def load_hint_vectors(self): - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") - logger.info( - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." - ) - hints = self.uniq_hints["hint"].tolist() - semantic_keys = self.uniq_hints["semantic_keys"].tolist() - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] - emb_path = f"{self.hint_db_path}.embs.npy" - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" - logger.info(f"Loading hint embeddings from: {emb_path}") - emb_dict = np.load(emb_path, allow_pickle=True).item() - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") - - def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: - """Choose hints based on the task name.""" - logger.info( - f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" - ) - if self.hint_retrieval_mode == "llm": - return self.choose_hints_llm(llm, goal, task_name) - elif self.hint_retrieval_mode == "direct": - return self.choose_hints_direct(task_name) - elif self.hint_retrieval_mode == "emb": - return self.choose_hints_emb(goal, task_name) - else: - raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") - - def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: - """Choose hints using LLM to filter the hints.""" - topic_to_hints = defaultdict(list) - skip_hints = [] - if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) - for _, row in self.hint_db.iterrows(): - hint = row["hint"] - if hint in skip_hints: - continue - topic_to_hints[row["semantic_keys"]].append(hint) - logger.info(f"Collected {len(topic_to_hints)} hint topics") - hint_topics = list(topic_to_hints.keys()) - topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) - prompt = self.llm_prompt.format(goal=goal, topics=topics) - - if isinstance(llm, ChatModel): - response: str = llm(messages=[dict(role="user", content=prompt)])["content"] - else: - response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think - try: - topic_number = json.loads(response) - if topic_number < 0 or topic_number >= len(hint_topics): - logger.error(f"Wrong LLM hint id response: {response}, no hints") - return [] - hint_topic = hint_topics[topic_number] - hints = list(set(topic_to_hints[hint_topic])) - logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") - except Exception as e: - logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") - hints = [] - return hints - - def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: - """Choose hints using embeddings to filter the hints.""" - try: - goal_embeddings = self._encode([goal], prompt="task description") - hint_embeddings = self.hint_embeddings.copy() - all_hints = self.uniq_hints["hint"].tolist() - skip_hints = [] - if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) - hint_embeddings = [] - id_to_hint = {} - for hint, emb in zip(all_hints, self.hint_embeddings): - if hint in skip_hints: - continue - hint_embeddings.append(emb.tolist()) - id_to_hint[len(hint_embeddings) - 1] = hint - logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") - similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) - top_indices = similarities.argsort()[0][-self.top_n :].tolist() - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = [id_to_hint[idx] for idx in top_indices] - logger.info(f"Embedding-based hints chosen: {hints}") - except Exception as e: - logger.exception(f"Failed to choose hints using embeddings: {e}") - hints = [] - return hints - - def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): - """Call the encode API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/encode", - json={"texts": texts, "prompt": prompt}, - timeout=timeout, - ) - embs = response.json()["embeddings"] - return np.asarray(embs) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - raise ValueError("Failed to encode hints") - - def _similarity( - self, - texts1: list, - texts2: list, - timeout: int = 2, - max_retries: int = 5, - ): - """Call the similarity API endpoint with timeout and retries""" - for attempt in range(max_retries): - try: - response = requests.post( - f"{self.embedder_server}/similarity", - json={"texts1": texts1, "texts2": texts2}, - timeout=timeout, - ) - similarities = response.json()["similarities"] - return np.asarray(similarities) - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: - if attempt == max_retries - 1: - raise e - time.sleep(random.uniform(1, timeout)) - continue - raise ValueError("Failed to compute similarity") - - def choose_hints_direct(self, task_name: str) -> list[str]: - hints = self.get_current_task_hints(task_name) - logger.info(f"Direct hints chosen: {hints}") - return hints - - def get_current_task_hints(self, task_name): - hints_df = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] - return hints_df["hint"].tolist() - - @dataclass class PromptConfig: tag_screenshot: bool = True # Whether to tag the screenshot with the last action. diff --git a/src/agentlab/utils/__init__.py b/src/agentlab/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py new file mode 100644 index 00000000..6ba1f2d5 --- /dev/null +++ b/src/agentlab/utils/hinting.py @@ -0,0 +1,189 @@ +import fnmatch +import json +import logging +import os +import random +import time +from collections import defaultdict +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +import requests +from agentlab.llm.chat_api import ChatModel + +logger = logging.getLogger(__name__) + + +class HintsSource: + def __init__( + self, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + skip_hints_for_current_task: bool = False, + top_n: int = 4, + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", + embedder_server: str = "http://localhost:5000/", + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", + ) -> None: + self.hint_db_path = hint_db_path + self.hint_retrieval_mode = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task + self.top_n = top_n + self.embedder_model = embedder_model + self.embedder_server = embedder_server + self.llm_prompt = llm_prompt + + if Path(hint_db_path).is_absolute(): + self.hint_db_path = Path(hint_db_path).as_posix() + else: + self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() + self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") + if self.hint_retrieval_mode == "emb": + self.load_hint_vectors() + + def load_hint_vectors(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: + """Choose hints based on the task name.""" + logger.info( + f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" + ) + if self.hint_retrieval_mode == "llm": + return self.choose_hints_llm(llm, goal, task_name) + elif self.hint_retrieval_mode == "direct": + return self.choose_hints_direct(task_name) + elif self.hint_retrieval_mode == "emb": + return self.choose_hints_emb(goal, task_name) + else: + raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") + + def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: + """Choose hints using LLM to filter the hints.""" + topic_to_hints = defaultdict(list) + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + for _, row in self.hint_db.iterrows(): + hint = row["hint"] + if hint in skip_hints: + continue + topic_to_hints[row["semantic_keys"]].append(hint) + logger.info(f"Collected {len(topic_to_hints)} hint topics") + hint_topics = list(topic_to_hints.keys()) + topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) + prompt = self.llm_prompt.format(goal=goal, topics=topics) + + if isinstance(llm, ChatModel): + response: str = llm(messages=[dict(role="user", content=prompt)])["content"] + else: + response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think + try: + topic_number = json.loads(response) + if topic_number < 0 or topic_number >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response}, no hints") + return [] + hint_topic = hint_topics[topic_number] + hints = list(set(topic_to_hints[hint_topic])) + logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") + except Exception as e: + logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") + hints = [] + return hints + + def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: + """Choose hints using embeddings to filter the hints.""" + try: + goal_embeddings = self._encode([goal], prompt="task description") + hint_embeddings = self.hint_embeddings.copy() + all_hints = self.uniq_hints["hint"].tolist() + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + hint_embeddings = [] + id_to_hint = {} + for hint, emb in zip(all_hints, self.hint_embeddings): + if hint in skip_hints: + continue + hint_embeddings.append(emb.tolist()) + id_to_hint[len(hint_embeddings) - 1] = hint + logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = [id_to_hint[idx] for idx in top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + except Exception as e: + logger.exception(f"Failed to choose hints using embeddings: {e}") + hints = [] + return hints + + def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): + """Call the encode API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/encode", + json={"texts": texts, "prompt": prompt}, + timeout=timeout, + ) + embs = response.json()["embeddings"] + return np.asarray(embs) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to encode hints") + + def _similarity( + self, + texts1: list, + texts2: list, + timeout: int = 2, + max_retries: int = 5, + ): + """Call the similarity API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/similarity", + json={"texts1": texts1, "texts2": texts2}, + timeout=timeout, + ) + similarities = response.json()["similarities"] + return np.asarray(similarities) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + raise ValueError("Failed to compute similarity") + + def choose_hints_direct(self, task_name: str) -> list[str]: + hints = self.get_current_task_hints(task_name) + logger.info(f"Direct hints chosen: {hints}") + return hints + + def get_current_task_hints(self, task_name): + hints_df = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + return hints_df["hint"].tolist() From 60ad8e43e431a925b09101a58ac6fc68ddbcf567 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 16:06:42 -0400 Subject: [PATCH 33/53] update hinter agent and prompt --- .../generic_agent_hinter/generic_agent.py | 105 +++++++++++-- .../generic_agent_prompt.py | 138 +----------------- 2 files changed, 101 insertions(+), 142 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 540c4a5e..d9694a36 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -10,20 +10,18 @@ from copy import deepcopy from dataclasses import asdict, dataclass -from functools import partial +from pathlib import Path from warnings import warn -import bgym -from bgym import Benchmark -from browsergym.experiments.agent import Agent, AgentInfo import pandas as pd -from pathlib import Path from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource +from agentlab.utils.hinting import HintsSource +from bgym import Benchmark +from browsergym.experiments.agent import Agent, AgentInfo from .generic_agent_prompt import ( GenericPromptFlags, @@ -40,7 +38,9 @@ class GenericAgentArgs(AgentArgs): def __post_init__(self): try: # some attributes might be temporarily args.CrossProd for hyperparameter generation - self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_") + self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace( + "/", "_" + ) except AttributeError: pass @@ -116,7 +116,9 @@ def get_action(self, obs): queries, think_queries = self._get_queries() # use those queries to retrieve from the database and pass to prompt if step-level - queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None + queries_for_hints = ( + queries if getattr(self.flags, "hint_level", "episode") == "step" else None + ) main_prompt = MainPrompt( action_set=self.action_set, @@ -257,12 +259,16 @@ def _init_hints_index(self): if self.flags.hint_type == "docs": if self.flags.hint_index_type == "sparse": import bm25s + self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True) elif self.flags.hint_index_type == "dense": from datasets import load_from_disk from sentence_transformers import SentenceTransformer + self.hint_index = load_from_disk(self.flags.hint_index_path) - self.hint_index.load_faiss_index("embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss") + self.hint_index.load_faiss_index( + "embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss" + ) self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path) else: raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}") @@ -276,7 +282,10 @@ def _init_hints_index(self): if hint_db_path.exists(): self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) # Verify the expected columns exist - if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + if ( + "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): print( f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" ) @@ -292,4 +301,78 @@ def _init_hints_index(self): except Exception as e: # Fallback to empty database on any error print(f"Warning: Could not load hint database: {e}") - self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) \ No newline at end of file + self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + + def get_hints_for_task(self, task_name: str) -> str: + """Get hints for a specific task.""" + if not self.use_task_hint: + return "" + + if self.hint_type == "docs": + if not hasattr(self, "hint_index"): + print("Initializing hint index new time") + self._init() + if self.hint_query_type == "goal": + query = self.goal + elif self.hint_query_type == "llm": + query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) + else: + raise ValueError(f"Unknown hint query type: {self.hint_query_type}") + + if self.hint_index_type == "sparse": + import bm25s + query_tokens = bm25s.tokenize(query) + docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) + docs = [elem["text"] for elem in docs[0]] + # HACK: truncate to 20k characters (should cover >99% of the cases) + for doc in docs: + if len(doc) > 20000: + doc = doc[:20000] + doc += " ...[truncated]" + elif self.hint_index_type == "dense": + query_embedding = self.hint_retriever.encode(query) + _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) + docs = docs["text"] + + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(docs) + ) + return hints_str + + # Check if hint_db has the expected structure + if ( + self.hint_db.empty + or "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): + return "" + + try: + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.goal + if self.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.llm, + task_name, + goal_or_queries, + ) + + hints = [] + for hint in task_hints: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + if len(hints) > 0: + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(hints) + ) + return hints_str + except Exception as e: + print(f"Warning: Error getting hints for task {task_name}: {e}") + + return "" \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 599df838..3536a71b 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -85,7 +85,7 @@ def __init__( step: int, flags: GenericPromptFlags, llm: ChatModel, - queries: list[str] | None = None, + task_hints: list[str] = [], ) -> None: super().__init__() self.flags = flags @@ -120,25 +120,7 @@ def time_for_caution(): self.be_cautious = dp.BeCautious(visible=time_for_caution) self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) - goal_str: str = goal[0]["text"] - # TODO: This design is not very good as we will instantiate the loop up at every step - self.task_hint = TaskHint( - use_task_hint=flags.use_task_hint, - hint_db_path=flags.hint_db_path, - goal=goal_str, - hint_retrieval_mode=flags.task_hint_retrieval_mode, - llm=llm, - skip_hints_for_current_task=flags.skip_hints_for_current_task, - # hint related - hint_type=flags.hint_type, - hint_index_type=flags.hint_index_type, - hint_query_type=flags.hint_query_type, - hint_index_path=flags.hint_index_path, - hint_retriever_path=flags.hint_retriever_path, - hint_num_results=flags.hint_num_results, - hint_level=flags.hint_level, - queries=queries, - ) + self.task_hints = TaskHint(visible=lambda: flags.use_task_hint, task_hints=task_hints) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) self.memory = Memory(visible=lambda: flags.use_memory) @@ -147,19 +129,13 @@ def time_for_caution(): def _prompt(self) -> HumanMessage: prompt = HumanMessage(self.instructions.prompt) - # Add task hints if enabled - task_hints_text = "" - # if self.flags.use_task_hint and hasattr(self, "task_name"): - if self.flags.use_task_hint: - task_hints_text = self.task_hint.get_hints_for_task(self.task_name) - prompt.add_text( f"""\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_prompt.prompt}\ {self.hints.prompt}\ -{task_hints_text}\ +{self.task_hint.prompt}\ {self.be_cautious.prompt}\ {self.think.prompt}\ {self.plan.prompt}\ @@ -321,37 +297,11 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): def __init__( self, - use_task_hint: bool, - hint_db_path: str, - goal: str, - llm: ChatModel, - hint_type: Literal["human", "llm", "docs"] = "human", - hint_index_type: Literal["sparse", "dense"] = "sparse", - hint_query_type: Literal["direct", "llm", "emb"] = "direct", - hint_index_path: str = None, - hint_retriever_path: str = None, - hint_num_results: int = 5, - skip_hints_for_current_task: bool = False, - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", - hint_level: Literal["episode", "step"] = "episode", - queries: list[str] | None = None, + visible: bool, + task_hints: list[str] ) -> None: - super().__init__(visible=use_task_hint) - self.use_task_hint = use_task_hint - self.hint_type = hint_type - self.hint_index_type = hint_index_type - self.hint_query_type = hint_query_type - self.hint_index_path = hint_index_path - self.hint_retriever_path = hint_retriever_path - self.hint_num_results = hint_num_results - self.hint_db_rel_path = "hint_db.csv" - self.hint_db_path = hint_db_path # Allow external path override - self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode - self.skip_hints_for_current_task = skip_hints_for_current_task - self.goal = goal - self.llm = llm - self.hint_level: Literal["episode", "step"] = hint_level - self.queries: list[str] | None = queries + super().__init__(visible=visible) + self.task_hints = task_hints _prompt = "" # Task hints are added dynamically in MainPrompt @@ -368,80 +318,6 @@ def __init__( """ - def get_hints_for_task(self, task_name: str) -> str: - """Get hints for a specific task.""" - if not self.use_task_hint: - return "" - - if self.hint_type == "docs": - if not hasattr(self, "hint_index"): - print("Initializing hint index new time") - self._init() - if self.hint_query_type == "goal": - query = self.goal - elif self.hint_query_type == "llm": - query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) - else: - raise ValueError(f"Unknown hint query type: {self.hint_query_type}") - - if self.hint_index_type == "sparse": - import bm25s - query_tokens = bm25s.tokenize(query) - docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) - docs = [elem["text"] for elem in docs[0]] - # HACK: truncate to 20k characters (should cover >99% of the cases) - for doc in docs: - if len(doc) > 20000: - doc = doc[:20000] - doc += " ...[truncated]" - elif self.hint_index_type == "dense": - query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) - docs = docs["text"] - - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(docs) - ) - return hints_str - - # Check if hint_db has the expected structure - if ( - self.hint_db.empty - or "task_name" not in self.hint_db.columns - or "hint" not in self.hint_db.columns - ): - return "" - - try: - # When step-level, pass queries as goal string to fit the llm_prompt - goal_or_queries = self.goal - if self.hint_level == "step" and self.queries: - goal_or_queries = "\n".join(self.queries) - - task_hints = self.hints_source.choose_hints( - self.llm, - task_name, - goal_or_queries, - ) - - hints = [] - for hint in task_hints: - hint = hint.strip() - if hint: - hints.append(f"- {hint}") - - if len(hints) > 0: - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(hints) - ) - return hints_str - except Exception as e: - print(f"Warning: Error getting hints for task {task_name}: {e}") - - return "" - class StepWiseContextIdentificationPrompt(dp.Shrinkable): def __init__( From 4a2c7de5921668e7e54a060cca792a85e65c2961 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 16:34:52 -0400 Subject: [PATCH 34/53] fix prompt for task hint --- .../generic_agent_prompt.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 3536a71b..fddd48f2 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -51,9 +51,6 @@ class GenericPromptFlags(dp.Flags): use_abstract_example: bool = False use_hints: bool = False use_task_hint: bool = False - task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" - skip_hints_for_current_task: bool = False - hint_db_path: str = None enable_chat: bool = False max_prompt_tokens: int = None be_cautious: bool = True @@ -61,15 +58,6 @@ class GenericPromptFlags(dp.Flags): add_missparsed_messages: bool = True max_trunc_itr: int = 20 flag_group: str = None - # hint flags - hint_type: Literal["human", "llm", "docs"] = "human" - hint_index_type: Literal["sparse", "dense"] = "sparse" - hint_query_type: Literal["direct", "llm", "emb"] = "direct" - hint_index_path: str = None - hint_retriever_path: str = None - hint_num_results: int = 5 - n_retrieval_queries: int = 3 - hint_level: Literal["episode", "step"] = "episode" class MainPrompt(dp.Shrinkable): @@ -135,7 +123,7 @@ def _prompt(self) -> HumanMessage: {self.history.prompt}\ {self.action_prompt.prompt}\ {self.hints.prompt}\ -{self.task_hint.prompt}\ +{self.task_hints.prompt}\ {self.be_cautious.prompt}\ {self.think.prompt}\ {self.plan.prompt}\ @@ -156,7 +144,7 @@ def _prompt(self) -> HumanMessage: {self.plan.abstract_ex}\ {self.memory.abstract_ex}\ {self.criticise.abstract_ex}\ -{self.task_hint.abstract_ex}\ +{self.task_hints.abstract_ex}\ {self.action_prompt.abstract_ex}\ """ ) @@ -172,7 +160,7 @@ def _prompt(self) -> HumanMessage: {self.plan.concrete_ex}\ {self.memory.concrete_ex}\ {self.criticise.concrete_ex}\ -{self.task_hint.concrete_ex}\ +{self.task_hints.concrete_ex}\ {self.action_prompt.concrete_ex}\ """ ) @@ -303,7 +291,12 @@ def __init__( super().__init__(visible=visible) self.task_hints = task_hints - _prompt = "" # Task hints are added dynamically in MainPrompt + @property + def _prompt(self): + task_hint_str = "# Hints:\nHere are some hints for the task you are working on:\n" + for hint in self.task_hints: + task_hint_str += f"{hint}\n" + return task_hint_str _abstract_ex = """ From eafd5fc6207802c2b55e35cae64f3fd68a264bd4 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 10 Sep 2025 16:49:43 -0400 Subject: [PATCH 35/53] undo changes to tmlr config --- .../agents/generic_agent_hinter/tmlr_config.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py index b6f16058..d222b7c0 100644 --- a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py +++ b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py @@ -47,16 +47,6 @@ max_prompt_tokens=40_000, be_cautious=True, extra_instructions=None, - - # hint flags - hint_type="human", - hint_index_type="sparse", - hint_query_type="direct", - hint_index_path=None, - hint_retriever_path=None, - hint_num_results=5, - n_retrieval_queries=3, - hint_level="episode", ) From 70d701e9dce7e614aa55b2b4940a26048876cead Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 11 Sep 2025 13:37:53 -0400 Subject: [PATCH 36/53] update hinter agent --- experiments/generic/run_generic_agent.py | 10 ++++ experiments/hinter/run_hinter_agent.py | 6 ++ experiments/hinter/run_hinter_agent.sh | 7 +-- .../generic_agent_hinter/generic_agent.py | 58 +++++++++---------- .../generic_agent_prompt.py | 11 +++- 5 files changed, 57 insertions(+), 35 deletions(-) diff --git a/experiments/generic/run_generic_agent.py b/experiments/generic/run_generic_agent.py index cdeb3eaf..cc646436 100644 --- a/experiments/generic/run_generic_agent.py +++ b/experiments/generic/run_generic_agent.py @@ -30,6 +30,16 @@ def main(): agent_args = [get_base_agent(args.llm_config)] benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + ##################### Shuffle env args list, pick subset + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:33] + ##################### + + # for env_args in benchmark.env_args_list: + # env_args.max_steps = 100 + if args.relaunch: # relaunch an existing study study = Study.load_most_recent(contains=None) diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py index b08283ba..fb2e4d57 100644 --- a/experiments/hinter/run_hinter_agent.py +++ b/experiments/hinter/run_hinter_agent.py @@ -49,6 +49,12 @@ def main(): benchmark = DEFAULT_BENCHMARKS[args.benchmark]() + # # shuffle env_args_list and pick first 33 + # import numpy as np + # rng = np.random.default_rng(42) + # rng.shuffle(benchmark.env_args_list) + # benchmark.env_args_list = benchmark.env_args_list[:33] + if args.relaunch: # relaunch an existing study diff --git a/experiments/hinter/run_hinter_agent.sh b/experiments/hinter/run_hinter_agent.sh index ab35f35b..9d998ef2 100644 --- a/experiments/hinter/run_hinter_agent.sh +++ b/experiments/hinter/run_hinter_agent.sh @@ -9,7 +9,7 @@ PARALLEL_BACKEND="ray" HINT_TYPE="docs" # human, llm, docs HINT_INDEX_TYPE="sparse" # sparse, dense HINT_QUERY_TYPE="goal" # goal, llm -HINT_NUM_RESULTS=5 +HINT_NUM_RESULTS=3 HINT_INDEX_PATH="indexes/servicenow-docs-bm25" # HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m" @@ -17,7 +17,7 @@ HINT_RETRIEVER_PATH="google/embeddinggemma-300m" N_JOBS=6 -python experiments/hint/run_hinter_agent.py \ +python experiments/hinter/run_hinter_agent.py \ --benchmark $BENCHMARK \ --llm-config $LLM_CONFIG \ --parallel-backend $PARALLEL_BACKEND \ @@ -27,5 +27,4 @@ python experiments/hint/run_hinter_agent.py \ --hint-query-type $HINT_QUERY_TYPE \ --hint-index-path $HINT_INDEX_PATH \ --hint-retriever-path $HINT_RETRIEVER_PATH \ - --hint-num-results $HINT_NUM_RESULTS \ - --relaunch \ No newline at end of file + --hint-num-results $HINT_NUM_RESULTS \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index d9694a36..18a24468 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -120,6 +120,12 @@ def get_action(self, obs): queries if getattr(self.flags, "hint_level", "episode") == "step" else None ) + # get hints + if self.flags.use_hints: + task_hints = self._get_task_hints() + else: + task_hints = [] + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, @@ -130,7 +136,7 @@ def get_action(self, obs): step=self.plan_step, flags=self.flags, llm=self.chat_llm, - queries=queries_for_hints, + task_hints=task_hints, ) # Set task name for task hints if available @@ -303,42 +309,39 @@ def _init_hints_index(self): print(f"Warning: Could not load hint database: {e}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - def get_hints_for_task(self, task_name: str) -> str: + def _get_task_hints(self) -> list[str]: """Get hints for a specific task.""" - if not self.use_task_hint: - return "" + if not self.flags.use_task_hint: + return [] - if self.hint_type == "docs": + if self.flags.hint_type == "docs": if not hasattr(self, "hint_index"): print("Initializing hint index new time") self._init() - if self.hint_query_type == "goal": - query = self.goal - elif self.hint_query_type == "llm": + if self.flags.hint_query_type == "goal": + query = self.obs_history[-1]["goal_object"][0]["text"] + elif self.flags.hint_query_type == "llm": query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) else: - raise ValueError(f"Unknown hint query type: {self.hint_query_type}") + raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") - if self.hint_index_type == "sparse": + print(f"Query: {query}") + if self.flags.hint_index_type == "sparse": import bm25s query_tokens = bm25s.tokenize(query) - docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results) + docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) docs = [elem["text"] for elem in docs[0]] # HACK: truncate to 20k characters (should cover >99% of the cases) for doc in docs: if len(doc) > 20000: doc = doc[:20000] doc += " ...[truncated]" - elif self.hint_index_type == "dense": + elif self.flags.hint_index_type == "dense": query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results) + _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.flags.hint_num_results) docs = docs["text"] - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(docs) - ) - return hints_str + return docs # Check if hint_db has the expected structure if ( @@ -346,17 +349,17 @@ def get_hints_for_task(self, task_name: str) -> str: or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns ): - return "" + return [] try: # When step-level, pass queries as goal string to fit the llm_prompt - goal_or_queries = self.goal - if self.hint_level == "step" and self.queries: + goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"] + if self.flags.hint_level == "step" and self.queries: goal_or_queries = "\n".join(self.queries) task_hints = self.hints_source.choose_hints( self.llm, - task_name, + self.task_name, goal_or_queries, ) @@ -366,13 +369,8 @@ def get_hints_for_task(self, task_name: str) -> str: if hint: hints.append(f"- {hint}") - if len(hints) > 0: - hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" - + "\n".join(hints) - ) - return hints_str + return hints except Exception as e: - print(f"Warning: Error getting hints for task {task_name}: {e}") + print(f"Warning: Error getting hints for task {self.task_name}: {e}") - return "" \ No newline at end of file + return [] \ No newline at end of file diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index fddd48f2..2699024f 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -59,6 +59,16 @@ class GenericPromptFlags(dp.Flags): max_trunc_itr: int = 20 flag_group: str = None + # hint related + use_task_hint: bool = False + hint_type: str = "docs" + hint_index_type: str = "sparse" + hint_query_type: str = "direct" + hint_index_path: str = "indexes/servicenow-docs-bm25" + hint_retriever_path: str = "google/embeddinggemma-300m" + hint_num_results: int = 5 + n_retrieval_queries: int = 1 + class MainPrompt(dp.Shrinkable): def __init__( @@ -68,7 +78,6 @@ def __init__( actions: list[str], memories: list[str], thoughts: list[str], - hints: list[str], previous_plan: str, step: int, flags: GenericPromptFlags, From 91119d6305eb657cca64045c0f8c2dd619a37c9d Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 11 Sep 2025 13:38:25 -0400 Subject: [PATCH 37/53] formatting --- .../agents/generic_agent_hinter/generic_agent.py | 7 +++++-- .../agents/generic_agent_hinter/generic_agent_prompt.py | 9 ++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 18a24468..0e17e711 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -328,6 +328,7 @@ def _get_task_hints(self) -> list[str]: print(f"Query: {query}") if self.flags.hint_index_type == "sparse": import bm25s + query_tokens = bm25s.tokenize(query) docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results) docs = [elem["text"] for elem in docs[0]] @@ -338,7 +339,9 @@ def _get_task_hints(self) -> list[str]: doc += " ...[truncated]" elif self.flags.hint_index_type == "dense": query_embedding = self.hint_retriever.encode(query) - _, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.flags.hint_num_results) + _, docs = self.hint_index.get_nearest_examples( + "embeddings", query_embedding, k=self.flags.hint_num_results + ) docs = docs["text"] return docs @@ -373,4 +376,4 @@ def _get_task_hints(self) -> list[str]: except Exception as e: print(f"Warning: Error getting hints for task {self.task_name}: {e}") - return [] \ No newline at end of file + return [] diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 2699024f..c986fbce 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -11,12 +11,11 @@ from typing import Literal import pandas as pd -from browsergym.core.action.base import AbstractActionSet - from agentlab.agents import dynamic_prompting as dp from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise +from browsergym.core.action.base import AbstractActionSet @dataclass @@ -292,11 +291,7 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): - def __init__( - self, - visible: bool, - task_hints: list[str] - ) -> None: + def __init__(self, visible: bool, task_hints: list[str]) -> None: super().__init__(visible=visible) self.task_hints = task_hints From a3b6ca46ea41264d9df3e543031a07bbb9a54cd4 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Fri, 12 Sep 2025 17:24:32 -0400 Subject: [PATCH 38/53] bug fix hint retrieval --- src/agentlab/agents/generic_agent_hinter/generic_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 0e17e711..68664ff0 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -321,7 +321,9 @@ def _get_task_hints(self) -> list[str]: if self.flags.hint_query_type == "goal": query = self.obs_history[-1]["goal_object"][0]["text"] elif self.flags.hint_query_type == "llm": - query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex) + queries, _ = self._get_queries() + # HACK: only 1 query supported + query = queries[0] else: raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") From 49ebc8985966913af932d23efb87d95d0bf425a0 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Fri, 12 Sep 2025 17:25:18 -0400 Subject: [PATCH 39/53] improve launch script --- experiments/hinter/run_hinter_agent.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/experiments/hinter/run_hinter_agent.py b/experiments/hinter/run_hinter_agent.py index fb2e4d57..a5a0d544 100644 --- a/experiments/hinter/run_hinter_agent.py +++ b/experiments/hinter/run_hinter_agent.py @@ -30,6 +30,7 @@ def main(): parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25") parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m") parser.add_argument("--hint-num-results", type=int, default=5) + parser.add_argument("--debug", action="http://github.com/store_true") args = parser.parse_args() flags = FLAGS_GPT_4o @@ -49,11 +50,12 @@ def main(): benchmark = DEFAULT_BENCHMARKS[args.benchmark]() - # # shuffle env_args_list and pick first 33 - # import numpy as np - # rng = np.random.default_rng(42) - # rng.shuffle(benchmark.env_args_list) - # benchmark.env_args_list = benchmark.env_args_list[:33] + if args.debug: + # shuffle env_args_list and + import numpy as np + rng = np.random.default_rng(42) + rng.shuffle(benchmark.env_args_list) + benchmark.env_args_list = benchmark.env_args_list[:6] if args.relaunch: @@ -73,7 +75,7 @@ def main(): n_jobs=args.n_jobs, parallel_backend=args.parallel_backend, strict_reproducibility=args.reproducibility_mode, - n_relaunch=1, + n_relaunch=3, ) From ddea29b02bf1afc3e112c68292058648a0e57c76 Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Tue, 16 Sep 2025 20:20:27 -0400 Subject: [PATCH 40/53] get queries only for step level hint --- .../agents/generic_agent_hinter/generic_agent.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 0cbdb6b3..6cb63155 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -109,10 +109,11 @@ def get_action(self, obs): system_prompt = SystemMessage(dp.SystemPrompt().prompt) - queries, think_queries = self._get_queries() - - # use those queries to retrieve from the database and pass to prompt if step-level - queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None + if self.flags.hint_level == "step": + # use those queries to retrieve from the database and pass to prompt if step-level + queries_for_hints, think_queries = self._get_queries() + else: + queries_for_hints = None main_prompt = MainPrompt( action_set=self.action_set, From b38052ecd217953723e086dc3ca7340edf52a11a Mon Sep 17 00:00:00 2001 From: amanjaiswal73892 Date: Wed, 17 Sep 2025 15:50:41 +0000 Subject: [PATCH 41/53] Add webarenalite to agentlab loop.py --- src/agentlab/experiments/loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index de4b976a..f69322a6 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -907,7 +907,6 @@ def _move_old_exp(exp_dir): def _get_env_name(task_name: str): """Register tasks if needed (lazy import) and return environment name.""" - # lazy import if task_name.startswith("miniwob"): import browsergym.miniwob @@ -915,6 +914,7 @@ def _get_env_name(task_name: str): import browsergym.workarena elif task_name.startswith("webarena"): import browsergym.webarena + import browsergym.webarenalite elif task_name.startswith("visualwebarena"): import browsergym.visualwebarena elif task_name.startswith("assistantbench"): From 4189ca95c43ae645dfee92717b87c5c08a4d21da Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 24 Sep 2025 19:21:36 +0200 Subject: [PATCH 42/53] update stepwise hint queries prompt --- .../generic_agent_prompt.py | 33 ++++++++++++++----- .../agents/tool_use_agent/tool_use_agent.py | 6 ++++ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 10cfeef6..88db8724 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -18,6 +18,7 @@ from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise +logger = logging.getLogger(__name__) @dataclass class GenericPromptFlags(dp.Flags): @@ -404,7 +405,7 @@ def _init(self): else: print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - + self.hints_source = HintsSource( hint_db_path=hint_db_path.as_posix(), hint_retrieval_mode=self.hint_retrieval_mode, @@ -531,13 +532,14 @@ def _prompt(self) -> HumanMessage: # Querying memory Before choosing an action, let's search our available documentation and memory for relevant context. -Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow +Generate a brief, general summary of the current status to help identify useful hints. Return your answer in the following format: chain of thought -json list of strings for the queries. Return exactly {self.n_queries} -queries in the list. +json list of strings of queries -# Concrete Example +Additional instructions: List of queries should contain up to {self.n_queries} queries. Both the think and the queries blocks are required! +# Concrete Example +``` I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if I will be able to sort by both at the same time. @@ -546,6 +548,9 @@ def _prompt(self) -> HumanMessage: {example_queries_str} +``` +Note: do not generate backticks. +Now proceed to generate your own thoughts and queries. """ ) @@ -556,8 +561,18 @@ def shrink(self): self.obs.shrink() def _parse_answer(self, text_answer): - ans_dict = parse_html_tags_raise( - text_answer, keys=["think", "queries"], merge_multiple=True - ) - ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) + try: + ans_dict = parse_html_tags_raise( + text_answer, keys=["think", "queries"], merge_multiple=True + ) + except Exception as e: + t = text_answer.replace("\n", "\\n") + logger.exception(f"Failed to parse llm answer: {e}. RAW answer: {t}") + raise e + try: + ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) + except Exception as e: + t = text_answer.replace("\n", "\\n") + logger.exception(f"Failed to parse queries: {e}. RAW llm answer: {t}") + raise e return ans_dict diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index bd200da3..e33a3082 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -43,6 +43,7 @@ from agentlab.llm.tracking import cost_tracker_decorator logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) @dataclass @@ -423,12 +424,17 @@ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: hint_topics = list(topic_to_hints.keys()) topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) prompt = self.llm_prompt.format(goal=goal, topics=topics) + logger.info(f"LLM choose hint topics prompt: {prompt}") if isinstance(llm, ChatModel): response: str = llm(messages=[dict(role="user", content=prompt)])["content"] else: response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think try: + response = response.strip() + if response.endswith("."): + response = response[:-1] + logger.info(f"LLM choose hint topics RAW response: {response}") topic_number = json.loads(response) if topic_number < 0 or topic_number >= len(hint_topics): logger.error(f"Wrong LLM hint id response: {response}, no hints") From 2e7de68d73e0177d4ce9058ccea36deb3913c048 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 24 Sep 2025 19:28:27 +0200 Subject: [PATCH 43/53] fix exc logging --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 88db8724..ba76c8de 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -567,12 +567,12 @@ def _parse_answer(self, text_answer): ) except Exception as e: t = text_answer.replace("\n", "\\n") - logger.exception(f"Failed to parse llm answer: {e}. RAW answer: {t}") + logger.warning(f"Failed to parse llm answer: {e}. RAW answer: '{t}'. Will retry") raise e try: ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) except Exception as e: t = text_answer.replace("\n", "\\n") - logger.exception(f"Failed to parse queries: {e}. RAW llm answer: {t}") + logger.warning(f"Failed to parse queries: {e}. RAW llm answer: '{t}'. Will retry") raise e return ans_dict From 77ffacec4fafba676163fc8ede2e4f1dfb7f8a57 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 24 Sep 2025 20:23:26 +0200 Subject: [PATCH 44/53] non empty instruction --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index ba76c8de..af11896f 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -551,6 +551,7 @@ def _prompt(self) -> HumanMessage: ``` Note: do not generate backticks. Now proceed to generate your own thoughts and queries. +Always return non-empty answer, its very important! """ ) @@ -569,10 +570,11 @@ def _parse_answer(self, text_answer): t = text_answer.replace("\n", "\\n") logger.warning(f"Failed to parse llm answer: {e}. RAW answer: '{t}'. Will retry") raise e + raw_queries = ans_dict.get("queries", "[]") try: - ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) + ans_dict["queries"] = json.loads(raw_queries) except Exception as e: t = text_answer.replace("\n", "\\n") - logger.warning(f"Failed to parse queries: {e}. RAW llm answer: '{t}'. Will retry") + logger.warning(f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry") raise e return ans_dict From bb15454cd33ab05820e5a2ff53efe7ee2baf15ba Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 24 Sep 2025 20:24:32 +0200 Subject: [PATCH 45/53] allow less then max hint queries --- src/agentlab/agents/generic_agent_hinter/generic_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 6cb63155..74e25b68 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -200,7 +200,7 @@ def _get_queries(self): ) queries = ans_dict.get("queries", []) - assert len(queries) == self.flags.n_retrieval_queries + assert len(queries) <= self.flags.n_retrieval_queries # TODO: we should probably propagate these chat_messages to be able to see them in xray return queries, ans_dict.get("think", None) From 04358c329f333e2fdc072fdc760ebad605b57197 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Thu, 25 Sep 2025 15:28:32 +0200 Subject: [PATCH 46/53] add generic agent gpt5-nano config --- src/agentlab/agents/generic_agent_hinter/__init__.py | 5 +++-- src/agentlab/agents/generic_agent_hinter/agent_configs.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/__init__.py b/src/agentlab/agents/generic_agent_hinter/__init__.py index 659aa35a..08b44255 100644 --- a/src/agentlab/agents/generic_agent_hinter/__init__.py +++ b/src/agentlab/agents/generic_agent_hinter/__init__.py @@ -13,6 +13,8 @@ AGENT_CLAUDE_SONNET_35, AGENT_CLAUDE_SONNET_35_VISION, AGENT_CUSTOM, + AGENT_GPT5_MINI, + AGENT_GPT5_NANO, AGENT_LLAMA3_70B, AGENT_LLAMA4_17B_INSTRUCT, AGENT_LLAMA31_70B, @@ -26,9 +28,7 @@ AGENT_o3_MINI, FLAGS_GPT_4o, GenericAgentArgs, - AGENT_GPT5_MINI, ) - from .generic_agent import GenericAgent, GenericAgentArgs __all__ = [ @@ -50,4 +50,5 @@ "AGENT_4o_MINI_VISION", "AGENT_CLAUDE_SONNET_35_VISION", "AGENT_GPT5_MINI", + "AGENT_GPT5_NANO", ] diff --git a/src/agentlab/agents/generic_agent_hinter/agent_configs.py b/src/agentlab/agents/generic_agent_hinter/agent_configs.py index 798445db..031b824c 100644 --- a/src/agentlab/agents/generic_agent_hinter/agent_configs.py +++ b/src/agentlab/agents/generic_agent_hinter/agent_configs.py @@ -365,6 +365,10 @@ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-mini-2025-08-07"], flags=GPT5_MINI_FLAGS, ) +AGENT_GPT5_NANO = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-nano-2025-08-07"], + flags=GPT5_MINI_FLAGS, +) AGENT_GPT5 = GenericAgentArgs( chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-2025-08-07"], From 0dcae12920aa7ecb8ea58eb06cb65aa177b39cbf Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Fri, 26 Sep 2025 17:04:38 +0000 Subject: [PATCH 47/53] make ray available on toolkit --- src/agentlab/experiments/launch_exp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/agentlab/experiments/launch_exp.py b/src/agentlab/experiments/launch_exp.py index 3bc6c54e..f14ae4c7 100644 --- a/src/agentlab/experiments/launch_exp.py +++ b/src/agentlab/experiments/launch_exp.py @@ -1,4 +1,5 @@ import logging +import os from importlib import import_module from pathlib import Path @@ -7,6 +8,8 @@ from agentlab.experiments.exp_utils import run_exp from agentlab.experiments.loop import ExpArgs, yield_all_exp_results +RAY_PUBLIC_DASHBOARD = os.environ.get("RAY_PUBLIC_DASHBOARD", "false") == "true" + def run_experiments( n_jobs, @@ -82,7 +85,9 @@ def run_experiments( elif parallel_backend == "ray": from agentlab.experiments.graph_execution_ray import execute_task_graph, ray - ray.init(num_cpus=n_jobs) + ray.init( + num_cpus=n_jobs, dashboard_host="0.0.0.0" if RAY_PUBLIC_DASHBOARD else "127.0.0.1" + ) try: execute_task_graph(exp_args_list, avg_step_timeout=avg_step_timeout) finally: From 15c5639424fde862225199dcb73f3eb372c77533 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 30 Sep 2025 14:01:51 +0000 Subject: [PATCH 48/53] check that hints db exists --- src/agentlab/agents/generic_agent_hinter/generic_agent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 74e25b68..ce07becc 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -8,6 +8,7 @@ the agent, including model arguments and flags for various behaviors. """ +import os from copy import deepcopy from dataclasses import asdict, dataclass from functools import partial @@ -89,6 +90,8 @@ def __init__( self.max_retry = max_retry self.flags = flags + if self.flags.hint_db_path is not None: + assert os.path.exists(self.flags.hint_db_path), f"Hint database path {self.flags.hint_db_path} does not exist." self.action_set = self.flags.action.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) From ce866f6e902b6f309780d35cb8797e346c926016 Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Fri, 17 Oct 2025 11:40:08 -0400 Subject: [PATCH 49/53] Fix assignment of queries_for_hints variable --- src/agentlab/agents/generic_agent_hinter/generic_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index a06c5970..ef8f10c5 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -117,7 +117,7 @@ def get_action(self, obs): system_prompt = SystemMessage(dp.SystemPrompt().prompt) # use those queries to retrieve from the database and pass to prompt if step-level - queries_for_hints = ( + self.queries = ( self._get_queries()[0] if getattr(self.flags, "hint_level", "episode") == "step" else None ) From 87e2510fc3a4651cfbc53d1d53cc10b6b458eb1b Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Mon, 20 Oct 2025 19:04:53 -0400 Subject: [PATCH 50/53] Improve generic agent hinter (#309) * Make LLM retreival topic index selection more robust --- .../agents/generic_agent_hinter/generic_agent.py | 9 +++++++-- .../generic_agent_hinter/generic_agent_prompt.py | 5 ++++- .../agents/generic_agent_hinter/tmlr_config.py | 2 +- src/agentlab/utils/hinting.py | 16 ++++++++++++++-- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index ef8f10c5..843879f8 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -93,7 +93,9 @@ def __init__( self.flags = flags if self.flags.hint_db_path is not None: - assert os.path.exists(self.flags.hint_db_path), f"Hint database path {self.flags.hint_db_path} does not exist." + assert os.path.exists( + self.flags.hint_db_path + ), f"Hint database path {self.flags.hint_db_path} does not exist." self.action_set = self.flags.action.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) @@ -118,7 +120,9 @@ def get_action(self, obs): # use those queries to retrieve from the database and pass to prompt if step-level self.queries = ( - self._get_queries()[0] if getattr(self.flags, "hint_level", "episode") == "step" else None + self._get_queries()[0] + if getattr(self.flags, "hint_level", "episode") == "step" + else None ) # get hints @@ -204,6 +208,7 @@ def _get_queries(self): ) chat_messages = Discussion([system_prompt, query_prompt.prompt]) + # BUG: Parsing fails multiple times. ans_dict = retry( self.chat_llm, chat_messages, diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 0fc08e41..5ccb73a9 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + @dataclass class GenericPromptFlags(dp.Flags): """ @@ -403,6 +404,8 @@ def _parse_answer(self, text_answer): ans_dict["queries"] = json.loads(raw_queries) except Exception as e: t = text_answer.replace("\n", "\\n") - logger.warning(f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry") + logger.warning( + f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry" + ) raise e return ans_dict diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py index d222b7c0..5a749721 100644 --- a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py +++ b/src/agentlab/agents/generic_agent_hinter/tmlr_config.py @@ -23,7 +23,7 @@ use_think_history=True, # gpt-4o config except for this line use_diff=False, html_type="pruned_html", - use_screenshot=True, + use_screenshot=False, use_som=False, extract_visible_tag=True, extract_clickable_tag=True, diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py index 6ba1f2d5..901d0361 100644 --- a/src/agentlab/utils/hinting.py +++ b/src/agentlab/utils/hinting.py @@ -12,11 +12,14 @@ import pandas as pd import requests from agentlab.llm.chat_api import ChatModel +import re +from agentlab.llm.response_api import APIPayload logger = logging.getLogger(__name__) class HintsSource: + def __init__( self, hint_db_path: str, @@ -27,7 +30,8 @@ def __init__( embedder_server: str = "http://localhost:5000/", llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n -Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", +Choose hint topic for the task and return only its number. Use the following output format: +index for e.g. 0 for the topic with index 0. If you don't know the answer, return -1""", ) -> None: self.hint_db_path = hint_db_path self.hint_retrieval_mode = hint_retrieval_mode @@ -96,7 +100,15 @@ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: else: response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think try: - topic_number = json.loads(response) + matches = re.findall(r"(-?\d+)", response) + if not matches: + logger.error(f"No choice tags found in LLM response: {response}") + return [] + if len(matches) > 1: + logger.warning( + f"LLM selected multiple topics for retrieval using only the first one." + ) + topic_number = int(matches[0]) if topic_number < 0 or topic_number >= len(hint_topics): logger.error(f"Wrong LLM hint id response: {response}, no hints") return [] From f06c6d04c515392ceef4d903423c6bbdf6b18969 Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:40:50 -0400 Subject: [PATCH 51/53] =?UTF-8?q?add=20new=20flag=20to=20skip=20hints=20wi?= =?UTF-8?q?th=20the=20current=20goal=20in=20the=20hint=20source=20t?= =?UTF-8?q?=E2=80=A6=20(#310)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add new flag to skip hints with the current goal in the hint source traces --- src/agentlab/utils/hinting.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py index 901d0361..506513d5 100644 --- a/src/agentlab/utils/hinting.py +++ b/src/agentlab/utils/hinting.py @@ -13,6 +13,7 @@ import requests from agentlab.llm.chat_api import ChatModel import re +import json from agentlab.llm.response_api import APIPayload logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def __init__( hint_db_path: str, hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", skip_hints_for_current_task: bool = False, + skip_hints_for_current_goal: bool = False, top_n: int = 4, embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", embedder_server: str = "http://localhost:5000/", @@ -36,6 +38,7 @@ def __init__( self.hint_db_path = hint_db_path self.hint_retrieval_mode = hint_retrieval_mode self.skip_hints_for_current_task = skip_hints_for_current_task + self.skip_hints_for_current_goal = skip_hints_for_current_goal self.top_n = top_n self.embedder_model = embedder_model self.embedder_server = embedder_server @@ -45,7 +48,16 @@ def __init__( self.hint_db_path = Path(hint_db_path).as_posix() else: self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() - self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + self.hint_db = pd.read_csv( + self.hint_db_path, + header=0, + index_col=None, + dtype=str, + converters={ + "trace_paths_json": lambda x: json.loads(x) if pd.notna(x) else [], + "source_trace_goals": lambda x: json.loads(x) if pd.notna(x) else [], + }, + ) logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") if self.hint_retrieval_mode == "emb": self.load_hint_vectors() @@ -84,7 +96,9 @@ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: topic_to_hints = defaultdict(list) skip_hints = [] if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) + skip_hints += self.get_current_task_hints(task_name) + if self.skip_hints_for_current_goal: + skip_hints += self.get_current_goal_hints(goal) for _, row in self.hint_db.iterrows(): hint = row["hint"] if hint in skip_hints: @@ -128,7 +142,9 @@ def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: all_hints = self.uniq_hints["hint"].tolist() skip_hints = [] if self.skip_hints_for_current_task: - skip_hints = self.get_current_task_hints(task_name) + skip_hints += self.get_current_task_hints(task_name) + if self.skip_hints_for_current_goal: + skip_hints += self.get_current_goal_hints(goal) hint_embeddings = [] id_to_hint = {} for hint, emb in zip(all_hints, self.hint_embeddings): @@ -199,3 +215,7 @@ def get_current_task_hints(self, task_name): self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) ] return hints_df["hint"].tolist() + + def get_current_goal_hints(self, goal_str: str): + mask = self.hint_db["source_trace_goals"].apply(lambda goals: goal_str in goals) + return self.hint_db.loc[mask, "hint"].tolist() From cf0e9b3696ee070bfba093425c22f68d25cbf40e Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:21:21 -0400 Subject: [PATCH 52/53] Rename generic agent hinter to hint_use_agent (#311) * rename generic_agent_hinter to hint_use_agent for clarity * Add deprecation warning and module alias for generic_agent_hinter * improve module aliasing for submodules * Add todo rename agent name * black * bugfix: check for hint_db only when use_task_hint is true. * fix: address missing initialization and correct args reference in choose_hints method * black --- .../agents/generic_agent_hinter/__init__.py | 65 +++++-------------- .../agents/hint_use_agent/__init__.py | 54 +++++++++++++++ .../agent_configs.py | 0 .../generic_agent.py | 10 ++- .../generic_agent_prompt.py | 0 .../tmlr_config.py | 0 6 files changed, 76 insertions(+), 53 deletions(-) create mode 100644 src/agentlab/agents/hint_use_agent/__init__.py rename src/agentlab/agents/{generic_agent_hinter => hint_use_agent}/agent_configs.py (100%) rename src/agentlab/agents/{generic_agent_hinter => hint_use_agent}/generic_agent.py (97%) rename src/agentlab/agents/{generic_agent_hinter => hint_use_agent}/generic_agent_prompt.py (100%) rename src/agentlab/agents/{generic_agent_hinter => hint_use_agent}/tmlr_config.py (100%) diff --git a/src/agentlab/agents/generic_agent_hinter/__init__.py b/src/agentlab/agents/generic_agent_hinter/__init__.py index 08b44255..4ad73676 100644 --- a/src/agentlab/agents/generic_agent_hinter/__init__.py +++ b/src/agentlab/agents/generic_agent_hinter/__init__.py @@ -1,54 +1,19 @@ -""" -Baseline agent for all ServiceNow papers +import importlib, sys, warnings -This module contains the GenericAgent class, which is the baseline agent for all ServiceNow papers. \ -It is a simple agent that can be ran OOB on all BrowserGym environments. It is also shipped with \ -a few configurations that can be used to run it on different environments. -""" +OLD = __name__ +NEW = "agentlab.agents.hint_use_agent" +SUBS = ("agent_configs", "generic_agent_prompt", "generic_agent", "tmlr_config") -from .agent_configs import ( - AGENT_3_5, - AGENT_8B, - AGENT_37_SONNET, - AGENT_CLAUDE_SONNET_35, - AGENT_CLAUDE_SONNET_35_VISION, - AGENT_CUSTOM, - AGENT_GPT5_MINI, - AGENT_GPT5_NANO, - AGENT_LLAMA3_70B, - AGENT_LLAMA4_17B_INSTRUCT, - AGENT_LLAMA31_70B, - CHAT_MODEL_ARGS_DICT, - RANDOM_SEARCH_AGENT, - AGENT_4o, - AGENT_4o_MINI, - AGENT_4o_MINI_VISION, - AGENT_4o_VISION, - AGENT_o1_MINI, - AGENT_o3_MINI, - FLAGS_GPT_4o, - GenericAgentArgs, +warnings.warn( + f"{OLD} is renamed to {NEW}. {OLD} will be removed in future", + DeprecationWarning, + stacklevel=2, ) -from .generic_agent import GenericAgent, GenericAgentArgs -__all__ = [ - "AGENT_3_5", - "AGENT_4o", - "AGENT_4o_MINI", - "AGENT_4o_VISION", - "AGENT_o3_MINI", - "AGENT_o1_MINI", - "AGENT_LLAMA4_17B_INSTRUCT", - "AGENT_LLAMA3_70B", - "AGENT_LLAMA31_70B", - "AGENT_8B", - "RANDOM_SEARCH_AGENT", - "AGENT_CUSTOM", - "AGENT_CLAUDE_SONNET_35", - "AGENT_37_SONNET", - "AGENT_4o_VISION", - "AGENT_4o_MINI_VISION", - "AGENT_CLAUDE_SONNET_35_VISION", - "AGENT_GPT5_MINI", - "AGENT_GPT5_NANO", -] +# Alias the top-level +new_mod = importlib.import_module(NEW) +sys.modules[OLD] = new_mod + +# Alias known submodules +for sub in SUBS: + sys.modules[f"{OLD}.{sub}"] = importlib.import_module(f"{NEW}.{sub}") diff --git a/src/agentlab/agents/hint_use_agent/__init__.py b/src/agentlab/agents/hint_use_agent/__init__.py new file mode 100644 index 00000000..08b44255 --- /dev/null +++ b/src/agentlab/agents/hint_use_agent/__init__.py @@ -0,0 +1,54 @@ +""" +Baseline agent for all ServiceNow papers + +This module contains the GenericAgent class, which is the baseline agent for all ServiceNow papers. \ +It is a simple agent that can be ran OOB on all BrowserGym environments. It is also shipped with \ +a few configurations that can be used to run it on different environments. +""" + +from .agent_configs import ( + AGENT_3_5, + AGENT_8B, + AGENT_37_SONNET, + AGENT_CLAUDE_SONNET_35, + AGENT_CLAUDE_SONNET_35_VISION, + AGENT_CUSTOM, + AGENT_GPT5_MINI, + AGENT_GPT5_NANO, + AGENT_LLAMA3_70B, + AGENT_LLAMA4_17B_INSTRUCT, + AGENT_LLAMA31_70B, + CHAT_MODEL_ARGS_DICT, + RANDOM_SEARCH_AGENT, + AGENT_4o, + AGENT_4o_MINI, + AGENT_4o_MINI_VISION, + AGENT_4o_VISION, + AGENT_o1_MINI, + AGENT_o3_MINI, + FLAGS_GPT_4o, + GenericAgentArgs, +) +from .generic_agent import GenericAgent, GenericAgentArgs + +__all__ = [ + "AGENT_3_5", + "AGENT_4o", + "AGENT_4o_MINI", + "AGENT_4o_VISION", + "AGENT_o3_MINI", + "AGENT_o1_MINI", + "AGENT_LLAMA4_17B_INSTRUCT", + "AGENT_LLAMA3_70B", + "AGENT_LLAMA31_70B", + "AGENT_8B", + "RANDOM_SEARCH_AGENT", + "AGENT_CUSTOM", + "AGENT_CLAUDE_SONNET_35", + "AGENT_37_SONNET", + "AGENT_4o_VISION", + "AGENT_4o_MINI_VISION", + "AGENT_CLAUDE_SONNET_35_VISION", + "AGENT_GPT5_MINI", + "AGENT_GPT5_NANO", +] diff --git a/src/agentlab/agents/generic_agent_hinter/agent_configs.py b/src/agentlab/agents/hint_use_agent/agent_configs.py similarity index 100% rename from src/agentlab/agents/generic_agent_hinter/agent_configs.py rename to src/agentlab/agents/hint_use_agent/agent_configs.py diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/hint_use_agent/generic_agent.py similarity index 97% rename from src/agentlab/agents/generic_agent_hinter/generic_agent.py rename to src/agentlab/agents/hint_use_agent/generic_agent.py index 843879f8..afc688ed 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/hint_use_agent/generic_agent.py @@ -39,6 +39,7 @@ class GenericAgentArgs(AgentArgs): def __post_init__(self): try: # some attributes might be temporarily args.CrossProd for hyperparameter generation + # TODO: Rename the agent to HintUseAgent when appropriate self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace( "/", "_" ) @@ -92,7 +93,8 @@ def __init__( self.max_retry = max_retry self.flags = flags - if self.flags.hint_db_path is not None: + + if self.flags.hint_db_path is not None and self.flags.use_task_hint: assert os.path.exists( self.flags.hint_db_path ), f"Hint database path {self.flags.hint_db_path} does not exist." @@ -323,7 +325,8 @@ def _get_task_hints(self) -> list[str]: if self.flags.hint_type == "docs": if not hasattr(self, "hint_index"): print("Initializing hint index new time") - self._init() + # @patricebechard It seems _.init() method is missing do we still need it? + # self._init() if self.flags.hint_query_type == "goal": query = self.obs_history[-1]["goal_object"][0]["text"] elif self.flags.hint_query_type == "llm": @@ -331,6 +334,7 @@ def _get_task_hints(self) -> list[str]: # HACK: only 1 query supported query = queries[0] else: + # @patricebechard: This raises an error with the default value 'direct' raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}") print(f"Query: {query}") @@ -369,7 +373,7 @@ def _get_task_hints(self) -> list[str]: goal_or_queries = "\n".join(self.queries) task_hints = self.hints_source.choose_hints( - self.llm, + self.chat_llm, self.task_name, goal_or_queries, ) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/hint_use_agent/generic_agent_prompt.py similarity index 100% rename from src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py rename to src/agentlab/agents/hint_use_agent/generic_agent_prompt.py diff --git a/src/agentlab/agents/generic_agent_hinter/tmlr_config.py b/src/agentlab/agents/hint_use_agent/tmlr_config.py similarity index 100% rename from src/agentlab/agents/generic_agent_hinter/tmlr_config.py rename to src/agentlab/agents/hint_use_agent/tmlr_config.py From a46ddc69eeab891716ff4f3fa84e7b80c67b4129 Mon Sep 17 00:00:00 2001 From: amanjaiswal73892 Date: Mon, 27 Oct 2025 22:48:22 +0000 Subject: [PATCH 53/53] bugfix: skip HintSource init if use_task_hint is false --- .../agents/tool_use_agent/tool_use_agent.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 91594d3e..894616a4 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -319,14 +319,15 @@ class TaskHint(Block): def _init(self): """Initialize the block.""" - self.hints_source = HintsSource( - hint_db_path=self.hint_db_rel_path, - hint_retrieval_mode=self.hint_retrieval_mode, - top_n=self.top_n, - embedder_model=self.embedder_model, - embedder_server=self.embedder_server, - llm_prompt=self.llm_prompt, - ) + if self.use_task_hint: + self.hints_source = HintsSource( + hint_db_path=self.hint_db_rel_path, + hint_retrieval_mode=self.hint_retrieval_mode, + top_n=self.top_n, + embedder_model=self.embedder_model, + embedder_server=self.embedder_server, + llm_prompt=self.llm_prompt, + ) def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: