"""AristotleProver: a wrapper around Harmonic's Aristotle API.
No agentic sandbox is needed for generation -- we hand the lake project to the
hosted Aristotle agent via ``aristotlelib`` (submit -> wait -> download), unpack the
returned archive over the workdir, and let the shared verifier do the final check in
our own Docker sandbox. This is the platform's simplest end-to-end slice.
The remote interaction is isolated in :meth:`AristotleProver._submit_and_download`
so tests can stand in a fake result without touching the network or an API key.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import shutil
import tarfile
import tempfile
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from open_atp.backends.base import ComputeBackend
from open_atp.lean import ProofTask
from open_atp.provers.base import (
AutomatedProver,
ProofResult,
compose_prompt,
)
if TYPE_CHECKING:
from aristotlelib import AgentTask, Project
log = logging.getLogger(__name__)
_T = TypeVar("_T")
PROVER_PROMPT = (
"Complete every `sorry` in this Lean project. Make the project compile and be "
"sorry-free without introducing new axioms; do not weaken or delete the stated "
"theorems."
)
# END PROVER_PROMPT (docs literalinclude end marker -- keep adjacent)
# Directories never worth shipping to Aristotle / copying into the workdir.
_IGNORE = shutil.ignore_patterns(".lake", ".git", "*.tar.gz")
def _is_transient(exc: BaseException) -> bool:
"""True for errors worth retrying: a dropped connection, timeout, or 5xx.
aristotlelib turns httpx transport failures during a plain request into an
``AristotleAPIError`` with no status code (its ``RequestError`` wrapper) and
leaves HTTP status errors with their code; a streamed run instead surfaces the
raw ``httpx`` error. We treat transport-level failures and server-side 5xx as
transient, but let real 4xx (bad key, missing project) fail fast.
"""
import httpx
from aristotlelib.api_request import AristotleAPIError
if isinstance(exc, httpx.TransportError):
return True
if isinstance(exc, AristotleAPIError):
return exc.status_code is None or exc.status_code >= 500
return False
[docs]
class AristotleProver(AutomatedProver):
"""Prove by handing the whole project to Harmonic's hosted Aristotle agent.
Generation happens over the network (submit the lake project, wait, download the
result archive, unpack it over the workdir); the shared
:class:`~open_atp.verify.Verifier` then runs the same local compile/sorry/axiom
check. Generation is network-only, so the backend is used solely for that final
check -- unlike the agentic provers, there is no live session to reuse.
Parameters
----------
backend : ComputeBackend
The sandbox used only for the final verify; Aristotle generates over the
network, so there is no live session to reuse.
api_key : str, optional
The Harmonic API key. ``None`` (default) reads it from the host
``ARISTOTLE_API_KEY`` env var.
allow_agent_questions : bool
Whether to let the hosted agent ask clarifying questions. Off by default:
this is a headless API path and a prompt for stdin would hang the run.
max_connection_retries : int
Bounds per-call retries (list/refresh/download) when a connection drops.
The hosted run lives server-side, so a dropped connection is recoverable:
re-fetch rather than reporting the run failed. Default ``5``.
max_resume_attempts : int
Bounds how many times we re-attach to the event stream when it drops
mid-run. Default ``20``.
resume_backoff_seconds : float
Initial sleep between retries/resumes, doubling (capped) between tries.
Default ``5.0``.
timeout_s : int
Wall-clock budget for the generation run, in seconds. Default ``1800``.
Attributes
----------
prover_prompt : str
The prover's own prompt handed to Aristotle, before any user prompt.
Examples
--------
Construct the prover directly (network-only generation, so the backend is just
the verify backend):
>>> from open_atp.backends.docker import DockerBackend
>>> from open_atp.provers.aristotle import AristotleProver
>>> backend = DockerBackend()
>>> prover = AristotleProver(backend=backend)
>>> prover.name
'aristotle'
Or build the same prover from the standard catalog by name, taking its
baked-in defaults (see :func:`~open_atp.config.standard_prover`):
>>> from open_atp import standard_prover
>>> prover = standard_prover("aristotle", backend=DockerBackend())
>>> prover.name
'aristotle'
Complete a task's ``sorry``\\s with
:meth:`~open_atp.provers.base.AutomatedProver.prove`, here on a bundled example
(this hits the hosted Aristotle API, needing ``ARISTOTLE_API_KEY``, and runs
Docker for the verify):
>>> import tempfile
>>> from open_atp.examples import EXAMPLE, example_task
>>> task = example_task(EXAMPLE.ABS_MUL_LT)
>>> result = prover.prove(task, tempfile.mkdtemp()) # doctest: +SKIP
>>> result.success # doctest: +SKIP
True
"""
name = "aristotle"
def __init__(
self,
*,
backend: ComputeBackend,
api_key: str | None = None,
allow_agent_questions: bool = False,
max_connection_retries: int = 5,
max_resume_attempts: int = 20,
resume_backoff_seconds: float = 5.0,
timeout_s: int = 1800,
) -> None:
super().__init__(backend=backend, timeout_s=timeout_s)
#: The Harmonic API key, or ``None`` to read ``ARISTOTLE_API_KEY`` at run time.
self._api_key = api_key
#: Whether to let the hosted agent ask clarifying questions.
self.allow_agent_questions = allow_agent_questions
#: Bounds per-call retries when a connection drops.
self.max_connection_retries = max_connection_retries
#: Bounds re-attaches to the event stream when it drops mid-run.
self.max_resume_attempts = max_resume_attempts
#: Initial sleep between retries/resumes, doubling (capped) between tries.
self.resume_backoff_seconds = resume_backoff_seconds
@property
def prover_prompt(self) -> str:
"""The prover's own prompt handed to Aristotle, before any user prompt."""
return PROVER_PROMPT
def _generate(
self, task: ProofTask, wd: Path, logs_dir: Path, result: ProofResult
) -> None:
# Stage the original project so the workdir is a complete project both for the
# upload and, after extraction, for verification.
shutil.copytree(task.project.root, wd, dirs_exist_ok=True, ignore=_IGNORE)
original = {
p.relative_to(task.project.root).as_posix(): p.read_text()
for p in task.project.lean_files()
}
prompt = compose_prompt(self.prover_prompt, task.user_prompt)
# The raw result archive and the full run record both belong with the run's
# logs, not the proof project. ``prove`` already created ``logs_dir``; the
# hosted agent has no live stdout stream, so its record (events, transcript,
# summary) is downloaded here rather than teed.
result_tar = logs_dir / "aristotle_result.tar.gz"
downloaded, metadata = asyncio.run(
self._submit_and_download(wd, prompt, result_tar, logs_dir)
)
if downloaded is not None:
self._extract_over(downloaded, wd)
# Report the .lean files Aristotle changed or added.
completed: dict[str, str] = {}
for path in sorted(wd.rglob("*.lean")):
if ".lake" in path.parts:
continue
rel = path.relative_to(wd).as_posix()
content = path.read_text()
if original.get(rel) != content:
completed[rel] = content
# The hosted agent's run summary is its primary human-readable record; surface
# it beside the event record in the logs dir.
summary_src = wd / "ARISTOTLE_SUMMARY.md"
if summary_src.is_file():
(logs_dir / "summary.md").write_text(summary_src.read_text())
result.completed_files = completed
# The Aristotle API does not expose a per-run cost; leave it unset.
result.cost_usd = None
result.metadata = metadata
async def _submit_and_download(
self, project_dir: Path, prompt: str, dest_tar: Path, logs_dir: Path
) -> tuple[Path | None, dict[str, object]]:
"""Submit ``project_dir`` to Aristotle, wait, and download the result archive.
Also syncs the full run record (task metadata, every event, a readable
transcript, project metadata) to ``logs_dir`` on the host.
Returns ``(downloaded_tar_or_None, metadata)``. Isolated for testing.
"""
import aristotlelib
from aristotlelib import AgentQuestionsSetting, Project
key = self._api_key or os.environ.get("ARISTOTLE_API_KEY")
if key:
aristotlelib.set_api_key(key)
questions = (
AgentQuestionsSetting.TIMEOUT_15_MIN
if self.allow_agent_questions
else AgentQuestionsSetting.DISABLED
)
project = await Project.create_from_directory(
prompt=prompt, project_dir=project_dir, agent_questions_setting=questions
)
tasks, _ = await self._with_retry(
lambda: project.get_tasks(limit=1), "list tasks"
)
metadata: dict[str, object] = {"project_id": project.project_id}
if not tasks:
metadata["error"] = "Aristotle returned no task to wait on."
return None, metadata
agent_task = tasks[0]
# Resume across dropped connections until the task truly settles server-side.
await self._wait_until_terminal(agent_task)
await self._with_retry(project.refresh, "refresh project")
metadata.update(
task_id=agent_task.agent_task_id,
task_status=agent_task.status.name,
percent_complete=agent_task.percent_complete,
output_summary=agent_task.output_summary,
)
# Sync the full run record to the host before returning. Best-effort: a hiccup
# syncing logs must not discard an otherwise-good result.
try:
await self._sync_run_info(project, agent_task, logs_dir)
except Exception: # noqa: BLE001 -- logs are nice-to-have, not the result
log.warning("aristotle: failed to sync run record", exc_info=True)
metadata["logs_dir"] = str(logs_dir)
if not project.has_files:
metadata["error"] = "Aristotle produced no output files."
return None, metadata
await self._with_retry(
lambda: project.get_files(destination=dest_tar), "download files"
)
return dest_tar, metadata
async def _with_retry(self, op: Callable[[], Awaitable[_T]], what: str) -> _T:
"""Run an awaitable-returning ``op``, retrying transient connection failures.
Backs off exponentially (capped) and re-raises once the retry budget is spent
or the error is not transient, so a genuine bad key/4xx still fails fast.
"""
delay = self.resume_backoff_seconds
for attempt in range(1, self.max_connection_retries + 1):
try:
return await op()
except Exception as exc: # noqa: BLE001 -- re-raised unless transient
if not _is_transient(exc) or attempt == self.max_connection_retries:
raise
log.warning(
"aristotle: %s failed (%s); retrying (attempt %d/%d)",
what,
exc,
attempt,
self.max_connection_retries,
)
await asyncio.sleep(delay)
delay = min(delay * 2, 60.0)
raise AssertionError("unreachable") # loop either returns or raises
async def _wait_until_terminal(self, agent_task: AgentTask) -> None:
"""Wait for the task to reach a terminal state, resuming across dropped links.
aristotlelib's ``wait_for_completion`` swallows a dropped event stream and
returns with a stale, still-running status while the task keeps going on the
server. Treat any non-terminal status after it returns as a dropped connection,
re-fetch the true state, and re-attach to the stream until the task actually
settles (or we exhaust the resume budget, in which case we proceed with
whatever output exists -- the run dashboard is the source of truth).
"""
from aristotlelib.agent_task import TaskStatus
terminal = {
TaskStatus.COMPLETE,
TaskStatus.COMPLETE_WITH_ERRORS,
TaskStatus.OUT_OF_BUDGET,
TaskStatus.FAILED,
TaskStatus.CANCELED,
}
delay = self.resume_backoff_seconds
for attempt in range(1, self.max_resume_attempts + 1):
try:
await agent_task.wait_for_completion()
except Exception as exc: # noqa: BLE001 -- re-raised unless transient
if not _is_transient(exc):
raise
log.warning("aristotle: wait interrupted (%s); resuming", exc)
await self._with_retry(agent_task.refresh, "refresh task status")
if agent_task.status in terminal:
return
log.warning(
"aristotle: connection dropped with task still %s; resuming "
"(attempt %d/%d)",
agent_task.status.name,
attempt,
self.max_resume_attempts,
)
await asyncio.sleep(delay)
delay = min(delay * 2, 60.0)
log.warning(
"aristotle: task %s still %s after %d resume attempts; proceeding with "
"whatever output is available",
agent_task.agent_task_id,
agent_task.status.name,
self.max_resume_attempts,
)
async def _sync_run_info(
self, project: Project, agent_task: AgentTask, logs_dir: Path
) -> None:
"""Download the task's metadata and full event log to ``logs_dir``.
Writes ``project.json``, ``task.json``, ``events.json`` (every event,
oldest-first), and a human-readable ``transcript.txt``.
"""
logs_dir.mkdir(parents=True, exist_ok=True)
# Page through every event, oldest-first, so the transcript reads top-down.
events = []
pagination_key = None
while True:
page, pagination_key = await self._with_retry(
lambda: agent_task.get_events(
limit=100, pagination_key=pagination_key, newest_first=False
),
"fetch events",
)
events.extend(page)
if not pagination_key:
break
def _dump(obj: object) -> str:
return json.dumps(obj, default=str, indent=2)
(logs_dir / "project.json").write_text(_dump(project.model_dump()))
(logs_dir / "task.json").write_text(_dump(agent_task.model_dump()))
(logs_dir / "events.json").write_text(_dump([e.model_dump() for e in events]))
(logs_dir / "transcript.txt").write_text(
"\n\n".join(f"[{e.created_at.isoformat()}] {e}" for e in events)
)
@staticmethod
def _extract_over(tar_path: Path, workdir: Path) -> None:
"""Unpack Aristotle's archive over the workdir (completed files win).
Aristotle wraps its result in a single top-level directory (e.g.
``<name>_aristotle/``); we unwrap that so files land at the workdir root and
overwrite the originals, rather than nesting a second copy one level down.
"""
with tempfile.TemporaryDirectory() as tmp:
staging = Path(tmp)
with tarfile.open(tar_path, "r:gz") as tar:
tar.extractall(staging, filter="data")
# Unwrap iff everything sits under exactly one top-level directory.
entries = list(staging.iterdir())
wrapped = len(entries) == 1 and entries[0].is_dir()
source = entries[0] if wrapped else staging
for item in source.rglob("*"):
if item.is_file():
dest = workdir / item.relative_to(source)
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(item, dest)