Source code for open_atp.provers.aristotle

"""AristotleProver: a wrapper around Harmonic's Aristotle API.

No agentic sandbox is needed for generation -- we hand the lake project to the
hosted Aristotle agent via ``aristotlelib`` (submit -> wait -> download), unpack the
returned archive over the workdir, and let the shared verifier do the final check in
our own Docker sandbox. This is the platform's simplest end-to-end slice.

The remote interaction is isolated in :meth:`AristotleProver._submit_and_download`
so tests can stand in a fake result without touching the network or an API key.
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import shutil
import tarfile
import tempfile
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar

from open_atp.lean import ProofTask
from open_atp.provers.base import AutomatedProver, AutomatedProverConfig
from open_atp.verify import ProofResult

if TYPE_CHECKING:
    from aristotlelib import AgentTask, Project

log = logging.getLogger(__name__)

_T = TypeVar("_T")

_DEFAULT_PROMPT = (
    "Complete every `sorry` in this Lean project. Make the project compile and be "
    "sorry-free without introducing new axioms; do not weaken or delete the stated "
    "theorems."
)
# END _DEFAULT_PROMPT (docs literalinclude end marker -- keep adjacent)

# Directories never worth shipping to Aristotle / copying into the workdir.
_IGNORE = shutil.ignore_patterns(".lake", ".git", "*.tar.gz")


def _is_transient(exc: BaseException) -> bool:
    """True for errors worth retrying: a dropped connection, timeout, or 5xx.

    aristotlelib turns httpx transport failures during a plain request into an
    ``AristotleAPIError`` with no status code (its ``RequestError`` wrapper) and
    leaves HTTP status errors with their code; a streamed run instead surfaces the
    raw ``httpx`` error. We treat transport-level failures and server-side 5xx as
    transient, but let real 4xx (bad key, missing project) fail fast.
    """
    import httpx
    from aristotlelib.api_request import AristotleAPIError

    if isinstance(exc, httpx.TransportError):
        return True
    if isinstance(exc, AristotleAPIError):
        return exc.status_code is None or exc.status_code >= 500
    return False


[docs] @dataclass class AristotleProverConfig(AutomatedProverConfig): """Configuration for :class:`AristotleProver`. Extends :class:`~open_atp.provers.base.AutomatedProverConfig` (``timeout_s``, ``env``) with the hosted-API knobs. Attributes ---------- api_key_env : str Name of the environment variable holding the Harmonic API key. Default ``ARISTOTLE_API_KEY``. allow_agent_questions : bool Whether to let the hosted agent ask clarifying questions. Off by default: this is a headless API path and a prompt for stdin would hang the run. max_connection_retries : int Bounds per-call retries (list/refresh/download) when a connection drops. The hosted run lives server-side, so a dropped connection is recoverable: re-fetch rather than reporting the run failed. Default ``5``. max_resume_attempts : int Bounds how many times we re-attach to the event stream when it drops mid-run. Default ``20``. resume_backoff_seconds : float Initial sleep between retries/resumes, doubling (capped) between tries. Default ``5.0``. """ api_key_env: str = "ARISTOTLE_API_KEY" allow_agent_questions: bool = False max_connection_retries: int = 5 max_resume_attempts: int = 20 resume_backoff_seconds: float = 5.0
[docs] class AristotleProver(AutomatedProver): """Prove by handing the whole project to Harmonic's hosted Aristotle agent. Generation happens over the network (submit the lake project, wait, download the result archive, unpack it over the workdir); the shared :class:`~open_atp.verify.Verifier` then runs the same local compile/sorry/axiom check. Network-only, so it takes just the verify backend -- no ``agent_backend``. Examples -------- Build the config and construct the prover directly (network-only, so it takes just the verify backend): >>> from open_atp.backends.docker import DockerBackend, DockerConfig >>> from open_atp.provers.aristotle import AristotleProver, AristotleProverConfig >>> backend = DockerBackend(DockerConfig()) >>> config = AristotleProverConfig() >>> prover = AristotleProver(config, verification_backend=backend) >>> prover.config.api_key_env 'ARISTOTLE_API_KEY' """ name = "aristotle" config: AristotleProverConfig def _generate( self, task: ProofTask, wd: Path, logs_dir: Path, result: ProofResult ) -> None: # Stage the original project so the workdir is a complete project both for the # upload and, after extraction, for verification. shutil.copytree(task.project.root, wd, dirs_exist_ok=True, ignore=_IGNORE) original = { p.relative_to(task.project.root).as_posix(): p.read_text() for p in task.project.lean_files() } prompt = task.instructions or _DEFAULT_PROMPT # The raw result archive and the full run record both belong with the run's # logs, not the proof project. ``prove`` already created ``logs_dir``; the # hosted agent has no live stdout stream, so its record (events, transcript, # summary) is downloaded here rather than teed. result_tar = logs_dir / "aristotle_result.tar.gz" downloaded, metadata = asyncio.run( self._submit_and_download(wd, prompt, result_tar, logs_dir) ) if downloaded is not None: self._extract_over(downloaded, wd) # Report the .lean files Aristotle changed or added. completed: dict[str, str] = {} for path in sorted(wd.rglob("*.lean")): if ".lake" in path.parts: continue rel = path.relative_to(wd).as_posix() content = path.read_text() if original.get(rel) != content: completed[rel] = content # The hosted agent's run summary is its primary human-readable record; surface # it beside the event record in the logs dir. summary_src = wd / "ARISTOTLE_SUMMARY.md" if summary_src.is_file(): (logs_dir / "summary.md").write_text(summary_src.read_text()) result.completed_files = completed # The Aristotle API does not expose a per-run cost; leave it unset. result.cost_usd = None result.metadata = metadata async def _submit_and_download( self, project_dir: Path, prompt: str, dest_tar: Path, logs_dir: Path ) -> tuple[Path | None, dict[str, object]]: """Submit ``project_dir`` to Aristotle, wait, and download the result archive. Also syncs the full run record (task metadata, every event, a readable transcript, project metadata) to ``logs_dir`` on the host. Returns ``(downloaded_tar_or_None, metadata)``. Isolated for testing. """ import aristotlelib from aristotlelib import AgentQuestionsSetting, Project key = os.environ.get(self.config.api_key_env) if key: aristotlelib.set_api_key(key) questions = ( AgentQuestionsSetting.TIMEOUT_15_MIN if self.config.allow_agent_questions else AgentQuestionsSetting.DISABLED ) project = await Project.create_from_directory( prompt=prompt, project_dir=project_dir, agent_questions_setting=questions ) tasks, _ = await self._with_retry( lambda: project.get_tasks(limit=1), "list tasks" ) metadata: dict[str, object] = {"project_id": project.project_id} if not tasks: metadata["error"] = "Aristotle returned no task to wait on." return None, metadata agent_task = tasks[0] # Resume across dropped connections until the task truly settles server-side. await self._wait_until_terminal(agent_task) await self._with_retry(project.refresh, "refresh project") metadata.update( task_id=agent_task.agent_task_id, task_status=agent_task.status.name, percent_complete=agent_task.percent_complete, output_summary=agent_task.output_summary, ) # Sync the full run record to the host before returning. Best-effort: a hiccup # syncing logs must not discard an otherwise-good result. try: await self._sync_run_info(project, agent_task, logs_dir) except Exception: # noqa: BLE001 -- logs are nice-to-have, not the result log.warning("aristotle: failed to sync run record", exc_info=True) metadata["logs_dir"] = str(logs_dir) if not project.has_files: metadata["error"] = "Aristotle produced no output files." return None, metadata await self._with_retry( lambda: project.get_files(destination=dest_tar), "download files" ) return dest_tar, metadata async def _with_retry(self, op: Callable[[], Awaitable[_T]], what: str) -> _T: """Run an awaitable-returning ``op``, retrying transient connection failures. Backs off exponentially (capped) and re-raises once the retry budget is spent or the error is not transient, so a genuine bad key/4xx still fails fast. """ delay = self.config.resume_backoff_seconds for attempt in range(1, self.config.max_connection_retries + 1): try: return await op() except Exception as exc: # noqa: BLE001 -- re-raised unless transient if ( not _is_transient(exc) or attempt == self.config.max_connection_retries ): raise log.warning( "aristotle: %s failed (%s); retrying (attempt %d/%d)", what, exc, attempt, self.config.max_connection_retries, ) await asyncio.sleep(delay) delay = min(delay * 2, 60.0) raise AssertionError("unreachable") # loop either returns or raises async def _wait_until_terminal(self, agent_task: AgentTask) -> None: """Wait for the task to reach a terminal state, resuming across dropped links. aristotlelib's ``wait_for_completion`` swallows a dropped event stream and returns with a stale, still-running status while the task keeps going on the server. Treat any non-terminal status after it returns as a dropped connection, re-fetch the true state, and re-attach to the stream until the task actually settles (or we exhaust the resume budget, in which case we proceed with whatever output exists -- the run dashboard is the source of truth). """ from aristotlelib.agent_task import TaskStatus terminal = { TaskStatus.COMPLETE, TaskStatus.COMPLETE_WITH_ERRORS, TaskStatus.OUT_OF_BUDGET, TaskStatus.FAILED, TaskStatus.CANCELED, } delay = self.config.resume_backoff_seconds for attempt in range(1, self.config.max_resume_attempts + 1): try: await agent_task.wait_for_completion() except Exception as exc: # noqa: BLE001 -- re-raised unless transient if not _is_transient(exc): raise log.warning("aristotle: wait interrupted (%s); resuming", exc) await self._with_retry(agent_task.refresh, "refresh task status") if agent_task.status in terminal: return log.warning( "aristotle: connection dropped with task still %s; resuming " "(attempt %d/%d)", agent_task.status.name, attempt, self.config.max_resume_attempts, ) await asyncio.sleep(delay) delay = min(delay * 2, 60.0) log.warning( "aristotle: task %s still %s after %d resume attempts; proceeding with " "whatever output is available", agent_task.agent_task_id, agent_task.status.name, self.config.max_resume_attempts, ) async def _sync_run_info( self, project: Project, agent_task: AgentTask, logs_dir: Path ) -> None: """Download the task's metadata and full event log to ``logs_dir``. Writes ``project.json``, ``task.json``, ``events.json`` (every event, oldest-first), and a human-readable ``transcript.txt``. """ logs_dir.mkdir(parents=True, exist_ok=True) # Page through every event, oldest-first, so the transcript reads top-down. events = [] pagination_key = None while True: page, pagination_key = await self._with_retry( lambda: agent_task.get_events( limit=100, pagination_key=pagination_key, newest_first=False ), "fetch events", ) events.extend(page) if not pagination_key: break def _dump(obj: object) -> str: return json.dumps(obj, default=str, indent=2) (logs_dir / "project.json").write_text(_dump(project.model_dump())) (logs_dir / "task.json").write_text(_dump(agent_task.model_dump())) (logs_dir / "events.json").write_text(_dump([e.model_dump() for e in events])) (logs_dir / "transcript.txt").write_text( "\n\n".join(f"[{e.created_at.isoformat()}] {e}" for e in events) ) @staticmethod def _extract_over(tar_path: Path, workdir: Path) -> None: """Unpack Aristotle's archive over the workdir (completed files win). Aristotle wraps its result in a single top-level directory (e.g. ``<name>_aristotle/``); we unwrap that so files land at the workdir root and overwrite the originals, rather than nesting a second copy one level down. """ with tempfile.TemporaryDirectory() as tmp: staging = Path(tmp) with tarfile.open(tar_path, "r:gz") as tar: tar.extractall(staging, filter="data") # Unwrap iff everything sits under exactly one top-level directory. entries = list(staging.iterdir()) wrapped = len(entries) == 1 and entries[0].is_dir() source = entries[0] if wrapped else staging for item in source.rglob("*"): if item.is_file(): dest = workdir / item.relative_to(source) dest.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(item, dest)