Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions c2rust-postprocess/postprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from postprocess.utils import existing_file

DEFAULT_LLM_MODEL = "gemini-3.5-flash"
DEFAULT_TRANSFORMS = ("comments", "asserts", "formatting")


def build_arg_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -113,10 +114,11 @@ def build_arg_parser() -> argparse.ArgumentParser:
type=str,
required=False,
action="append",
default=["comments"],
default=None,
help=(
"Transform to apply; pass multiple times to apply multiple transforms "
"in sorted order (default: comments)"
"in the order provided; duplicate transforms are ignored "
f"(default: {', '.join(DEFAULT_TRANSFORMS)})"
),
)

Expand Down Expand Up @@ -177,13 +179,16 @@ def main(argv: Sequence[str] | None = None):

model = get_model(args.llm_model)

# sort transform IDs to transforms always run in the same order to
# maximize cache hits even if the user passed them in a different order
transform_ids = sorted(
transform_id.strip()
for transform_id in set(args.transform)
if transform_id.strip()
# De-duplicate transform IDs while preserving their first occurrence.
transform_args = args.transform or DEFAULT_TRANSFORMS
transform_ids = list(
dict.fromkeys(
transform_id.strip()
for transform_id in transform_args
if transform_id.strip()
)
)

transforms = [
get_transform_by_id(
transform_id,
Expand Down
15 changes: 12 additions & 3 deletions c2rust-postprocess/postprocess/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from postprocess.utils import get_tool_path


class MergeRustError(RuntimeError):
"""The merge_rust tool rejected a replacement Rust definition."""


def get_c_sourcefile(compile_commands, rustfile: Path) -> Path | None:
c_file_guesses = [rustfile.with_suffix(".c"), rustfile.with_suffix(".C")]

Expand Down Expand Up @@ -296,8 +300,13 @@ def update_rust_definition(
)

if result.returncode != 0:
print(result.stdout)
print(result.stderr)
raise RuntimeError(f"merge_rust failed with exit code {result.returncode}")
message_parts = [
f"merge_rust failed with exit code {result.returncode}"
]
if result.stdout.strip():
message_parts.append(f"stdout:\n{result.stdout.strip()}")
if result.stderr.strip():
message_parts.append(f"stderr:\n{result.stderr.strip()}")
raise MergeRustError("\n".join(message_parts))

logging.info(f"Updated Rust definition of {identifier}")
117 changes: 109 additions & 8 deletions c2rust-postprocess/postprocess/models/gpt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import inspect
import json
from collections.abc import Callable, Iterable
from typing import Any
from typing import Any, Protocol, cast

from openai import OpenAI
from openai.types.responses import (
FunctionToolParam,
ResponseFunctionToolCall,
ResponseInputParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput

from postprocess.models import AbstractGenerativeModel


class NamedCallable(Protocol):
__name__: str

def __call__(self, *args: Any, **kwargs: Any) -> Any: ...


class GPTModel(AbstractGenerativeModel):
def __init__(
self,
Expand All @@ -22,13 +36,100 @@ def generate_with_tools(
tools: Iterable[Callable[..., Any]] = (),
max_tool_loops: int = 5,
) -> str:
# TODO: implement tool calling support
assert not tools, "Tool calling not yet implemented for GPTModel"
tools = [self._named_tool(tool) for tool in tools]
tool_schemas = [self._tool_schema(tool) for tool in tools]
tool_by_name = {tool.__name__: tool for tool in tools}

if tool_schemas:
response = self.client.responses.create(
model=self.id,
input=messages[0]["content"],
max_tool_calls=max_tool_loops,
tools=tool_schemas,
)
else:
response = self.client.responses.create(
model=self.id,
input=messages[0]["content"],
max_tool_calls=max_tool_loops,
)

response = self.client.responses.create(
model=self.id,
input=messages[0]["content"],
max_tool_calls=max_tool_loops,
)
for _ in range(max_tool_loops):
tool_calls = [
cast(ResponseFunctionToolCall, item)
for item in response.output
if getattr(item, "type", None) == "function_call"
]
if not tool_calls:
return response.output_text

tool_outputs: ResponseInputParam = [
FunctionCallOutput(
type="function_call_output",
call_id=tool_call.call_id,
output=self._call_tool(tool_call, tool_by_name),
)
for tool_call in tool_calls
]
response = self.client.responses.create(
model=self.id,
input=tool_outputs,
previous_response_id=response.id,
max_tool_calls=max_tool_loops,
tools=tool_schemas,
)

return response.output_text

def _named_tool(self, tool: Callable[..., Any]) -> NamedCallable:
if not hasattr(tool, "__name__"):
raise TypeError(f"Tool must be a named function: {tool!r}")
return cast(NamedCallable, tool)

def _tool_schema(self, tool: NamedCallable) -> FunctionToolParam:
signature = inspect.signature(tool)
properties: dict[str, object] = {}
required: list[str] = []
for name, parameter in signature.parameters.items():
properties[name] = {
"type": self._json_schema_type(parameter.annotation),
}
if parameter.default is inspect.Parameter.empty:
required.append(name)

return {
"type": "function",
"name": tool.__name__,
"description": inspect.getdoc(tool) or f"Call `{tool.__name__}`.",
"parameters": {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
},
"strict": False,
}

def _json_schema_type(self, annotation: Any) -> str:
if annotation is bool:
return "boolean"
if annotation is int:
return "integer"
if annotation is float:
return "number"
return "string"

def _call_tool(
self,
tool_call: ResponseFunctionToolCall,
tool_by_name: dict[str, NamedCallable],
) -> str:
if tool_call.name not in tool_by_name:
raise ValueError(f"Unknown tool call: {tool_call.name}")

arguments = json.loads(tool_call.arguments or "{}")
if not isinstance(arguments, dict):
raise ValueError(f"Tool call arguments must be an object: {arguments}")

result = tool_by_name[tool_call.name](**arguments)
return result if isinstance(result, str) else json.dumps(result)
6 changes: 6 additions & 0 deletions c2rust-postprocess/postprocess/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from postprocess.cache import AbstractCache
from postprocess.models import AbstractGenerativeModel
from postprocess.transforms.asserts import AssertsTransform
from postprocess.transforms.base import AbstractTransform
from postprocess.transforms.comments import CommentsTransform
from postprocess.transforms.formatting import FormattingTransform


def get_transform_by_id(
Expand All @@ -13,5 +15,9 @@ def get_transform_by_id(
match id.lower():
case "comments":
return CommentsTransform(cache=cache, model=model)
case "asserts":
return AssertsTransform(cache=cache, model=model)
case "formatting":
return FormattingTransform(cache=cache, model=model)
case _:
raise ValueError(f"Unsupported transform: {id}")
Loading
Loading