Skip to content

Environment#

AREEnvironment #

Bases: Environment

ARE Environment adapter following OSWorld/CRMArena pattern.

This class provides Agoge integration for ARE scenarios while preserving ARE's native event-based execution in an isolated environment.

All conversion logic (Agoge ↔ ARE formats) is handled here in the main process. The proxy only deals with ARE's native types.

Parameters:

Name Type Description Default
scenario_id str | None

ID of ARE scenario to load

None
scenario_kwargs dict | None

Additional kwargs for scenario initialization

None
max_steps int

Maximum steps before episode termination

50
verbosity_level str

Notification verbosity (LOW/MEDIUM/HIGH)

'LOW'
judge_model str

Model name for LLM judge (default: gpt-5-mini)

'gpt-5-mini'
Source code in src/agoge/environment/are_environment.py
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
class AREEnvironment(Environment):
    """ARE Environment adapter following OSWorld/CRMArena pattern.

    This class provides Agoge integration for ARE scenarios while preserving
    ARE's native event-based execution in an isolated environment.

    All conversion logic (Agoge ↔ ARE formats) is handled here in the main process.
    The proxy only deals with ARE's native types.

    Args:
        scenario_id: ID of ARE scenario to load
        scenario_kwargs: Additional kwargs for scenario initialization
        max_steps: Maximum steps before episode termination
        verbosity_level: Notification verbosity (LOW/MEDIUM/HIGH)
        judge_model: Model name for LLM judge (default: gpt-5-mini)
    """

    def __init__(
        self,
        scenario_id: str | None = None,
        scenario_kwargs: dict | None = None,
        max_steps: int = 50,
        verbosity_level: str = "LOW",
        judge_model: str = "gpt-5-mini",
        a2a_config: dict | None = None,
    ):
        """Initialize the ARE environment.

        Args:
            a2a_config: Agent2Agent configuration dict. Example:
                {"model": "gpt-5-mini", "provider": "openai", "endpoint": None}
                If None, Agent2Agent is disabled.
        """
        super().__init__(max_steps=max_steps)

        logger.info(
            f"Initializing ARE environment (max_steps={max_steps}, verbosity={verbosity_level}, "
            f"judge_model={judge_model})"
        )

        # Create proxy actor with runtime_env
        current_node_id = ray.get_runtime_context().get_node_id()
        self.proxy = AREProxy.options(
            scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                node_id=current_node_id, soft=False
            )
        ).remote(
            scenario_id=scenario_id,
            scenario_kwargs=scenario_kwargs,
            max_steps=max_steps,
            verbosity_level=verbosity_level,
            judge_model=judge_model,
            a2a_config=a2a_config,
        )

        # Store judge model for evaluation details
        self.judge_model = judge_model

        # Cache
        self._tool_schemas = None
        self._environment_stopped = False
        self._evaluation_details: dict | None = None  # Store detailed evaluation info for trajectory

        logger.info("ARE environment initialized successfully (using runtime_env)")

    async def _reset_impl(self, task: Task, **reset_kwargs) -> Chat:
        """Start a new episode.

        Returns observation with episode-specific system prompt as the first SystemMessage.
        This allows agents that support dynamic prompts (like DynamicPromptAgent) to
        extract and use the scenario-specific system prompt.
        """
        logger.info("=" * 80)
        logger.info(f"STARTING NEW EPISODE: {task.task_id}")
        logger.info("=" * 80)

        # Reset ARE-specific state
        self._environment_stopped = False
        self._evaluation_details = None
        self._episode_id = str(uuid.uuid4())
        self._all_internal_calls = []  # Track expert agent internal calls across all steps

        # Convert task to dict for serialization
        task_dict = {
            "task_id": task.task_id,
            "inputs": task.inputs,
            "metadata": task.metadata,
        }

        # Reset proxy and get initial state (system prompt built in isolated environment)
        logger.info("Calling proxy.reset() - this may take a few minutes on first run (runtime_env setup)")
        result = await self.proxy.reset.remote(task_dict, **reset_kwargs)
        logger.info("Proxy reset completed")

        # Convert raw ARE schemas to Agoge format
        raw_schemas = result["tool_schemas"]
        self._tool_schemas = [self._convert_are_schema_to_agoge(schema) for schema in raw_schemas]

        logger.info(f"Episode initialized with {len(self._tool_schemas)} tools")

        system_prompt = result["system_prompt"]
        instruction = result["instruction"]

        # Following ARE's conversation format:
        # 1. SystemMessage with scenario-specific system prompt
        # 2. UserMessage with task instruction (matching ARE's "[TASK]: \n{content}\n" format)
        messages = [
            SystemMessage(content=system_prompt),
            UserMessage(content=f"[TASK]: \n{instruction}\n"),
        ]
        observation = Chat(messages=messages)

        logger.info(f"Reset complete: system_prompt={len(system_prompt)} chars, instruction={len(instruction)} chars")

        # Clear initial user message from notification queue to prevent duplication
        await self.proxy._clear_initial_user_message.remote()

        return observation

    async def _execute_tool_calls(self, action: AssistantMessage) -> tuple[list, bool, int]:
        """Execute tool calls and detect wait_for_notification.

        Returns:
            tuple of (tool_messages, wait_for_notification_called, wait_timeout)
        """
        tool_messages = []
        wait_for_notification_called = False
        wait_timeout = 0

        if action.tool_calls:
            for tool_call in action.tool_calls:
                try:
                    func_name, func_args = self._parse_tool_call(tool_call)

                    if func_name == "SystemApp__wait_for_notification":
                        wait_for_notification_called = True
                        wait_timeout = func_args.get("timeout", 0)
                        logger.debug(f"[WAIT_FOR_NOTIFICATION] Detected with timeout={wait_timeout}s")

                    if func_name == "AgentUserInterface__send_message_to_user":
                        logger.info(f"[SEND_MESSAGE_TO_USER] Step {self.step_count}: Agent sending message to user")

                    result_dict = await self.proxy.execute_tool.remote(func_name, func_args)

                    if func_name == "AgentUserInterface__send_message_to_user":
                        # Process turn boundary: tick until turn validation completes
                        turn_result = await self.proxy.process_turn_boundary.remote()

                        logger.info(
                            f"[SEND_MESSAGE_TO_USER] Turn boundary result: "
                            f"state={turn_result['state']}, success={turn_result['turn_success']}, "
                            f"is_last_turn={turn_result['is_last_turn']}"
                        )

                        # End episode if:
                        # 1. Turn validation failed
                        # 2. This was the last turn
                        # 3. Environment stopped/failed
                        if turn_result["state"] in {"STOPPED", "FAILED"}:
                            self._environment_stopped = True
                            logger.info(f"[SEND_MESSAGE_TO_USER] Episode ending: env_state={turn_result['state']}")
                        elif turn_result["turn_success"] is False:
                            self._environment_stopped = True
                            logger.info(
                                f"[SEND_MESSAGE_TO_USER] Episode ending: "
                                f"turn validation failed - {turn_result['turn_rationale']}"
                            )
                        elif turn_result["is_last_turn"]:
                            self._environment_stopped = True
                            logger.info("[SEND_MESSAGE_TO_USER] Episode ending: last turn completed successfully")

                    msg = self._create_tool_message(tool_call.id, result_dict)
                    tool_messages.append(msg)

                except Exception as e:
                    logger.exception(f"Error executing tool {tool_call.id}")
                    error_msg = ToolMessage(content=f"Error: {e!s}", tool_call_id=tool_call.id)
                    tool_messages.append(error_msg)
        else:
            logger.debug(f"Step {self.step_count}: No tool calls in action")

        return tool_messages, wait_for_notification_called, wait_timeout

    async def _process_event_loop(self, wait_for_notification_called: bool) -> dict:
        """Process ARE event loop (wait or tick) and return environment state.

        Returns:
            dict with keys: current_time, queue_length, state
        """
        if wait_for_notification_called:
            time_result = await self.proxy.get_current_time.remote()
            current_time = time_result["current_time"]

            notification_messages = await self._get_pending_notifications()

            logger.debug(
                f"[WAIT_FOR_NOTIFICATION] Completed: time={current_time:.1f}s, "
                f"notifications={len(notification_messages)}"
            )

            env_state_result = await self.proxy.get_environment_state.remote()
            return {
                "current_time": current_time,
                "queue_length": env_state_result.get("queue_length", 0),
                "state": env_state_result.get("state", "RUNNING"),
                "notifications": notification_messages,
            }
        else:
            tick_result = await self.proxy.tick.remote()
            logger.debug(
                f"Step {self.step_count} Tick: time={tick_result['current_time']:.1f}s, "
                f"queue={tick_result['queue_length']}, state={tick_result['state']}"
            )

            notification_messages = await self._get_pending_notifications()
            tick_result["notifications"] = notification_messages
            return tick_result

    def _check_are_done(self, env_state: str) -> bool:
        """Check if episode is done based on ARE-specific conditions."""
        done = self._environment_stopped or env_state in {"FAILED", "STOPPED"}

        logger.debug(
            f"[CHECK_ARE_DONE] Step {self.step_count}: "
            f"_environment_stopped={self._environment_stopped}, env_state={env_state}, done={done}"
        )

        return done

    async def _step_impl(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Execute action and return next state."""
        tool_messages, wait_called, _ = await self._execute_tool_calls(action)

        tick_result = await self._process_event_loop(wait_called)
        tool_messages.extend(tick_result.pop("notifications", []))

        observation = Chat(messages=tool_messages)

        env_state = tick_result["state"]
        if env_state == "FAILED":
            logger.warning("ARE environment failed (likely turn validation failure or other error)")

        # Only check ARE-specific done conditions, NOT max_steps
        # max_steps is handled by base class step()
        done = self._check_are_done(env_state)

        return observation, DEFAULT_REWARD, done

    async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Execute action and evaluate when episode ends.

        This method wraps base class step() to ensure evaluate() is called
        whenever the episode ends, regardless of the cause (ARE environment
        stopped, validation failed, or max_steps reached).
        """
        chat, reward, done = await super().step(action)

        if done and self._evaluation_details is None:
            logger.info("=" * 80)
            logger.info(f"EPISODE DONE at step {self.step_count}")
            logger.info("=" * 80)

            reward = await self.evaluate()
            logger.info(f"ARE evaluation reward: {reward}")
            logger.info("=" * 80)

        return chat, reward, done

    async def _get_pending_notifications(self) -> list[UserMessage]:
        """Get pending notifications and convert to UserMessages."""
        messages = []

        notifications = await self.proxy.get_pending_notifications.remote()

        if not notifications:
            return messages

        # Separate by type
        user_messages = [n for n in notifications if n["type"] == "USER_MESSAGE"]
        env_notifications = [n for n in notifications if n["type"] == "ENVIRONMENT_NOTIFICATION"]
        env_stop_messages = [n for n in notifications if n["type"] == "ENVIRONMENT_STOP"]

        # Handle ENVIRONMENT_STOP
        if env_stop_messages:
            self._environment_stopped = True
            logger.info(f"Environment stop: {env_stop_messages[0]['message']}")

        # Format user messages
        if user_messages:
            content = "\n".join([msg["message"] for msg in user_messages])
            formatted = f"User message updates:\n***\n{content}\n***\n"
            messages.append(UserMessage(content=formatted, label=MessageLabel.OBSERVATION))

        # Format environment notifications with timestamps
        if env_notifications:
            from datetime import datetime

            formatted_notifs = []
            for notif in env_notifications:
                timestamp = datetime.fromisoformat(notif["timestamp"])
                formatted_notifs.append(f"[{timestamp.strftime('%Y-%m-%d %H:%M:%S')}] {notif['message']}")

            content = "\n".join(formatted_notifs)
            formatted = f"Environment notifications:\n***\n{content}\n***\n"
            messages.append(UserMessage(content=formatted, label=MessageLabel.OBSERVATION))
            logger.debug(f"Delivered {len(env_notifications)} environment notification(s)")

        return messages

    def get_tool_schemas(self) -> list[dict]:
        """Get available tool schemas."""
        if self._tool_schemas is not None:
            return self._tool_schemas
        return []

    # === Internal conversion methods ===

    def _convert_are_schema_to_agoge(self, are_schema: dict) -> dict:
        """Convert ARE's native tool schema to Agoge's OpenAI format.

        Args:
            are_schema: ARE schema with keys: name, description, args

        Returns:
            OpenAI-compatible schema with keys: name, description, parameters
        """
        properties = {}
        required = []

        for arg in are_schema["args"]:
            param_type = PYTHON_TO_JSON_TYPE.get(arg["type"])
            param_schema = {"description": arg["description"]}
            if param_type:
                param_schema["type"] = param_type

            if arg["has_default"] and arg["default"] is not None:
                try:
                    json.dumps(arg["default"])
                    param_schema["default"] = arg["default"]
                except (TypeError, ValueError):
                    logger.debug(f"Skipping non-serializable default for {arg['name']}")

            properties[arg["name"]] = param_schema

            if not arg["has_default"]:
                required.append(arg["name"])

        return {
            "name": are_schema["name"],
            "description": are_schema["description"],
            "parameters": {
                "type": "object",
                "properties": properties,
                "required": required,
            },
        }

    def _parse_tool_call(self, tool_call) -> tuple[str, dict]:
        """Parse Agoge tool call to (function_name, arguments).

        Args:
            tool_call: ToolFunctionCall from AssistantMessage

        Returns:
            Tuple of (function_name, arguments_dict)
        """
        func_name = tool_call.function.get("name", "")
        func_args_str = tool_call.function.get("arguments", "{}")

        # Parse arguments if string
        if isinstance(func_args_str, str):
            try:
                func_args = json.loads(func_args_str)
            except json.JSONDecodeError:
                logger.exception(f"Failed to parse arguments: {func_args_str}")
                func_args = {}
        else:
            func_args = func_args_str

        return func_name, func_args

    def _create_tool_message(self, tool_call_id: str, result_dict: dict) -> ToolMessage:
        """Create ToolMessage from execution result.

        Args:
            tool_call_id: ID of the tool call
            result_dict: Result dict with keys: result, success, internal_calls

        Returns:
            ToolMessage with formatted content
        """
        if result_dict["success"]:
            content = result_dict["result"]
        else:
            content = f"Error: {result_dict['result']}"

        # Cache internal_calls for later inclusion in are_evaluation (not in model context)
        internal_calls = result_dict.get("internal_calls", [])
        if internal_calls:
            self._all_internal_calls.append({
                "step": self.step_count,
                "tool_call_id": tool_call_id,
                "internal_calls": internal_calls,
            })

        return ToolMessage(content=content, tool_call_id=tool_call_id)

    async def evaluate(self, termination_message: str | None = None) -> float:
        """Evaluate the scenario using ARE's native validation system.

        This performs comprehensive validation:
        - Environment state checks (FAILED/RUNNING/STOPPED)
        - final_validation_checks() for scenario-specific validation
        - Optional judge system comparison of agent events vs oracle events

        Args:
            _termination_message: Optional final message from agent (for compatibility with OSWorld/CRMArena).
                                  Not used in ARE validation as it relies on environment state.

        Returns:
            float: Score between 0.0 and 1.0
                - 1.0 if validation succeeds (success=True)
                - 0.0 if validation fails (success=False or None)

        Note:
            Unlike OSWorld/CRMArena, ARE validation is deterministic and doesn't
            require external services. The validation checks ARE's internal state
            and event logs.
        """
        logger.info("=" * 80)
        logger.info("STARTING VALIDATION")
        logger.info("=" * 80)

        try:
            logger.debug("Calling proxy.validate.remote()...")
            result_dict = await self.proxy.validate.remote()
            logger.debug(f"Validation returned: {result_dict.keys() if result_dict else 'None'}")
        except (ray.exceptions.RayTaskError, ray.exceptions.RayActorError, TimeoutError) as e:
            self._evaluation_details = {
                "success": None,
                "rationale": f"Remote validation failed: {type(e).__name__} - {e!s}",
                "duration": None,
                "judge_model": None,
                "oracle_matching_failures": [],
            }
            logger.exception("Remote validation error")
            return 0.0

        # Extract all fields
        success = result_dict.get("success")
        rationale = result_dict.get("rationale", "No rationale provided")
        exception = result_dict.get("exception")
        duration = result_dict.get("duration")
        oracle_matching_failures = result_dict.get("oracle_matching_failures", [])

        # Log detailed validation info
        logger.info(f"Validation result: success={success}")
        logger.info(f"Validation rationale: {rationale}")
        if exception:
            logger.info(f"Validation exception: {exception}")
        if duration:
            logger.info(f"Validation duration: {duration}s")

        # Log oracle matching failures at debug level
        if oracle_matching_failures:
            logger.debug(f"Oracle matching failures: {len(oracle_matching_failures)}")

        # If rationale doesn't already contain exception details, append them
        if exception and exception not in rationale:
            detailed_rationale = f"{rationale}: {exception}"
        else:
            detailed_rationale = rationale

        # Store detailed evaluation info for trajectory
        self._evaluation_details = {
            "success": success,
            "rationale": detailed_rationale,
            "duration": duration,
            "judge_model": self.judge_model if success is not None else None,
            "oracle_matching_failures": result_dict.get("oracle_matching_failures", []),
            "step_internal_calls": getattr(self, "_all_internal_calls", []),
        }

        # Analyze the validation type based on the result
        if success is True:
            if rationale is None or rationale == "No rationale provided":
                logger.warning("⚠ SUCCESS with no rationale - likely using BASE CLASS validation (no judge!)")
                logger.warning("⚠ This means agent behavior was NOT compared against oracle events")
            else:
                logger.info("✓ SUCCESS with rationale - judge validation worked correctly")
            return 1.0
        elif success is False:
            if rationale is None or rationale == "No rationale provided":
                logger.warning("FAILURE with no rationale - basic validation failed (env.state == FAILED)")
            else:
                logger.warning("FAILURE with rationale - judge validation found mismatches")
            return 0.0
        else:
            logger.error("⚠ VALIDATION RETURNED None")
            if exception:
                logger.error(f"Exception during validation: {exception}")
            else:
                logger.error("No exception provided - likely using dummy validation (lambda returning None)")
            return 0.0

    async def get_evaluation_details(self) -> dict | None:
        """Get detailed evaluation information for trajectory storage.

        This method should be called after evaluate() to retrieve comprehensive
        evaluation metadata including success status, rationale, validation type,
        and other relevant information for trajectory analysis.

        Returns:
            dict | None: Evaluation details dictionary, or None if evaluation hasn't run yet
        """
        return self._evaluation_details

    async def cleanup(self):
        """Clean up Ray actor resources."""
        logger.info("Closing ARE environment.")
        if self.proxy is not None:
            try:
                await self.proxy.cleanup.remote()
            except Exception as e:
                logger.warning(f"Error during proxy cleanup: {e}")

__init__(scenario_id=None, scenario_kwargs=None, max_steps=50, verbosity_level='LOW', judge_model='gpt-5-mini', a2a_config=None) #

Initialize the ARE environment.

Parameters:

Name Type Description Default
a2a_config dict | None

Agent2Agent configuration dict. Example: {"model": "gpt-5-mini", "provider": "openai", "endpoint": None} If None, Agent2Agent is disabled.

None
Source code in src/agoge/environment/are_environment.py
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
def __init__(
    self,
    scenario_id: str | None = None,
    scenario_kwargs: dict | None = None,
    max_steps: int = 50,
    verbosity_level: str = "LOW",
    judge_model: str = "gpt-5-mini",
    a2a_config: dict | None = None,
):
    """Initialize the ARE environment.

    Args:
        a2a_config: Agent2Agent configuration dict. Example:
            {"model": "gpt-5-mini", "provider": "openai", "endpoint": None}
            If None, Agent2Agent is disabled.
    """
    super().__init__(max_steps=max_steps)

    logger.info(
        f"Initializing ARE environment (max_steps={max_steps}, verbosity={verbosity_level}, "
        f"judge_model={judge_model})"
    )

    # Create proxy actor with runtime_env
    current_node_id = ray.get_runtime_context().get_node_id()
    self.proxy = AREProxy.options(
        scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
            node_id=current_node_id, soft=False
        )
    ).remote(
        scenario_id=scenario_id,
        scenario_kwargs=scenario_kwargs,
        max_steps=max_steps,
        verbosity_level=verbosity_level,
        judge_model=judge_model,
        a2a_config=a2a_config,
    )

    # Store judge model for evaluation details
    self.judge_model = judge_model

    # Cache
    self._tool_schemas = None
    self._environment_stopped = False
    self._evaluation_details: dict | None = None  # Store detailed evaluation info for trajectory

    logger.info("ARE environment initialized successfully (using runtime_env)")

cleanup() async #

Clean up Ray actor resources.

Source code in src/agoge/environment/are_environment.py
2400
2401
2402
2403
2404
2405
2406
2407
async def cleanup(self):
    """Clean up Ray actor resources."""
    logger.info("Closing ARE environment.")
    if self.proxy is not None:
        try:
            await self.proxy.cleanup.remote()
        except Exception as e:
            logger.warning(f"Error during proxy cleanup: {e}")

evaluate(termination_message=None) async #

Evaluate the scenario using ARE's native validation system.

This performs comprehensive validation: - Environment state checks (FAILED/RUNNING/STOPPED) - final_validation_checks() for scenario-specific validation - Optional judge system comparison of agent events vs oracle events

Parameters:

Name Type Description Default
_termination_message

Optional final message from agent (for compatibility with OSWorld/CRMArena). Not used in ARE validation as it relies on environment state.

required

Returns:

Name Type Description
float float

Score between 0.0 and 1.0 - 1.0 if validation succeeds (success=True) - 0.0 if validation fails (success=False or None)

Note

Unlike OSWorld/CRMArena, ARE validation is deterministic and doesn't require external services. The validation checks ARE's internal state and event logs.

Source code in src/agoge/environment/are_environment.py
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
async def evaluate(self, termination_message: str | None = None) -> float:
    """Evaluate the scenario using ARE's native validation system.

    This performs comprehensive validation:
    - Environment state checks (FAILED/RUNNING/STOPPED)
    - final_validation_checks() for scenario-specific validation
    - Optional judge system comparison of agent events vs oracle events

    Args:
        _termination_message: Optional final message from agent (for compatibility with OSWorld/CRMArena).
                              Not used in ARE validation as it relies on environment state.

    Returns:
        float: Score between 0.0 and 1.0
            - 1.0 if validation succeeds (success=True)
            - 0.0 if validation fails (success=False or None)

    Note:
        Unlike OSWorld/CRMArena, ARE validation is deterministic and doesn't
        require external services. The validation checks ARE's internal state
        and event logs.
    """
    logger.info("=" * 80)
    logger.info("STARTING VALIDATION")
    logger.info("=" * 80)

    try:
        logger.debug("Calling proxy.validate.remote()...")
        result_dict = await self.proxy.validate.remote()
        logger.debug(f"Validation returned: {result_dict.keys() if result_dict else 'None'}")
    except (ray.exceptions.RayTaskError, ray.exceptions.RayActorError, TimeoutError) as e:
        self._evaluation_details = {
            "success": None,
            "rationale": f"Remote validation failed: {type(e).__name__} - {e!s}",
            "duration": None,
            "judge_model": None,
            "oracle_matching_failures": [],
        }
        logger.exception("Remote validation error")
        return 0.0

    # Extract all fields
    success = result_dict.get("success")
    rationale = result_dict.get("rationale", "No rationale provided")
    exception = result_dict.get("exception")
    duration = result_dict.get("duration")
    oracle_matching_failures = result_dict.get("oracle_matching_failures", [])

    # Log detailed validation info
    logger.info(f"Validation result: success={success}")
    logger.info(f"Validation rationale: {rationale}")
    if exception:
        logger.info(f"Validation exception: {exception}")
    if duration:
        logger.info(f"Validation duration: {duration}s")

    # Log oracle matching failures at debug level
    if oracle_matching_failures:
        logger.debug(f"Oracle matching failures: {len(oracle_matching_failures)}")

    # If rationale doesn't already contain exception details, append them
    if exception and exception not in rationale:
        detailed_rationale = f"{rationale}: {exception}"
    else:
        detailed_rationale = rationale

    # Store detailed evaluation info for trajectory
    self._evaluation_details = {
        "success": success,
        "rationale": detailed_rationale,
        "duration": duration,
        "judge_model": self.judge_model if success is not None else None,
        "oracle_matching_failures": result_dict.get("oracle_matching_failures", []),
        "step_internal_calls": getattr(self, "_all_internal_calls", []),
    }

    # Analyze the validation type based on the result
    if success is True:
        if rationale is None or rationale == "No rationale provided":
            logger.warning("⚠ SUCCESS with no rationale - likely using BASE CLASS validation (no judge!)")
            logger.warning("⚠ This means agent behavior was NOT compared against oracle events")
        else:
            logger.info("✓ SUCCESS with rationale - judge validation worked correctly")
        return 1.0
    elif success is False:
        if rationale is None or rationale == "No rationale provided":
            logger.warning("FAILURE with no rationale - basic validation failed (env.state == FAILED)")
        else:
            logger.warning("FAILURE with rationale - judge validation found mismatches")
        return 0.0
    else:
        logger.error("⚠ VALIDATION RETURNED None")
        if exception:
            logger.error(f"Exception during validation: {exception}")
        else:
            logger.error("No exception provided - likely using dummy validation (lambda returning None)")
        return 0.0

get_evaluation_details() async #

Get detailed evaluation information for trajectory storage.

This method should be called after evaluate() to retrieve comprehensive evaluation metadata including success status, rationale, validation type, and other relevant information for trajectory analysis.

Returns:

Type Description
dict | None

dict | None: Evaluation details dictionary, or None if evaluation hasn't run yet

Source code in src/agoge/environment/are_environment.py
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
async def get_evaluation_details(self) -> dict | None:
    """Get detailed evaluation information for trajectory storage.

    This method should be called after evaluate() to retrieve comprehensive
    evaluation metadata including success status, rationale, validation type,
    and other relevant information for trajectory analysis.

    Returns:
        dict | None: Evaluation details dictionary, or None if evaluation hasn't run yet
    """
    return self._evaluation_details

get_tool_schemas() #

Get available tool schemas.

Source code in src/agoge/environment/are_environment.py
2192
2193
2194
2195
2196
def get_tool_schemas(self) -> list[dict]:
    """Get available tool schemas."""
    if self._tool_schemas is not None:
        return self._tool_schemas
    return []

step(action) async #

Execute action and evaluate when episode ends.

This method wraps base class step() to ensure evaluate() is called whenever the episode ends, regardless of the cause (ARE environment stopped, validation failed, or max_steps reached).

Source code in src/agoge/environment/are_environment.py
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
    """Execute action and evaluate when episode ends.

    This method wraps base class step() to ensure evaluate() is called
    whenever the episode ends, regardless of the cause (ARE environment
    stopped, validation failed, or max_steps reached).
    """
    chat, reward, done = await super().step(action)

    if done and self._evaluation_details is None:
        logger.info("=" * 80)
        logger.info(f"EPISODE DONE at step {self.step_count}")
        logger.info("=" * 80)

        reward = await self.evaluate()
        logger.info(f"ARE evaluation reward: {reward}")
        logger.info("=" * 80)

    return chat, reward, done

CRMEnvironment #

Bases: Environment

CRM environment for distributed execution using Salesforce API.

This environment uses a Ray actor to isolate CRMArena dependencies and execute Salesforce queries through the SalesforceConnector.

Parameters:

Name Type Description Default
org_type str

Salesforce org type ('original', 'b2b', or 'b2c')

'original'
max_steps int

Maximum number of steps per episode

20
auth dict | None

Optional authentication dict (username/password/token or instance_url/session_id)

None
Source code in src/agoge/environment/crmarena.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class CRMEnvironment(Environment):
    """CRM environment for distributed execution using Salesforce API.

    This environment uses a Ray actor to isolate CRMArena dependencies
    and execute Salesforce queries through the SalesforceConnector.

    Args:
        org_type: Salesforce org type ('original', 'b2b', or 'b2c')
        max_steps: Maximum number of steps per episode
        auth: Optional authentication dict (username/password/token or instance_url/session_id)
    """

    # System metadata template for task-specific context
    SYSTEM_METADATA_TEMPLATE = """\
# Salesforce instance description
This is mainly to help you when you are using the free-form tools.
{schema}
# Additional task context
{system_metadata}
{optional}
"""

    def __init__(self, org_type: str = "original", max_steps: int = 20, auth: dict | None = None):
        """Initialize CRM environment.

        Args:
            org_type: Salesforce org type ('original', 'b2b', or 'b2c')
            max_steps: Maximum number of steps per episode
            auth: Optional authentication dict
        """
        super().__init__(max_steps=max_steps)
        self.org_type = org_type

        logger.info(f"Initializing CRMEnvironment with org_type='{org_type}'")
        self.connector_actor = SalesforceConnectorProxy.remote(org_type=org_type, auth=auth)

        from agoge.schema.tools.crmarena_tool import CRMArenaToolSet

        self.toolset = CRMArenaToolSet(self.connector_actor)

        self.final_answer: str | None = None

        logger.info("CRMEnvironment initialized successfully")

    def get_tool_schemas(self) -> list[dict]:
        """Return the CRM tool schemas."""
        return self.toolset.schema

    async def _reset_impl(self, task: Task, **reset_kwargs) -> Chat:
        """Start a new CRM episode.

        Args:
            task: Task instance with metadata and inputs
            **reset_kwargs: Additional reset arguments (unused)

        Returns:
            Chat: Initial observation with system metadata and user query
        """
        if reset_kwargs:
            logger.warning(f"Unused reset_kwargs: {reset_kwargs}")

        # Reset final answer for new episode
        self.final_answer = None

        # Build task-specific system metadata
        schema = task.metadata.get("schema", "")
        required_context = task.metadata.get("required", "")
        optional_context = task.metadata.get("optional", "")

        self.system_metadata = self.SYSTEM_METADATA_TEMPLATE.format(
            schema=schema,
            system_metadata=required_context,
            optional=optional_context,
        )

        # Two system messages: agent's prompt + schema in env metadata
        env_system_msg = SystemMessage(content=self.system_metadata)
        user_msg = UserMessage(content=task.inputs["query"])

        return Chat(messages=[env_system_msg, user_msg])

    async def _step_impl(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Process one turn of CRM interaction.

        Args:
            action: Agent's action containing tool calls

        Returns:
            tuple[Chat, float, bool]: (observation, reward, done)
        """
        tool_calls = action.tool_calls or []
        tool_messages = []
        done = False

        for tool_call in tool_calls:
            if tool_call.function:
                func_name = tool_call.function.get("name", "unknown")

                try:
                    func_args = tool_call.function.get("arguments", {})
                    if isinstance(func_args, str):
                        func_args = json.loads(func_args)

                    if hasattr(self.toolset, func_name):
                        tool_method = getattr(self.toolset, func_name)
                        result = await tool_method(**func_args)

                        if func_name == "respond":
                            done = True
                            self.final_answer = func_args.get("content", "")
                            logger.debug(f"Stored final answer: '{self.final_answer}'")
                    else:
                        available_tools = [name for name in dir(self.toolset) if not name.startswith("_")]
                        result = f"Unknown CRM tool: {func_name}. Available tools: {available_tools}"

                    tool_messages.append(ToolMessage(content=str(result), tool_call_id=tool_call.id))

                except Exception as e:
                    logger.exception(f"Error executing {func_name}")
                    error_msg = f"Error executing {func_name}: {e}"
                    tool_messages.append(ToolMessage(content=error_msg, tool_call_id=tool_call.id))

        # If no tool calls were made, provide guidance
        if not tool_messages:
            available = ", ".join([schema["name"] for schema in self.get_tool_schemas()])
            tool_messages.append(
                UserMessage(
                    content=f"Use CRM tools to gather information, or 'respond' to provide your final answer. "
                    f"Available tools: {available}"
                )
            )

        # Create observation with tool messages
        observation = Chat(messages=tool_messages, label=MessageLabel.OBSERVATION)

        # Reward calculation: 0.0 during episode, evaluate when done
        reward = 0.0
        if done:
            reward = await self._evaluate_task()

        return observation, reward, done

    async def _evaluate_task(self) -> float:
        """Evaluate task completion using CRMArena's evaluator.

        Returns:
            float: Reward score (0.0 to 1.0)
        """

        task = self._current_task

        # Extract evaluation parameters
        ground_truth = task.eval_criteria.get("answer")
        reward_metric = task.eval_criteria.get("reward_metric", "exact_match")
        task_name = task.metadata.get("task_name", "unknown")

        if self.final_answer is None:
            logger.warning(f"Task {task.task_id}: No final answer provided by agent")
            return 0.0

        try:
            logger.info(f"Evaluating: proposed_answer={self.final_answer}, ground_truth={ground_truth}")

            gt_answer_list = [ground_truth] if isinstance(ground_truth, str) else ground_truth
            result = await self.connector_actor.evaluate.remote(
                proposed_answer=self.final_answer,
                ground_truth=gt_answer_list,
                reward_metric=reward_metric,
                task_name=task_name,
            )

            # Handle different reward formats from evaluator
            reward_value = result.get("reward", 0.0)
            if isinstance(reward_value, dict):
                # If reward_metric == "fuzzy_match", reward_value is a dict of (em, f1, bleu, rouge)
                # For agoge compatibility, return BLEU score as the main metric for now
                reward = float(reward_value.get("bleu", 0.0))
            else:
                reward = float(reward_value)

            parsed_answer = result.get("parsed_answer", "")

            logger.info(
                f"Task evaluation complete: "
                f"task_id={task.task_id}, "
                f"metric={reward_metric}, "
                f"reward={reward:.2f}, "
                f"all_metrics={reward_value}, "
                f"parsed_answer={parsed_answer}, "
                f"ground_truth={ground_truth}"
            )

        except Exception:
            logger.exception(f"Error evaluating task {task.task_id}")
            return 0.0
        else:
            return reward

__init__(org_type='original', max_steps=20, auth=None) #

Initialize CRM environment.

Parameters:

Name Type Description Default
org_type str

Salesforce org type ('original', 'b2b', or 'b2c')

'original'
max_steps int

Maximum number of steps per episode

20
auth dict | None

Optional authentication dict

None
Source code in src/agoge/environment/crmarena.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def __init__(self, org_type: str = "original", max_steps: int = 20, auth: dict | None = None):
    """Initialize CRM environment.

    Args:
        org_type: Salesforce org type ('original', 'b2b', or 'b2c')
        max_steps: Maximum number of steps per episode
        auth: Optional authentication dict
    """
    super().__init__(max_steps=max_steps)
    self.org_type = org_type

    logger.info(f"Initializing CRMEnvironment with org_type='{org_type}'")
    self.connector_actor = SalesforceConnectorProxy.remote(org_type=org_type, auth=auth)

    from agoge.schema.tools.crmarena_tool import CRMArenaToolSet

    self.toolset = CRMArenaToolSet(self.connector_actor)

    self.final_answer: str | None = None

    logger.info("CRMEnvironment initialized successfully")

get_tool_schemas() #

Return the CRM tool schemas.

Source code in src/agoge/environment/crmarena.py
213
214
215
def get_tool_schemas(self) -> list[dict]:
    """Return the CRM tool schemas."""
    return self.toolset.schema

Environment #

Base class for agent-environment interaction following a gym-like interface.

This class provides two implementation patterns for subclasses:

Pattern 1 (Hook-based): Implement _reset_impl() and _step_impl() hooks. The base class handles episode lifecycle management, step counting, and max_steps enforcement automatically.

Pattern 2 (Override): Override reset() and step() directly for full control. Useful when you need custom lifecycle management or compatibility with existing code.

Subclasses should also set self.toolset or override get_tool_schemas() to expose available tools to the agent.

Attributes:

Name Type Description
max_steps int | None

Maximum number of steps allowed per episode (None = unlimited).

step_count int

Number of steps taken in the current episode.

current_state Any

Environment-specific state (defined by subclasses).

toolset ToolSet | None

Optional ToolSet instance for tool schema generation.

episode_active bool

Whether an episode is currently running.

current_task Task | None

Task associated with the current episode.

last_observation Chat | None

Most recent Chat observation from the environment.

last_action AssistantMessage | None

Most recent AssistantMessage action from the agent.

Example
class MyEnv(Environment):
    def __init__(self):
        super().__init__(max_steps=100)
        self.toolset = MyToolSet()

    async def _reset_impl(self, task, **kwargs):
        return Chat(messages=[UserMessage(content="Start")])

    async def _step_impl(self, action):
        reward = 1.0 if action.tool_calls else 0.0
        done = False
        return Chat(messages=[UserMessage(content="OK")]), reward, done
Source code in src/agoge/environment/environment.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
class Environment:
    """Base class for agent-environment interaction following a gym-like interface.

    This class provides two implementation patterns for subclasses:

    **Pattern 1 (Hook-based):** Implement ``_reset_impl()`` and ``_step_impl()`` hooks.
    The base class handles episode lifecycle management, step counting, and max_steps
    enforcement automatically.

    **Pattern 2 (Override):** Override ``reset()`` and ``step()`` directly for full control.
    Useful when you need custom lifecycle management or compatibility with existing code.

    Subclasses should also set ``self.toolset`` or override ``get_tool_schemas()`` to
    expose available tools to the agent.

    Attributes:
        max_steps: Maximum number of steps allowed per episode (None = unlimited).
        step_count: Number of steps taken in the current episode.
        current_state: Environment-specific state (defined by subclasses).
        toolset: Optional ToolSet instance for tool schema generation.
        episode_active: Whether an episode is currently running.
        current_task: Task associated with the current episode.
        last_observation: Most recent Chat observation from the environment.
        last_action: Most recent AssistantMessage action from the agent.

    Example:
        ```python
        class MyEnv(Environment):
            def __init__(self):
                super().__init__(max_steps=100)
                self.toolset = MyToolSet()

            async def _reset_impl(self, task, **kwargs):
                return Chat(messages=[UserMessage(content="Start")])

            async def _step_impl(self, action):
                reward = 1.0 if action.tool_calls else 0.0
                done = False
                return Chat(messages=[UserMessage(content="OK")]), reward, done
        ```
    """

    def __init__(self, max_steps: int | None = None):
        """Initialize persistent environment state.

        Args:
            max_steps: Optional hard limit on the number of steps per episode.
                When provided, the base class will automatically flag episodes
                as done after this many calls to :meth:`step`. Use ``None`` for
                unlimited steps (default).

        Raises:
            ValueError: If max_steps is provided but less than 1.
        """
        if max_steps is not None and max_steps < 1:
            raise ValueError("max_steps must be None or >= 1")

        self.max_steps: int | None = max_steps
        self.step_count: int = 0
        self.current_state: Any = None

        self.toolset: ToolSet | None = None
        self._current_task: Task | None = None
        self._last_observation: Chat | None = None
        self._last_action: AssistantMessage | None = None
        self._episode_active: bool = False
        self._episode_id: str | None = None

    @property
    def episode_active(self) -> bool:
        """Whether an episode is currently active.

        Returns:
            bool: True if :meth:`reset` was called and the episode has not
                terminated. False before reset or after episode ends.
        """
        return self._episode_active

    @property
    def current_task(self) -> Task | None:
        """Get the task associated with the current episode.

        Returns:
            Task | None: The Task instance passed to :meth:`reset`, or None
                if no episode is active.
        """
        return self._current_task

    @property
    def last_observation(self) -> Chat | None:
        """Get the most recent observation from the environment.

        Returns:
            Chat | None: The last Chat observation returned by :meth:`reset`
                or :meth:`step`, or None if no episode has started.
        """
        return self._last_observation

    @property
    def last_action(self) -> AssistantMessage | None:
        """Get the most recent action from the agent.

        Returns:
            AssistantMessage | None: The last action passed to :meth:`step`,
                or None if no action has been taken yet.
        """
        return self._last_action

    @property
    def step_limit_reached(self) -> bool:
        """Check if the step limit has been reached.

        Returns:
            bool: True if max_steps is set and step_count has reached or
                exceeded it, False otherwise.
        """
        return self.max_steps is not None and self.step_count >= self.max_steps

    def get_episode_id(self) -> str | None:
        """Get the unique identifier for the current episode.

        Returns:
            str | None: The episode UUID, or None if no episode has started.
        """
        return self._episode_id

    def get_tool_schemas(self) -> list[dict[str, Any]]:
        """Get tool schemas available to the agent.

        By default, extracts schemas from ``self.toolset.schema``. Subclasses
        can override this method to provide custom tool schemas.

        Returns:
            list[dict[str, Any]]: List of tool schema dictionaries following
                the OpenAI function calling format. Returns empty list if
                no toolset is configured.

        Note:
            Returns shallow copies of schemas to prevent accidental mutation.
        """
        if self.toolset is None:
            return []
        # Defensive: accept any Sequence and copy out as list[dict[str, Any]]
        try:
            return [dict(x) for x in self.toolset.schema]  # shallow copy for safety
        except AttributeError:
            return []

    # ----------------------------
    # Episode lifecycle
    # ----------------------------
    async def reset(self, task: Task, **reset_kwargs: Any) -> Chat:
        """Start a new episode and return the initial observation.

        This method initializes episode state (step_count, active flag, task)
        and delegates to :meth:`_reset_impl` for environment-specific setup.
        Subclasses can override this method directly for full control, or
        implement :meth:`_reset_impl` to use the default lifecycle management.

        Args:
            task: Task instance defining what the agent should accomplish.
            **reset_kwargs: Optional keyword arguments passed to :meth:`_reset_impl`.
                Common examples: seed, difficulty, scenario, wait_seconds.

        Returns:
            Chat: Initial observation before the agent takes any action.
                Typically includes task instructions and initial state (e.g.,
                screenshot, system message).

        Raises:
            TypeError: If the returned observation is not a Chat instance.

        Note:
            After calling reset(), ``episode_active`` will be True and
            ``step_count`` will be 0.
        """
        self.step_count = 0
        self._current_task = task
        self._episode_active = True
        self._episode_id = str(uuid.uuid4())

        chat = await self._reset_impl(task, **reset_kwargs)
        self._ensure_chat(chat, method_name="reset")

        self._last_observation = chat
        self._last_action = None
        return chat

    async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Advance the environment by one turn.

        Processes the agent's action, delegates to :meth:`_step_impl` for
        environment-specific logic, and handles step counting and termination.
        Subclasses can override this method directly for full control, or
        implement :meth:`_step_impl` to use the default lifecycle management.

        Args:
            action: Agent's action, typically an AssistantMessage containing
                tool calls to execute in the environment.

        Returns:
            tuple[Chat, float, bool]: A tuple containing:
                - **chat**: Observation after action execution (e.g., tool
                  results, screenshot). Should only contain new messages,
                  not the full conversation history.
                - **reward**: Scalar feedback for the agent. None is coerced
                  to 0.0. Final reward often computed by evaluators.
                - **done**: Whether the episode has terminated. Automatically
                  set to True if max_steps is reached.

        Raises:
            RuntimeError: If called before :meth:`reset`.
            TypeError: If the returned observation is not a Chat instance.

        Note:
            - Increments step_count before calling :meth:`_step_impl`
            - Sets episode_active to False when done=True
            - Enforces max_steps limit by forcing done=True when reached
        """
        if not self._episode_active:
            raise RuntimeError("Environment must be reset before calling step().")

        self.step_count += 1
        self._last_action = action

        chat, reward, done = await self._step_impl(action)
        self._ensure_chat(chat, method_name="step")

        # Normalize outputs
        done = bool(done)
        reward_value = float(reward) if reward is not None else 0.0

        # Enforce step cap
        if not done and self.max_steps is not None and self.step_count >= self.max_steps:
            done = True

        self._last_observation = chat
        if done:
            self._episode_active = False

        return chat, reward_value, done

    async def _reset_impl(self, task: Task, **reset_kwargs: Any) -> Chat:
        """Environment-specific reset logic (hook for subclasses).

        This hook is called by :meth:`reset` after initializing episode state.
        Implement this method to set up environment-specific state and return
        the initial observation. The base class handles step_count initialization,
        episode activation, and storing the observation.

        Args:
            task: Task instance defining the episode goal.
            **reset_kwargs: Optional keyword arguments from :meth:`reset`.

        Returns:
            Chat: Initial observation to send to the agent.

        Raises:
            NotImplementedError: If not overridden by subclass (and :meth:`reset`
                is not overridden either).

        Note:
            Environments that override :meth:`reset` directly can ignore this hook.
        """
        raise NotImplementedError(f"{self.__class__.__name__} must implement _reset_impl or override reset().")

    async def _step_impl(self, action: AssistantMessage) -> tuple[Chat, SupportsFloat | None, bool]:
        """Environment-specific step logic (hook for subclasses).

        This hook is called by :meth:`step` after incrementing step_count.
        Implement this method to execute the action and compute the next
        observation, reward, and termination flag. The base class handles
        step counting, max_steps enforcement, and episode deactivation.

        Args:
            action: Agent's action (typically containing tool calls).

        Returns:
            tuple[Chat, SupportsFloat | None, bool]: A tuple containing:
                - **chat**: New observation after executing the action
                - **reward**: Reward signal (None is coerced to 0.0)
                - **done**: Whether the episode should terminate

        Raises:
            NotImplementedError: If not overridden by subclass (and :meth:`step`
                is not overridden either).

        Note:
            Environments that override :meth:`step` directly can ignore this hook.
        """
        raise NotImplementedError(f"{self.__class__.__name__} must implement _step_impl or override step().")

    def _ensure_chat(self, chat: Chat, *, method_name: str) -> None:
        """Validate that a method returned a Chat instance.

        Args:
            chat: Object to validate.
            method_name: Name of the method being validated (for error messages).

        Raises:
            TypeError: If chat is not a Chat instance.
        """
        if not isinstance(chat, Chat):
            raise TypeError(
                f"{self.__class__.__name__}.{method_name} must return a Chat instance (got {type(chat)!r})."
            )

current_task property #

Get the task associated with the current episode.

Returns:

Type Description
Task | None

Task | None: The Task instance passed to :meth:reset, or None if no episode is active.

episode_active property #

Whether an episode is currently active.

Returns:

Name Type Description
bool bool

True if :meth:reset was called and the episode has not terminated. False before reset or after episode ends.

last_action property #

Get the most recent action from the agent.

Returns:

Type Description
AssistantMessage | None

AssistantMessage | None: The last action passed to :meth:step, or None if no action has been taken yet.

last_observation property #

Get the most recent observation from the environment.

Returns:

Type Description
Chat | None

Chat | None: The last Chat observation returned by :meth:reset or :meth:step, or None if no episode has started.

step_limit_reached property #

Check if the step limit has been reached.

Returns:

Name Type Description
bool bool

True if max_steps is set and step_count has reached or exceeded it, False otherwise.

__init__(max_steps=None) #

Initialize persistent environment state.

Parameters:

Name Type Description Default
max_steps int | None

Optional hard limit on the number of steps per episode. When provided, the base class will automatically flag episodes as done after this many calls to :meth:step. Use None for unlimited steps (default).

None

Raises:

Type Description
ValueError

If max_steps is provided but less than 1.

Source code in src/agoge/environment/environment.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(self, max_steps: int | None = None):
    """Initialize persistent environment state.

    Args:
        max_steps: Optional hard limit on the number of steps per episode.
            When provided, the base class will automatically flag episodes
            as done after this many calls to :meth:`step`. Use ``None`` for
            unlimited steps (default).

    Raises:
        ValueError: If max_steps is provided but less than 1.
    """
    if max_steps is not None and max_steps < 1:
        raise ValueError("max_steps must be None or >= 1")

    self.max_steps: int | None = max_steps
    self.step_count: int = 0
    self.current_state: Any = None

    self.toolset: ToolSet | None = None
    self._current_task: Task | None = None
    self._last_observation: Chat | None = None
    self._last_action: AssistantMessage | None = None
    self._episode_active: bool = False
    self._episode_id: str | None = None

get_episode_id() #

Get the unique identifier for the current episode.

Returns:

Type Description
str | None

str | None: The episode UUID, or None if no episode has started.

Source code in src/agoge/environment/environment.py
129
130
131
132
133
134
135
def get_episode_id(self) -> str | None:
    """Get the unique identifier for the current episode.

    Returns:
        str | None: The episode UUID, or None if no episode has started.
    """
    return self._episode_id

get_tool_schemas() #

Get tool schemas available to the agent.

By default, extracts schemas from self.toolset.schema. Subclasses can override this method to provide custom tool schemas.

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of tool schema dictionaries following the OpenAI function calling format. Returns empty list if no toolset is configured.

Note

Returns shallow copies of schemas to prevent accidental mutation.

Source code in src/agoge/environment/environment.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def get_tool_schemas(self) -> list[dict[str, Any]]:
    """Get tool schemas available to the agent.

    By default, extracts schemas from ``self.toolset.schema``. Subclasses
    can override this method to provide custom tool schemas.

    Returns:
        list[dict[str, Any]]: List of tool schema dictionaries following
            the OpenAI function calling format. Returns empty list if
            no toolset is configured.

    Note:
        Returns shallow copies of schemas to prevent accidental mutation.
    """
    if self.toolset is None:
        return []
    # Defensive: accept any Sequence and copy out as list[dict[str, Any]]
    try:
        return [dict(x) for x in self.toolset.schema]  # shallow copy for safety
    except AttributeError:
        return []

reset(task, **reset_kwargs) async #

Start a new episode and return the initial observation.

This method initializes episode state (step_count, active flag, task) and delegates to :meth:_reset_impl for environment-specific setup. Subclasses can override this method directly for full control, or implement :meth:_reset_impl to use the default lifecycle management.

Parameters:

Name Type Description Default
task Task

Task instance defining what the agent should accomplish.

required
**reset_kwargs Any

Optional keyword arguments passed to :meth:_reset_impl. Common examples: seed, difficulty, scenario, wait_seconds.

{}

Returns:

Name Type Description
Chat Chat

Initial observation before the agent takes any action. Typically includes task instructions and initial state (e.g., screenshot, system message).

Raises:

Type Description
TypeError

If the returned observation is not a Chat instance.

Note

After calling reset(), episode_active will be True and step_count will be 0.

Source code in src/agoge/environment/environment.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
async def reset(self, task: Task, **reset_kwargs: Any) -> Chat:
    """Start a new episode and return the initial observation.

    This method initializes episode state (step_count, active flag, task)
    and delegates to :meth:`_reset_impl` for environment-specific setup.
    Subclasses can override this method directly for full control, or
    implement :meth:`_reset_impl` to use the default lifecycle management.

    Args:
        task: Task instance defining what the agent should accomplish.
        **reset_kwargs: Optional keyword arguments passed to :meth:`_reset_impl`.
            Common examples: seed, difficulty, scenario, wait_seconds.

    Returns:
        Chat: Initial observation before the agent takes any action.
            Typically includes task instructions and initial state (e.g.,
            screenshot, system message).

    Raises:
        TypeError: If the returned observation is not a Chat instance.

    Note:
        After calling reset(), ``episode_active`` will be True and
        ``step_count`` will be 0.
    """
    self.step_count = 0
    self._current_task = task
    self._episode_active = True
    self._episode_id = str(uuid.uuid4())

    chat = await self._reset_impl(task, **reset_kwargs)
    self._ensure_chat(chat, method_name="reset")

    self._last_observation = chat
    self._last_action = None
    return chat

step(action) async #

Advance the environment by one turn.

Processes the agent's action, delegates to :meth:_step_impl for environment-specific logic, and handles step counting and termination. Subclasses can override this method directly for full control, or implement :meth:_step_impl to use the default lifecycle management.

Parameters:

Name Type Description Default
action AssistantMessage

Agent's action, typically an AssistantMessage containing tool calls to execute in the environment.

required

Returns:

Type Description
tuple[Chat, float, bool]

tuple[Chat, float, bool]: A tuple containing: - chat: Observation after action execution (e.g., tool results, screenshot). Should only contain new messages, not the full conversation history. - reward: Scalar feedback for the agent. None is coerced to 0.0. Final reward often computed by evaluators. - done: Whether the episode has terminated. Automatically set to True if max_steps is reached.

Raises:

Type Description
RuntimeError

If called before :meth:reset.

TypeError

If the returned observation is not a Chat instance.

Note
  • Increments step_count before calling :meth:_step_impl
  • Sets episode_active to False when done=True
  • Enforces max_steps limit by forcing done=True when reached
Source code in src/agoge/environment/environment.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
    """Advance the environment by one turn.

    Processes the agent's action, delegates to :meth:`_step_impl` for
    environment-specific logic, and handles step counting and termination.
    Subclasses can override this method directly for full control, or
    implement :meth:`_step_impl` to use the default lifecycle management.

    Args:
        action: Agent's action, typically an AssistantMessage containing
            tool calls to execute in the environment.

    Returns:
        tuple[Chat, float, bool]: A tuple containing:
            - **chat**: Observation after action execution (e.g., tool
              results, screenshot). Should only contain new messages,
              not the full conversation history.
            - **reward**: Scalar feedback for the agent. None is coerced
              to 0.0. Final reward often computed by evaluators.
            - **done**: Whether the episode has terminated. Automatically
              set to True if max_steps is reached.

    Raises:
        RuntimeError: If called before :meth:`reset`.
        TypeError: If the returned observation is not a Chat instance.

    Note:
        - Increments step_count before calling :meth:`_step_impl`
        - Sets episode_active to False when done=True
        - Enforces max_steps limit by forcing done=True when reached
    """
    if not self._episode_active:
        raise RuntimeError("Environment must be reset before calling step().")

    self.step_count += 1
    self._last_action = action

    chat, reward, done = await self._step_impl(action)
    self._ensure_chat(chat, method_name="step")

    # Normalize outputs
    done = bool(done)
    reward_value = float(reward) if reward is not None else 0.0

    # Enforce step cap
    if not done and self.max_steps is not None and self.step_count >= self.max_steps:
        done = True

    self._last_observation = chat
    if done:
        self._episode_active = False

    return chat, reward_value, done

GuiEnvironment #

Bases: Environment

A GUI environment that reuses functionality from ComputerGuiEnvironment.

This environment provides GUI interaction capabilities similar to ComputerGuiEnvironment but extends from the base Environment class for compatibility with the training framework.

Source code in src/agoge/environment/gui_environment.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
class GuiEnvironment(Environment):
    """A GUI environment that reuses functionality from ComputerGuiEnvironment.

    This environment provides GUI interaction capabilities similar to ComputerGuiEnvironment
    but extends from the base Environment class for compatibility with the training framework.
    """

    def __init__(
        self,
        max_steps: int = 50,
        display_width: int = 1280,
        display_height: int = 768,
        sandbox_provider: Literal["clusterfudge", "e2b", "osworld_aws"] = "e2b",
        action_resolution: tuple[int, int] | None = None,
    ):
        """Initialize the GUI environment.

        Args:
            max_steps: Maximum number of steps before termination.
            display_width: Width of the GUI display in pixels.
            display_height: Height of the GUI display in pixels.
            sandbox_provider: The sandbox provider to use for the computer session.
            action_resolution: Optional logical coordinate system for agent actions (e.g., [1000, 1000] for Qwen-style).
                             Defaults to display dimensions if not provided.
        """
        super().__init__()
        self.computer_session: ClusterfudgeComputer | E2BComputer | OSWorldAWSComputer | None = None
        self.current_state = None
        self.step_count = 0
        self.max_steps = max_steps  # Prevent infinite loops
        self.display_width = display_width
        self.display_height = display_height
        self.sandbox_provider = sandbox_provider
        self.action_resolution = (
            (display_width, display_height) if action_resolution is None else tuple(action_resolution)
        )

    async def _initialize_session(self, task: Task):
        """Initialize the computer session if not already done."""
        logger.info(
            f"Initializing {self.sandbox_provider} session for task: {task.inputs.get('instruction', 'N/A')[:100]}"
        )
        computer_kwargs = {
            "display_width": self.display_width,
            "display_height": self.display_height,
        }
        match self.sandbox_provider:
            case "e2b":
                self.computer_session = E2BComputer(setup_script=task.inputs.get("setup_script"), **computer_kwargs)
            case "clusterfudge":
                self.computer_session = ClusterfudgeComputer(**computer_kwargs)
            case "osworld_aws":
                self.computer_session = OSWorldAWSComputer(task=task, **computer_kwargs)
            case _:
                raise ValueError(f"Invalid sandbox provider for computer: {self.sandbox_provider}")
        await self.computer_session.initialize()
        logger.info(f"Successfully initialized {self.sandbox_provider} session")

    async def _get_screenshot(self) -> str:
        """Get a screenshot and return it as base64 encoded string."""
        if self.computer_session is None:
            # Return a placeholder if no session
            return ""

        image_bytes = await self.computer_session.screenshot()

        # Log image type and convert if PNG
        img = Image.open(io.BytesIO(image_bytes))

        if img.format == "PNG":
            # Convert to JPEG
            jpeg_buffer = io.BytesIO()
            img.convert("RGB").save(jpeg_buffer, format="JPEG", quality=85)
            image_bytes = jpeg_buffer.getvalue()

        return "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode("utf-8")

    @ray_remote_logger
    async def reset(self, task: Task, **reset_kwargs: dict) -> Chat:
        """Start a new episode and return the initial observation.

        Args:
            task: Task instance to complete in this episode
            **reset_kwargs: Optional keyword arguments that influence the
                initial state (e.g. ``seed=42``, ``scenario="button_click"``).

        Returns:
            Chat: The initial observation with system instructions and current state.
        """
        logger.info(f"Resetting GUI environment (max_steps={self.max_steps})")
        # Reset episode state
        self.step_count = 0
        self.current_state = "GUI Environment initialized"

        # Initialize session asynchronously
        await self._initialize_session(task)

        # Create toolset after session is initialized
        match self.sandbox_provider:
            case "e2b":
                from agoge.schema.tools import E2BComputerGuiToolSet  # avoid circular import

                self.computer_toolset = E2BComputerGuiToolSet(
                    computer_session=self.computer_session,
                    screen_width=self.display_width,
                    screen_height=self.display_height,
                    action_width=self.action_resolution[0],
                    action_height=self.action_resolution[1],
                )
            case "clusterfudge":
                from agoge.schema.tools import ClusterfudgeComputerGuiToolSet  # avoid circular import

                self.computer_toolset = ClusterfudgeComputerGuiToolSet(computer_session=self.computer_session)
            case "osworld_aws":
                from agoge.schema.tools import OSWorldAWSComputerGuiToolSet  # avoid circular import

                self.computer_toolset = OSWorldAWSComputerGuiToolSet(computer_session=self.computer_session)
            case _:
                raise ValueError(f"Invalid sandbox provider for computer: {self.sandbox_provider}")

        # Create initial task message
        task_msg = UserMessage(content=task.inputs["instruction"])

        # Get initial screenshot asynchronously
        try:
            wait_seconds = reset_kwargs.get("wait_seconds", 15 if self.sandbox_provider == "e2b" else 0)
            if wait_seconds > 0:
                await asyncio.sleep(wait_seconds)
            screenshot = await self._get_screenshot()
            if screenshot:
                user_msg = UserMessage(
                    role="user", label=MessageLabel.OBSERVATION, content=[ImageURLPart(image_url={"url": screenshot})]
                )
            else:
                user_msg = UserMessage(
                    content=f"Current GUI state: {self.current_state}\n\n"
                    "What would you like to do? You can click buttons, type text, or navigate menus.",
                    label=MessageLabel.OBSERVATION,
                )
        except Exception:
            # Fallback if screenshot fails
            logger.warning("Failed to get initial screenshot, falling back to text observation", exc_info=True)
            user_msg = UserMessage(
                content=f"Current GUI state: {self.current_state}\n\n"
                "What would you like to do? You can click buttons, type text, or navigate menus.",
                label=MessageLabel.OBSERVATION,
            )

        return Chat(messages=[task_msg, user_msg])

    def get_tool_schemas(self) -> list[dict]:
        return self.computer_toolset.schema

    @ray_remote_logger
    async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Advance the environment by one turn.

        Args:
            action: The agent's next AssistantMessage (usually with tool calls).

        Returns:
            tuple[Chat, float, bool]:
                * chat - Observation after the action has been applied.
                * reward - Reward is 0 for every step. Only in os-world the reward is
                calculated at the last steps through the validatiors.
                * done - True if the episode terminated at this step.

        Raises:
            RuntimeError: If called before reset.
        """
        if self.current_state is None:
            raise RuntimeError("Environment must be reset before calling step.")

        self.step_count += 1
        done = self.step_count >= self.max_steps
        termination_message = None

        logger.debug(f"Step {self.step_count}/{self.max_steps}")

        # Extract tool calls if present
        tool_calls = action.tool_calls or []
        tool_messages = []

        if tool_calls:
            tool_names = [tc.function.get("name", "unknown") for tc in tool_calls if tc.function]
            logger.debug(f"Executing {len(tool_calls)} tool call(s): {', '.join(tool_names)}")

        # Execute tool calls
        for tool_call in tool_calls:
            if tool_call.function:
                try:
                    tool_response, is_terminate = await self._execute_tool_call(tool_call, self.computer_toolset)
                    tool_messages.append(ToolMessage(content=tool_response, tool_call_id=tool_call.id))
                    if is_terminate:
                        done = True
                        termination_message = tool_response
                        logger.info(f"Episode terminated by agent at step {self.step_count}")
                except Exception:
                    logger.exception("Error executing tool call")
                    tool_messages.append(
                        ToolMessage(
                            content=(
                                f"Could not execute {tool_call.function.get('name', 'unknown')}: error during execution"
                            ),
                            tool_call_id=tool_call.id,
                        )
                    )

        # Update state based on actions
        if tool_calls:
            self.current_state = f"Action completed. Step {self.step_count}/{self.max_steps}"
        else:
            self.current_state = f"No action taken. Step {self.step_count}/{self.max_steps}"

        # Create the observation with screenshot
        screenshot = await self._get_screenshot()
        if screenshot:
            # Create a message with the screenshot as the main content
            user_msg = UserMessage(
                role="user", label=MessageLabel.OBSERVATION, content=[ImageURLPart(image_url={"url": screenshot})]
            )
        else:
            # Fallback to text content if no screenshot available
            user_msg = UserMessage(
                role="user", label=MessageLabel.OBSERVATION, content=f"GUI State: {self.current_state}"
            )

        observation = Chat(messages=[*tool_messages, user_msg])

        reward = 0
        # Clean up sandbox when episode is done
        if done:
            if self.step_count >= self.max_steps and not termination_message:
                logger.info(f"Episode reached max steps ({self.max_steps})")

            # Run evaluation while sandbox is still active
            if hasattr(self.computer_session, "evaluate"):
                logger.warning(
                    "Using env reward for last step, this overwrites the reward calculator, will refactor soon."
                )
                try:
                    reward = await self.computer_session.evaluate(termination_message)
                    logger.info(f"Episode evaluation completed: reward={reward}")
                except Exception:
                    logger.exception("Failed to evaluate task")
                    reward = 0.0

            await self.cleanup()

        return observation, reward, done

    async def cleanup(self):
        """Clean up the computer session and close the sandbox."""
        if self.computer_session:
            try:
                await self.computer_session.close(None, None, None)
                logger.info("Sandbox closed successfully")
            except Exception:
                logger.exception("Error closing sandbox")
            finally:
                self.computer_session = None

    async def _execute_tool_call(
        self,
        tool_call,
        computer_tool: ("ClusterfudgeComputerGuiToolSet | E2BComputerGuiToolSet | OSWorldAWSComputerGuiToolSet"),
    ) -> tuple[str | None, bool]:
        try:
            func_name = tool_call.function.get("name", "").lower()
            func_args = tool_call.function.get("arguments", {})
            if isinstance(func_args, str):
                func_args = json.loads(func_args)

            response = None
            is_terminate = False

            if func_name == "left_click":
                pos = func_args.get("position")
                response = await computer_tool.left_click(position=tuple(pos) if pos else None)
            elif func_name == "right_click":
                pos = func_args.get("position")
                response = await computer_tool.right_click(position=tuple(pos) if pos else None)
            elif func_name == "middle_click":
                pos = func_args.get("position")
                response = await computer_tool.middle_click(position=tuple(pos) if pos else None)
            elif func_name == "double_click":
                pos = func_args.get("position")
                response = await computer_tool.double_click(position=tuple(pos) if pos else None)
            elif func_name == "triple_click":
                pos = func_args.get("position")
                response = await computer_tool.triple_click(position=tuple(pos) if pos else None)
            elif func_name == "mouse_move":
                pos = func_args.get("position")
                response = await computer_tool.mouse_move(position=tuple(pos) if pos else None)
            elif func_name == "left_click_drag":
                pos = func_args.get("position")
                response = await computer_tool.left_click_drag(position=tuple(pos) if pos else None)
            elif func_name == "type":
                response = await computer_tool.type(text=func_args.get("text", ""))
            elif func_name == "key":
                response = await computer_tool.key(key=func_args.get("key", ""))
            elif func_name == "scroll":
                pos = func_args.get("position")
                response = await computer_tool.scroll(
                    position=tuple(pos) if pos else None,
                    scroll_direction=func_args.get("scroll_direction", "down"),
                    scroll_amount=func_args.get("scroll_amount", 100),
                )
            elif func_name == "wait":
                response = await computer_tool.wait(time=func_args.get("time", 1))
            elif func_name == "terminate":
                is_terminate = True
                response = await computer_tool.terminate(status=func_args.get("status", 1))
            else:
                return f"Could not execute {func_name}: invalid action format", False

            if response is None:
                return f"Executed {func_name} successfully", is_terminate
        except Exception:
            logger.exception("Error executing tool call")
            return ("Error executing action", False)
        else:
            return response, is_terminate

__init__(max_steps=50, display_width=1280, display_height=768, sandbox_provider='e2b', action_resolution=None) #

Initialize the GUI environment.

Parameters:

Name Type Description Default
max_steps int

Maximum number of steps before termination.

50
display_width int

Width of the GUI display in pixels.

1280
display_height int

Height of the GUI display in pixels.

768
sandbox_provider Literal['clusterfudge', 'e2b', 'osworld_aws']

The sandbox provider to use for the computer session.

'e2b'
action_resolution tuple[int, int] | None

Optional logical coordinate system for agent actions (e.g., [1000, 1000] for Qwen-style). Defaults to display dimensions if not provided.

None
Source code in src/agoge/environment/gui_environment.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    max_steps: int = 50,
    display_width: int = 1280,
    display_height: int = 768,
    sandbox_provider: Literal["clusterfudge", "e2b", "osworld_aws"] = "e2b",
    action_resolution: tuple[int, int] | None = None,
):
    """Initialize the GUI environment.

    Args:
        max_steps: Maximum number of steps before termination.
        display_width: Width of the GUI display in pixels.
        display_height: Height of the GUI display in pixels.
        sandbox_provider: The sandbox provider to use for the computer session.
        action_resolution: Optional logical coordinate system for agent actions (e.g., [1000, 1000] for Qwen-style).
                         Defaults to display dimensions if not provided.
    """
    super().__init__()
    self.computer_session: ClusterfudgeComputer | E2BComputer | OSWorldAWSComputer | None = None
    self.current_state = None
    self.step_count = 0
    self.max_steps = max_steps  # Prevent infinite loops
    self.display_width = display_width
    self.display_height = display_height
    self.sandbox_provider = sandbox_provider
    self.action_resolution = (
        (display_width, display_height) if action_resolution is None else tuple(action_resolution)
    )

cleanup() async #

Clean up the computer session and close the sandbox.

Source code in src/agoge/environment/gui_environment.py
279
280
281
282
283
284
285
286
287
288
async def cleanup(self):
    """Clean up the computer session and close the sandbox."""
    if self.computer_session:
        try:
            await self.computer_session.close(None, None, None)
            logger.info("Sandbox closed successfully")
        except Exception:
            logger.exception("Error closing sandbox")
        finally:
            self.computer_session = None

reset(task, **reset_kwargs) async #

Start a new episode and return the initial observation.

Parameters:

Name Type Description Default
task Task

Task instance to complete in this episode

required
**reset_kwargs dict

Optional keyword arguments that influence the initial state (e.g. seed=42, scenario="button_click").

{}

Returns:

Name Type Description
Chat Chat

The initial observation with system instructions and current state.

Source code in src/agoge/environment/gui_environment.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@ray_remote_logger
async def reset(self, task: Task, **reset_kwargs: dict) -> Chat:
    """Start a new episode and return the initial observation.

    Args:
        task: Task instance to complete in this episode
        **reset_kwargs: Optional keyword arguments that influence the
            initial state (e.g. ``seed=42``, ``scenario="button_click"``).

    Returns:
        Chat: The initial observation with system instructions and current state.
    """
    logger.info(f"Resetting GUI environment (max_steps={self.max_steps})")
    # Reset episode state
    self.step_count = 0
    self.current_state = "GUI Environment initialized"

    # Initialize session asynchronously
    await self._initialize_session(task)

    # Create toolset after session is initialized
    match self.sandbox_provider:
        case "e2b":
            from agoge.schema.tools import E2BComputerGuiToolSet  # avoid circular import

            self.computer_toolset = E2BComputerGuiToolSet(
                computer_session=self.computer_session,
                screen_width=self.display_width,
                screen_height=self.display_height,
                action_width=self.action_resolution[0],
                action_height=self.action_resolution[1],
            )
        case "clusterfudge":
            from agoge.schema.tools import ClusterfudgeComputerGuiToolSet  # avoid circular import

            self.computer_toolset = ClusterfudgeComputerGuiToolSet(computer_session=self.computer_session)
        case "osworld_aws":
            from agoge.schema.tools import OSWorldAWSComputerGuiToolSet  # avoid circular import

            self.computer_toolset = OSWorldAWSComputerGuiToolSet(computer_session=self.computer_session)
        case _:
            raise ValueError(f"Invalid sandbox provider for computer: {self.sandbox_provider}")

    # Create initial task message
    task_msg = UserMessage(content=task.inputs["instruction"])

    # Get initial screenshot asynchronously
    try:
        wait_seconds = reset_kwargs.get("wait_seconds", 15 if self.sandbox_provider == "e2b" else 0)
        if wait_seconds > 0:
            await asyncio.sleep(wait_seconds)
        screenshot = await self._get_screenshot()
        if screenshot:
            user_msg = UserMessage(
                role="user", label=MessageLabel.OBSERVATION, content=[ImageURLPart(image_url={"url": screenshot})]
            )
        else:
            user_msg = UserMessage(
                content=f"Current GUI state: {self.current_state}\n\n"
                "What would you like to do? You can click buttons, type text, or navigate menus.",
                label=MessageLabel.OBSERVATION,
            )
    except Exception:
        # Fallback if screenshot fails
        logger.warning("Failed to get initial screenshot, falling back to text observation", exc_info=True)
        user_msg = UserMessage(
            content=f"Current GUI state: {self.current_state}\n\n"
            "What would you like to do? You can click buttons, type text, or navigate menus.",
            label=MessageLabel.OBSERVATION,
        )

    return Chat(messages=[task_msg, user_msg])

step(action) async #

Advance the environment by one turn.

Parameters:

Name Type Description Default
action AssistantMessage

The agent's next AssistantMessage (usually with tool calls).

required

Returns:

Type Description
tuple[Chat, float, bool]

tuple[Chat, float, bool]: * chat - Observation after the action has been applied. * reward - Reward is 0 for every step. Only in os-world the reward is calculated at the last steps through the validatiors. * done - True if the episode terminated at this step.

Raises:

Type Description
RuntimeError

If called before reset.

Source code in src/agoge/environment/gui_environment.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@ray_remote_logger
async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
    """Advance the environment by one turn.

    Args:
        action: The agent's next AssistantMessage (usually with tool calls).

    Returns:
        tuple[Chat, float, bool]:
            * chat - Observation after the action has been applied.
            * reward - Reward is 0 for every step. Only in os-world the reward is
            calculated at the last steps through the validatiors.
            * done - True if the episode terminated at this step.

    Raises:
        RuntimeError: If called before reset.
    """
    if self.current_state is None:
        raise RuntimeError("Environment must be reset before calling step.")

    self.step_count += 1
    done = self.step_count >= self.max_steps
    termination_message = None

    logger.debug(f"Step {self.step_count}/{self.max_steps}")

    # Extract tool calls if present
    tool_calls = action.tool_calls or []
    tool_messages = []

    if tool_calls:
        tool_names = [tc.function.get("name", "unknown") for tc in tool_calls if tc.function]
        logger.debug(f"Executing {len(tool_calls)} tool call(s): {', '.join(tool_names)}")

    # Execute tool calls
    for tool_call in tool_calls:
        if tool_call.function:
            try:
                tool_response, is_terminate = await self._execute_tool_call(tool_call, self.computer_toolset)
                tool_messages.append(ToolMessage(content=tool_response, tool_call_id=tool_call.id))
                if is_terminate:
                    done = True
                    termination_message = tool_response
                    logger.info(f"Episode terminated by agent at step {self.step_count}")
            except Exception:
                logger.exception("Error executing tool call")
                tool_messages.append(
                    ToolMessage(
                        content=(
                            f"Could not execute {tool_call.function.get('name', 'unknown')}: error during execution"
                        ),
                        tool_call_id=tool_call.id,
                    )
                )

    # Update state based on actions
    if tool_calls:
        self.current_state = f"Action completed. Step {self.step_count}/{self.max_steps}"
    else:
        self.current_state = f"No action taken. Step {self.step_count}/{self.max_steps}"

    # Create the observation with screenshot
    screenshot = await self._get_screenshot()
    if screenshot:
        # Create a message with the screenshot as the main content
        user_msg = UserMessage(
            role="user", label=MessageLabel.OBSERVATION, content=[ImageURLPart(image_url={"url": screenshot})]
        )
    else:
        # Fallback to text content if no screenshot available
        user_msg = UserMessage(
            role="user", label=MessageLabel.OBSERVATION, content=f"GUI State: {self.current_state}"
        )

    observation = Chat(messages=[*tool_messages, user_msg])

    reward = 0
    # Clean up sandbox when episode is done
    if done:
        if self.step_count >= self.max_steps and not termination_message:
            logger.info(f"Episode reached max steps ({self.max_steps})")

        # Run evaluation while sandbox is still active
        if hasattr(self.computer_session, "evaluate"):
            logger.warning(
                "Using env reward for last step, this overwrites the reward calculator, will refactor soon."
            )
            try:
                reward = await self.computer_session.evaluate(termination_message)
                logger.info(f"Episode evaluation completed: reward={reward}")
            except Exception:
                logger.exception("Failed to evaluate task")
                reward = 0.0

        await self.cleanup()

    return observation, reward, done

OSWorld #

Bases: Environment

Simple wrapper for OSWorld.

Source code in src/agoge/environment/osworld.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
class OSWorld(Environment):
    """Simple wrapper for OSWorld."""

    def __init__(
        self,
        display_width: int,
        display_height: int,
        vm_data_source: str | Path,
        vm_cache_dir: str | Path,
        action_resolution: tuple[int, int] | None = None,
        max_steps: int = 50,
        step_pause_seconds: int = 2,
        startup_wait_seconds: int = 60,
        traj_output_dir: str | Path | None = None,
        record_episodes: bool = False,
    ):
        super().__init__(max_steps=max_steps)
        self.display_width = display_width
        self.display_height = display_height
        self.action_resolution = (
            (display_width, display_height) if action_resolution is None else tuple(action_resolution)
        )
        self.record_episodes = record_episodes

        # Only set up recording directory if recording is enabled
        if record_episodes:
            if traj_output_dir is None:
                raise ValueError("traj_output_dir required when record_episodes=True")
            self._traj_output_dir = Path(traj_output_dir)
            self._traj_output_dir.mkdir(parents=True, exist_ok=True)
        else:
            self._traj_output_dir = None

        # Ensure VM data is available locally on this worker node
        cache_path = self._ensure_vm_data(Path(vm_data_source), Path(vm_cache_dir))

        # Store actor kwargs for lazy initialization
        self._actor_kwargs = {
            "provider_name": "docker",
            "action_space": "pyautogui",
            "screen_size": (display_width, display_height),
            "headless": True,
            "require_a11y_tree": False,
            "require_terminal": False,
            "os_type": "Ubuntu",
            "enable_proxy": False,
            "vm_cache_dir": cache_path,
        }
        self.actor = None  # Lazy initialization on first reset

        # OSWorld docker provider uses 1920x1080 native resolution regardless of requested screen_size
        # Toolset scales coordinates from action_resolution (what agent sees) to
        # native_resolution (what OSWorld expects)
        self.toolset = PyAutoGui(
            native_resolution=(1920, 1080),
            action_resolution=self.action_resolution,
        )

        self.step_pause_seconds = step_pause_seconds
        self.startup_wait_seconds = startup_wait_seconds

    def _ensure_vm_data(self, source_dir: Path, cache_dir: Path) -> Path:
        """Ensure VM data is downloaded to local cache (per-node, with lock).

        Args:
            source_dir: Source directory containing VM data (e.g., on GCS/NFS)
            cache_dir: Local cache directory on worker node (e.g., /scratch/...)

        Returns:
            Path to the local cache directory with VM data
        """
        import fsspec
        from filelock import FileLock

        lock_file = cache_dir.parent / "vm_download.lock"
        cache_dir.mkdir(parents=True, exist_ok=True)

        with FileLock(lock_file, timeout=600):
            # Check if VM data already exists
            if any(cache_dir.glob("*.qcow2")):
                logger.info(f"VM data already exists in {cache_dir}")
                return cache_dir

            logger.info(f"Syncing VM data from {source_dir} to {cache_dir}")

            # Use fsspec to handle GCS/NFS/local paths uniformly
            fs, source_path = fsspec.core.url_to_fs(str(source_dir))
            # Copy contents of source directory into cache_dir
            fs.get(f"{source_path}/*", str(cache_dir) + "/")

            logger.info(f"VM data synced to {cache_dir}")
            return cache_dir

    async def _reset_impl(self, task: Task, **reset_kwargs: Any) -> Chat:
        """Environment-specific reset logic for OSWorld.

        Args:
            task: Task instance defining the episode goal.
            **reset_kwargs: Optional keyword arguments (currently unused, logs warning).

        Returns:
            Chat: Initial observation containing task instruction and screenshot.
        """
        if reset_kwargs:
            logger.warning(f"Unused reset_kwargs: {reset_kwargs}")

        task_config = {
            "id": task.task_id,
            **task.inputs,
            "evaluator": task.eval_criteria["os-world-eval"],
        }
        config_steps = task.metadata.get("config")
        if config_steps:
            task_config["config"] = config_steps

        # Lazy actor creation - only create once per environment instance
        if self.actor is None:
            # Update enable_proxy based on task metadata
            enable_proxy = task.metadata.get("proxy", False)
            self._actor_kwargs["enable_proxy"] = enable_proxy
            logger.info(f"Creating DesktopEnvProxy actor with enable_proxy={enable_proxy}")

            # Force DesktopEnvProxy to run on the same node as this Runner
            current_node_id = ray.get_runtime_context().get_node_id()
            self.actor = DesktopEnvProxy.options(  # pyright: ignore
                scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                    node_id=current_node_id, soft=False
                )
            ).remote(**self._actor_kwargs)

        obs = await self.actor.call.remote("reset", task_config=task_config)

        logger.debug(
            "OSWorld requires KVM (nested virtualization) to be enabled. "
            "If you see KVM errors above, OSWorld will not function correctly. "
            "On GCP: enable nested virtualization by adding "
            "'advancedMachineFeatures.enableNestedVirtualization=true' to your instance config, "
            "then stop and restart the instance. "
            "Search: 'GCP enable nested virtualization' or 'KVM not available OSWorld docker'"
        )

        # Wait for desktop environment to fully initialize
        startup_wait_seconds = (1 + random.random()) * self.startup_wait_seconds
        logger.info(f"Waiting {startup_wait_seconds}s for desktop environment to initialize...")
        await asyncio.sleep(startup_wait_seconds)

        # Start recording if enabled
        if self.record_episodes:
            try:
                await self.actor.start_recording.remote()
                logger.info("Started screen recording")
            except Exception:
                logger.exception("Failed to start screen recording")

        # Get a fresh screenshot after waiting
        obs = await self.actor.call.remote("_get_obs")

        # Create instruction message
        instruction_msg = UserMessage(content=task.inputs["instruction"])

        # Get screenshot observation
        screenshot_msg = self._obs_to_message(obs)

        return Chat(messages=[instruction_msg, screenshot_msg])

    def get_tool_schemas(self) -> list[dict]:
        return self.toolset.schema

    async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Execute action and handle recording stop on episode end."""
        chat, reward, done = await super().step(action)

        # Stop recording when episode ends
        if done and self.record_episodes:
            await self._stop_recording()

        return chat, reward, done

    async def _step_impl(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        """Environment-specific step logic for OSWorld.

        Args:
            action: Agent's action containing tool calls to execute.

        Returns:
            tuple[Chat, float, bool]: Observation, reward, and done flag.
                The base class handles step counting and max_steps enforcement.
        """
        obs, reward, done = None, 0.0, False
        tool_messages = []

        for tool_call in action.tool_calls or []:
            func_name = tool_call.function["name"]
            func_args = tool_call.function.get("arguments", {})
            # TODO(FC): disgusting for compatibility reasons.
            # this should have been parsed in the agent.
            if isinstance(func_args, str):
                func_args = json.loads(func_args)

            # Handle special actions
            if func_name == "terminate":
                logger.info(f"Terminate called with status: {func_args.get('status', 'N/A')}")
            elif func_name == "screenshot":
                logger.warning(f"{func_name} tool called. Not part of toolset. Replacing with a wait")
                func_name, func_args = "wait", {"time": self.step_pause_seconds}

            # Try to execute the action, handling invalid actions gracefully
            try:
                # Get pyautogui command from toolset
                if not hasattr(self.toolset, func_name):
                    raise AttributeError(f"Tool '{func_name}' not found in toolset")

                method = getattr(self.toolset, func_name)
                pyautogui_cmd = await method(**func_args)

                # Log the generated command
                logger.info(f"pyautogui command: {pyautogui_cmd!s}")

                # Execute in OSWorld - gym interface returns (obs, reward, done, info)
                obs, reward, done, _ = await self.actor.call.remote(
                    "step", action=pyautogui_cmd, pause=self.step_pause_seconds
                )
                logger.debug(f"Step completed: done={done}, reward={reward}")

                # Create tool message for this action
                tool_messages.append(
                    ToolMessage(content=f"Executed {func_name} successfully", tool_call_id=tool_call.id)
                )

            except (AttributeError, TypeError) as e:
                # Invalid tool name or invalid parameters
                error_msg = f"Invalid action '{func_name}': {e}"
                logger.info(f"Action error: {error_msg}")
                tool_messages.append(ToolMessage(content=error_msg, tool_call_id=tool_call.id))
            except Exception as e:
                # Catch any other errors during execution
                error_msg = f"Action '{func_name}' failed: {type(e).__name__}: {e}"
                logger.warning(f"Unexpected action error: {error_msg}")
                tool_messages.append(ToolMessage(content=error_msg, tool_call_id=tool_call.id))

            # if environment has finished then stop with the extra actions
            if done:
                break

        # If no observation was captured (no tool calls?), get current screenshot
        obs = obs or await self.actor.call.remote("_get_obs")
        try:
            if done:
                logger.info("Episode done, evaluating trajectory")
                reward = await self.actor.call.remote("evaluate")
                logger.info(f"Evaluation reward: {reward}")
        except Exception:
            logger.exception("Evaluation failed. Not overwriting reward.")

        screenshot_msg = self._obs_to_message(obs)
        return Chat(messages=[*tool_messages, screenshot_msg]), float(reward), done

    async def _stop_recording(self) -> None:
        """Stop recording and save to traj_output_dir with episode_id as filename."""
        try:
            if not self._episode_id:
                logger.warning("No episode_id available for recording filename")
                return

            if self._traj_output_dir is None:
                logger.warning("No traj_output_dir available for recording filename")
                return

            path = self._traj_output_dir / f"{self._episode_id}.mp4"
            await self.actor.end_recording.remote(str(path))

            logger.info(f"Saved screen recording to {path}")
        except Exception:
            logger.exception("Failed to stop/save screen recording")

    def _obs_to_message(self, obs: dict) -> UserMessage:
        screenshot = obs.get("screenshot")
        if not screenshot:
            return UserMessage(content="No screenshot available")

        img = Image.open(io.BytesIO(screenshot))
        expected = (self.display_width, self.display_height)

        if img.size != expected:
            logger.debug(
                f"Screenshot size mismatch: got {img.size}, expected {expected}. "
                "Resizing (expected with docker provider).",
            )
            img = img.resize(expected, Image.Resampling.LANCZOS)

        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=85)
        data_uri = "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")

        return UserMessage(content=[ImageURLPart(image_url={"url": data_uri})])

    async def cleanup(self):
        """Clean up the docker and environment."""
        logger.info("Closing (cleanup) environment.")
        if self.actor is not None:
            await self.actor.call.remote("close")

cleanup() async #

Clean up the docker and environment.

Source code in src/agoge/environment/osworld.py
532
533
534
535
536
async def cleanup(self):
    """Clean up the docker and environment."""
    logger.info("Closing (cleanup) environment.")
    if self.actor is not None:
        await self.actor.call.remote("close")

step(action) async #

Execute action and handle recording stop on episode end.

Source code in src/agoge/environment/osworld.py
404
405
406
407
408
409
410
411
412
async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
    """Execute action and handle recording stop on episode end."""
    chat, reward, done = await super().step(action)

    # Stop recording when episode ends
    if done and self.record_episodes:
        await self._stop_recording()

    return chat, reward, done

OSWorldG #

Bases: Environment

Single-step offline environment for OSWorld-G benchmark.

The environment presents an instruction and a single screenshot. The agent is expected to respond with a computer_use tool call describing the grounded action (e.g. a click position). Actions are not executed; evaluation happens outside the environment.

Source code in src/agoge/environment/osworld_g.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class OSWorldG(Environment):
    """Single-step offline environment for OSWorld-G benchmark.

    The environment presents an instruction and a single screenshot. The agent is
    expected to respond with a ``computer_use`` tool call describing the grounded
    action (e.g. a click position). Actions are not executed; evaluation happens
    outside the environment.
    """

    def __init__(self, *, max_steps: int = 1) -> None:
        if max_steps != 1:
            logger.warning("OSWorldG overrides max_steps to 1 (received %s).", max_steps)
        super().__init__(max_steps=max_steps)
        # self.toolset = OSWorldToolSet(native_resolution=(1920, 1080), action_resolution=(1920, 1080))
        self.toolset = ComputerUse(screen_width=1920, screen_height=1080)

    async def _reset_impl(self, task: Task, **reset_kwargs: Any) -> Chat:
        instruction = task.inputs.get("instruction")
        image_payload = task.inputs.get("images")
        if image_payload is None:
            image_payload = task.inputs.get("image")

        if not instruction:
            raise ValueError("Task inputs must include an 'instruction' string.")
        if image_payload is None:
            raise ValueError("Task inputs must include an 'image' payload.")

        image_data_uri = self._to_data_uri(image_payload)
        logger.debug(f"OSWorld-G reset with task_id={task.task_id}")

        screenshot_msg = UserMessage(
            label=MessageLabel.OBSERVATION,
            content=[ImageURLPart(image_url={"url": image_data_uri})],
        )
        instruction_msg = UserMessage(
            content=f"Please complete the following tasks via mouse click or wait: {instruction}"
        )
        image_size = task.metadata.get("image_size")
        model_image_size = (self.toolset.screen_width, self.toolset.screen_height)

        original_width, original_height = image_size[0], image_size[1]
        resized_height, resized_width = smart_resize(
            original_height,
            original_width,
            factor=28,
            min_pixels=4 * 32 * 32,
            max_pixels=2116800,
        )
        model_image_size = (resized_width, resized_height)

        self.toolset.set_model_frame(*model_image_size)

        self.current_state = {
            "task_id": task.task_id,
            "instruction": instruction,
            "image": image_data_uri,
            "eval_criteria": task.eval_criteria,
            "metadata": task.metadata,
            "reset_kwargs": reset_kwargs,
            "last_action": None,
            "image_size": image_size,
        }

        return Chat(messages=[screenshot_msg, instruction_msg])

    async def _step_impl(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        tool_messages = []

        for tool_call in action.tool_calls or []:
            func_name = tool_call.function["name"]
            func_args = tool_call.function.get("arguments", {})
            # TODO(FC): disgusting for compatibility reasons.
            # this should have been parsed in the agent.
            if isinstance(func_args, str):
                func_args = json.loads(func_args)

            tool_messages.append(ToolMessage(content=f"Executed {func_name} successfully", tool_call_id=tool_call.id))

        observation = Chat(messages=tool_messages)
        reward = self.evaluate(action.tool_calls)
        return observation, reward, True

    def evaluate(self, tool_calls: list[ToolCall]) -> float:
        """Evaluate the tool calls and return a reward."""

        if not tool_calls:
            logger.warning("No tool calls provided for evaluation; returning 0 reward.")
            return 0.0

        tool_call = tool_calls[0]
        coordinates = self.get_coordinates(tool_call)
        eval_criteria = self.current_state["eval_criteria"]
        box_raw = eval_criteria.get("box_coordinates", [])
        box_type = eval_criteria.get("box_type", "unknown")
        box_coordinates, box_size = self._normalise_box(box_raw, box_type)
        box_type = eval_criteria.get("box_type", "unknown")
        image_size = self.current_state["metadata"].get("image_size", [1920, 1080])

        logger.info(
            f"tool call: {tool_call}, coordinates: {coordinates}, box_type: {box_type}, "
            f"box_size: {box_size}, box_coordinates: {box_coordinates}, image_size: {image_size}"
        )

        success = self._eval_coords(coordinates, box_type, box_size, box_coordinates, image_size)
        return 1.0 if success else 0.0

    def _normalise_box(self, box: list[int | float], box_type: str) -> tuple[list[float], list[float] | None]:
        """Normalise dataset box representation into coordinate and size components."""
        if not box:
            return [-1, -1, -1, -1], None

        if box_type == "bbox":
            if len(box) >= 4:
                x, y, width, height = box[:4]
                return [x, y, x + width, y + height], [width, height]
            logger.warning("BBox missing dimensions: %s", box)
            return [-1, -1, -1, -1], None

        if box_type == "polygon":
            return list(box), None

        if box_type == "refusal":
            return list(box), None

        logger.warning("Unexpected box type '%s'; returning sentinel box.", box_type)
        return [-1, -1, -1, -1], None

    def _eval_coords(
        self,
        coordinate: list[int],
        boxes_type: str,
        boxes_size: list[int],
        boxes_coordinate: list[int],
        image_size: list[int],
    ):
        def _is_point_in_rectangle(point, rect):
            return rect[0] <= point[0] <= rect[2] and rect[1] <= point[1] <= rect[3]

        def _is_point_in_polygon(point, polygon):
            x, y = point
            n = len(polygon) // 2
            inside = False

            j = n - 1
            for i in range(n):
                xi, yi = polygon[i * 2], polygon[i * 2 + 1]
                xj, yj = polygon[j * 2], polygon[j * 2 + 1]

                if (yi > y) != (yj > y) and x < (xj - xi) * (y - yi) / (yj - yi) + xi:
                    inside = not inside
                j = i

            return inside

        # detect first if th coordiante are relative (between 0 and 1)
        if all(0 <= coord <= 1 for coord in coordinate):
            # expand the coordinate to the image width and height
            coordinate = [coord * image_size[i % 2] for i, coord in enumerate(coordinate)]

        # get the center point of the predicted box
        center_x = (coordinate[0] + coordinate[2]) / 2
        center_y = (coordinate[1] + coordinate[3]) / 2
        center_point = [center_x, center_y]

        if boxes_type == "bbox":
            if boxes_coordinate and len(boxes_coordinate) >= 4:
                return _is_point_in_rectangle(center_point, boxes_coordinate)
            logger.warning("BBox coordinates missing after normalisation: %s", boxes_coordinate)
            return False
        elif boxes_type == "polygon":
            return _is_point_in_polygon(center_point, boxes_coordinate)
        elif boxes_type == "refusal":
            # all the center point should be negative
            return all(center_point[i] < 0 for i in range(2))
        else:
            logger.warning("Unknown box type '%s'; returning failure.", boxes_type)
            return False

    def get_coordinates(self, tool_call: ToolCall) -> tuple[int, int]:
        func_name = tool_call.function["name"]
        func_args = tool_call.function.get("arguments", {})
        default_coords = (-1, -1, -1, -1)

        if isinstance(func_args, str):
            try:
                func_args = json.loads(func_args)
            except json.JSONDecodeError:
                logger.warning("Failed to parse tool arguments for %s; returning default coordinates.", func_name)
                return default_coords

        # Get coordinates - return default for wait or missing position
        raw_position = func_args.get("position") or func_args.get("coordinate") or func_args.get("box")

        if func_name == "wait" or raw_position is None:
            if raw_position is None:
                logger.warning("No coordinate-like payload found in tool args %s; returning default.", func_args)
            return default_coords

        if isinstance(raw_position, (list, tuple)):
            if len(raw_position) == 2:
                x, y = raw_position
                return [x, y, x, y]
            if len(raw_position) >= 4:
                return list(raw_position[:4])

        # Some computer_use actions (e.g., drag) may encode a path. Take the final point if present.
        if isinstance(raw_position, dict):
            for key in ("end", "start"):
                if key in raw_position and isinstance(raw_position[key], (list, tuple)) and len(raw_position[key]) == 2:
                    x, y = raw_position[key]
                    return [x, y, x, y]

        logger.warning("Unexpected coordinate payload %s; returning default.", raw_position)
        return default_coords

    def _to_data_uri(self, image_payload: Any) -> str:
        """Convert an arbitrary image payload into a base64 data URI."""
        if isinstance(image_payload, str):
            if image_payload.startswith("data:"):
                return image_payload
            path = Path(image_payload)
            if not path.is_file():
                raise FileNotFoundError(f"Image path does not exist: {image_payload!r}")
            with path.open("rb") as f:
                data = f.read()
            return self._encode_bytes(data)

        if isinstance(image_payload, (bytes, bytearray)):
            return self._encode_bytes(bytes(image_payload))

        if isinstance(image_payload, Image.Image):
            return self._encode_image(image_payload)

        raise TypeError(
            f"Unsupported image payload type {type(image_payload)!r}. Expected str path/data URI, bytes, or PIL Image."
        )

    def _encode_bytes(self, data: bytes) -> str:
        try:
            with Image.open(io.BytesIO(data)) as img:
                return self._encode_image(img)
        except Exception as exc:
            raise ValueError("Unable to decode image bytes for OSWorldG.") from exc

    def _encode_image(self, image: Image.Image) -> str:
        with io.BytesIO() as buffer:
            image.convert("RGB").save(buffer, format="JPEG", quality=85)
            encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
        return f"data:image/jpeg;base64,{encoded}"

evaluate(tool_calls) #

Evaluate the tool calls and return a reward.

Source code in src/agoge/environment/osworld_g.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def evaluate(self, tool_calls: list[ToolCall]) -> float:
    """Evaluate the tool calls and return a reward."""

    if not tool_calls:
        logger.warning("No tool calls provided for evaluation; returning 0 reward.")
        return 0.0

    tool_call = tool_calls[0]
    coordinates = self.get_coordinates(tool_call)
    eval_criteria = self.current_state["eval_criteria"]
    box_raw = eval_criteria.get("box_coordinates", [])
    box_type = eval_criteria.get("box_type", "unknown")
    box_coordinates, box_size = self._normalise_box(box_raw, box_type)
    box_type = eval_criteria.get("box_type", "unknown")
    image_size = self.current_state["metadata"].get("image_size", [1920, 1080])

    logger.info(
        f"tool call: {tool_call}, coordinates: {coordinates}, box_type: {box_type}, "
        f"box_size: {box_size}, box_coordinates: {box_coordinates}, image_size: {image_size}"
    )

    success = self._eval_coords(coordinates, box_type, box_size, box_coordinates, image_size)
    return 1.0 if success else 0.0

ScreenSpot #

Bases: OfflineGroundingEnvironment

Single-step offline environment for ScreenSpot benchmarks (v2, Pro, etc.).

Source code in src/agoge/environment/screenspot_v2.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class ScreenSpot(OfflineGroundingEnvironment):
    """Single-step offline environment for ScreenSpot benchmarks (v2, Pro, etc.)."""

    def __init__(
        self,
        *,
        default_screen_width: int = 1920,
        default_screen_height: int = 1080,
        action_resolution: tuple[int, int] | None = None,
    ) -> None:
        super().__init__(max_steps=1)
        self.default_screen_width = default_screen_width
        self.default_screen_height = default_screen_height
        # action_resolution: coordinate space used by the model (e.g., 1000x1000 for Qwen)
        # If None, uses the resized image dimensions from smart_resize
        self.action_resolution = action_resolution
        self.toolset = ComputerUse(screen_width=default_screen_width, screen_height=default_screen_height)

    async def _reset_impl(self, task: Task, **reset_kwargs: Any) -> Chat:
        instruction = task.inputs.get("instruction")
        image_payload = task.inputs.get("images") or task.inputs.get("image")

        if not instruction:
            raise ValueError("Task inputs must include an 'instruction' string.")
        if image_payload is None:
            raise ValueError("Task inputs must include an 'image' payload.")

        metadata = dict(task.metadata or {})
        image_size = metadata.get("image_size")
        if not image_size or len(image_size) != 2:
            inferred_width, inferred_height = self._get_image_size(image_payload)
            image_size = [int(inferred_width), int(inferred_height)]
            metadata["image_size"] = image_size

        width, height = int(image_size[0]), int(image_size[1])
        if (width, height) != (self.toolset.screen_width, self.toolset.screen_height):
            self.toolset.screen_width = width
            self.toolset.screen_height = height
            logger.debug("ScreenSpot-v2 screen resized to %s x %s", width, height)

        resized_height, resized_width = smart_resize(
            height,
            width,
            factor=28,
            min_pixels=4 * 32 * 32,
            max_pixels=2116800,
        )

        # Use action_resolution if specified (e.g., 1000x1000 for Qwen models),
        # otherwise fall back to the resized image dimensions
        if self.action_resolution is not None:
            model_width, model_height = self.action_resolution
        else:
            model_width, model_height = resized_width, resized_height

        self.toolset.set_model_frame(model_width, model_height)

        image_data_uri = self._to_data_uri(image_payload)
        logger.debug("ScreenSpot-v2 reset with task_id=%s", task.task_id)

        screenshot_msg = self._make_screenshot_message(image_data_uri)
        instruction_msg = UserMessage(content=PROMPT_TEMPLATE.format(instruction=instruction))

        metadata["image_size"] = [width, height]

        self.current_state = {
            "task_id": task.task_id,
            "instruction": instruction,
            "image": image_data_uri,
            "eval_criteria": task.eval_criteria,
            "metadata": metadata,
            "reset_kwargs": reset_kwargs,
            "last_action": None,
            "image_size": [width, height],
            "model_image_size": [model_width, model_height],
        }

        return Chat(messages=[screenshot_msg, instruction_msg])

TestEnv #

Bases: Environment

Source code in src/agoge/environment/test.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class TestEnv(Environment):
    def __init__(self, **kwargs):
        """Initialize TestEnv.

        Args:
            **kwargs: Accepts any keyword arguments (e.g., display_width, display_height)
                     for compatibility with agent configs that reference these fields,
                     but ignores them as TestEnv doesn't use display dimensions.
        """
        super().__init__()
        self.toolset = TestToolSet()

    def get_tool_schemas(self):
        return self.toolset.schema

    async def reset(self, task: Task, **reset_kwargs) -> Chat:
        adapter = TypeAdapter(ChatMessage)
        # System prompt explicitly instructs the model to use tools
        # This is required for RLAgent which expects tool calls
        msg1 = adapter.validate_python({
            "role": "system",
            "content": (
                "You are a helpful assistant. You MUST use the print_test_num tool to respond. "
                "Always call the print_test_num function with a test_num value. "
                "Do not respond with text only - you must make a tool call."
            ),
        })
        msg2 = adapter.validate_python({
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABAAEADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigD//2Q=="  # noqa: E501
                    },
                },
                {"type": "text", "text": "Please call the print_test_num tool with any number to complete this task."},
            ],
        })

        self.current_state = Chat.model_validate({"messages": [msg1, msg2]})
        return self.current_state

    async def step(self, action: AssistantMessage) -> tuple[Chat, float, bool]:
        new_obs = Chat(messages=[])
        reward = 1.0
        done = True
        return new_obs, reward, done

__init__(**kwargs) #

Initialize TestEnv.

Parameters:

Name Type Description Default
**kwargs

Accepts any keyword arguments (e.g., display_width, display_height) for compatibility with agent configs that reference these fields, but ignores them as TestEnv doesn't use display dimensions.

{}
Source code in src/agoge/environment/test.py
24
25
26
27
28
29
30
31
32
33
def __init__(self, **kwargs):
    """Initialize TestEnv.

    Args:
        **kwargs: Accepts any keyword arguments (e.g., display_width, display_height)
                 for compatibility with agent configs that reference these fields,
                 but ignores them as TestEnv doesn't use display dimensions.
    """
    super().__init__()
    self.toolset = TestToolSet()