diff --git a/c2rust-postprocess/postprocess/__init__.py b/c2rust-postprocess/postprocess/__init__.py index 306d0a1cca..767ccb5794 100644 --- a/c2rust-postprocess/postprocess/__init__.py +++ b/c2rust-postprocess/postprocess/__init__.py @@ -25,6 +25,7 @@ from postprocess.utils import existing_file DEFAULT_LLM_MODEL = "gemini-3.5-flash" +DEFAULT_TRANSFORMS = ("comments", "asserts", "formatting") def build_arg_parser() -> argparse.ArgumentParser: @@ -113,10 +114,11 @@ def build_arg_parser() -> argparse.ArgumentParser: type=str, required=False, action="append", - default=["comments"], + default=None, help=( "Transform to apply; pass multiple times to apply multiple transforms " - "in sorted order (default: comments)" + "in the order provided; duplicate transforms are ignored " + f"(default: {', '.join(DEFAULT_TRANSFORMS)})" ), ) @@ -177,13 +179,16 @@ def main(argv: Sequence[str] | None = None): model = get_model(args.llm_model) - # sort transform IDs to transforms always run in the same order to - # maximize cache hits even if the user passed them in a different order - transform_ids = sorted( - transform_id.strip() - for transform_id in set(args.transform) - if transform_id.strip() + # De-duplicate transform IDs while preserving their first occurrence. + transform_args = args.transform or DEFAULT_TRANSFORMS + transform_ids = list( + dict.fromkeys( + transform_id.strip() + for transform_id in transform_args + if transform_id.strip() + ) ) + transforms = [ get_transform_by_id( transform_id, diff --git a/c2rust-postprocess/postprocess/definitions/__init__.py b/c2rust-postprocess/postprocess/definitions/__init__.py index 9c22a7f063..d6e998cf12 100644 --- a/c2rust-postprocess/postprocess/definitions/__init__.py +++ b/c2rust-postprocess/postprocess/definitions/__init__.py @@ -22,6 +22,10 @@ from postprocess.utils import get_tool_path +class MergeRustError(RuntimeError): + """The merge_rust tool rejected a replacement Rust definition.""" + + def get_c_sourcefile(compile_commands, rustfile: Path) -> Path | None: c_file_guesses = [rustfile.with_suffix(".c"), rustfile.with_suffix(".C")] @@ -296,8 +300,13 @@ def update_rust_definition( ) if result.returncode != 0: - print(result.stdout) - print(result.stderr) - raise RuntimeError(f"merge_rust failed with exit code {result.returncode}") + message_parts = [ + f"merge_rust failed with exit code {result.returncode}" + ] + if result.stdout.strip(): + message_parts.append(f"stdout:\n{result.stdout.strip()}") + if result.stderr.strip(): + message_parts.append(f"stderr:\n{result.stderr.strip()}") + raise MergeRustError("\n".join(message_parts)) logging.info(f"Updated Rust definition of {identifier}") diff --git a/c2rust-postprocess/postprocess/models/gpt.py b/c2rust-postprocess/postprocess/models/gpt.py index eb0c8904a4..d37a5f1798 100644 --- a/c2rust-postprocess/postprocess/models/gpt.py +++ b/c2rust-postprocess/postprocess/models/gpt.py @@ -1,11 +1,25 @@ +import inspect +import json from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, Protocol, cast from openai import OpenAI +from openai.types.responses import ( + FunctionToolParam, + ResponseFunctionToolCall, + ResponseInputParam, +) +from openai.types.responses.response_input_param import FunctionCallOutput from postprocess.models import AbstractGenerativeModel +class NamedCallable(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + class GPTModel(AbstractGenerativeModel): def __init__( self, @@ -22,13 +36,100 @@ def generate_with_tools( tools: Iterable[Callable[..., Any]] = (), max_tool_loops: int = 5, ) -> str: - # TODO: implement tool calling support - assert not tools, "Tool calling not yet implemented for GPTModel" + tools = [self._named_tool(tool) for tool in tools] + tool_schemas = [self._tool_schema(tool) for tool in tools] + tool_by_name = {tool.__name__: tool for tool in tools} + + if tool_schemas: + response = self.client.responses.create( + model=self.id, + input=messages[0]["content"], + max_tool_calls=max_tool_loops, + tools=tool_schemas, + ) + else: + response = self.client.responses.create( + model=self.id, + input=messages[0]["content"], + max_tool_calls=max_tool_loops, + ) - response = self.client.responses.create( - model=self.id, - input=messages[0]["content"], - max_tool_calls=max_tool_loops, - ) + for _ in range(max_tool_loops): + tool_calls = [ + cast(ResponseFunctionToolCall, item) + for item in response.output + if getattr(item, "type", None) == "function_call" + ] + if not tool_calls: + return response.output_text + + tool_outputs: ResponseInputParam = [ + FunctionCallOutput( + type="function_call_output", + call_id=tool_call.call_id, + output=self._call_tool(tool_call, tool_by_name), + ) + for tool_call in tool_calls + ] + response = self.client.responses.create( + model=self.id, + input=tool_outputs, + previous_response_id=response.id, + max_tool_calls=max_tool_loops, + tools=tool_schemas, + ) return response.output_text + + def _named_tool(self, tool: Callable[..., Any]) -> NamedCallable: + if not hasattr(tool, "__name__"): + raise TypeError(f"Tool must be a named function: {tool!r}") + return cast(NamedCallable, tool) + + def _tool_schema(self, tool: NamedCallable) -> FunctionToolParam: + signature = inspect.signature(tool) + properties: dict[str, object] = {} + required: list[str] = [] + for name, parameter in signature.parameters.items(): + properties[name] = { + "type": self._json_schema_type(parameter.annotation), + } + if parameter.default is inspect.Parameter.empty: + required.append(name) + + return { + "type": "function", + "name": tool.__name__, + "description": inspect.getdoc(tool) or f"Call `{tool.__name__}`.", + "parameters": { + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": False, + }, + "strict": False, + } + + def _json_schema_type(self, annotation: Any) -> str: + if annotation is bool: + return "boolean" + if annotation is int: + return "integer" + if annotation is float: + return "number" + return "string" + + def _call_tool( + self, + tool_call: ResponseFunctionToolCall, + tool_by_name: dict[str, NamedCallable], + ) -> str: + if tool_call.name not in tool_by_name: + raise ValueError(f"Unknown tool call: {tool_call.name}") + + arguments = json.loads(tool_call.arguments or "{}") + if not isinstance(arguments, dict): + raise ValueError(f"Tool call arguments must be an object: {arguments}") + + result = tool_by_name[tool_call.name](**arguments) + return result if isinstance(result, str) else json.dumps(result) diff --git a/c2rust-postprocess/postprocess/transforms/__init__.py b/c2rust-postprocess/postprocess/transforms/__init__.py index 8cc791ed13..2ba21fddff 100644 --- a/c2rust-postprocess/postprocess/transforms/__init__.py +++ b/c2rust-postprocess/postprocess/transforms/__init__.py @@ -1,7 +1,9 @@ from postprocess.cache import AbstractCache from postprocess.models import AbstractGenerativeModel +from postprocess.transforms.asserts import AssertsTransform from postprocess.transforms.base import AbstractTransform from postprocess.transforms.comments import CommentsTransform +from postprocess.transforms.formatting import FormattingTransform def get_transform_by_id( @@ -13,5 +15,9 @@ def get_transform_by_id( match id.lower(): case "comments": return CommentsTransform(cache=cache, model=model) + case "asserts": + return AssertsTransform(cache=cache, model=model) + case "formatting": + return FormattingTransform(cache=cache, model=model) case _: raise ValueError(f"Unsupported transform: {id}") diff --git a/c2rust-postprocess/postprocess/transforms/asserts.py b/c2rust-postprocess/postprocess/transforms/asserts.py new file mode 100644 index 0000000000..262f0bbaf9 --- /dev/null +++ b/c2rust-postprocess/postprocess/transforms/asserts.py @@ -0,0 +1,194 @@ +import logging +from collections.abc import Callable +from pathlib import Path +from textwrap import dedent + +from postprocess.cache import AbstractCache +from postprocess.definitions import CDefinition, MergeRustError +from postprocess.models import AbstractGenerativeModel, api_key_from_env +from postprocess.transforms.base import ( + AbstractTransform, + TransformCandidate, + TransformError, +) +from postprocess.utils import get_highlighted_rust, remove_backticks + +SYSTEM_INSTRUCTION = ( + "You are a helpful assistant that rewrites c2rust-transpiled assert patterns " + "into idiomatic Rust assert! macros." +) + + +class AssertsTransformPrompt: + c_function: str + rust_function: str + prompt_text: str + identifier: str + + __slots__ = ("c_function", "rust_function", "prompt_text", "identifier") + + def __init__( + self, c_function: str, rust_function: str, prompt_text: str, identifier: str + ): + self.c_function = c_function + self.rust_function = rust_function + self.prompt_text = prompt_text + self.identifier = identifier + + def __str__(self) -> str: + return ( + self.prompt_text + + "\n\n" + + "C function:\n```c\n" + + self.c_function + + "```\n\n" + + "Rust function:\n```rust\n" + + self.rust_function + + "```\n" + ) + + +class AssertsTransform(AbstractTransform): + def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel): + super().__init__(SYSTEM_INSTRUCTION) + self.cache = cache + self.model = model + + @staticmethod + def get_validation_fn(expected_assert_count: int) -> Callable[[str], str]: + def validate_response(rust_fn: str) -> str: + rust_fn = remove_backticks(rust_fn) + + if "__assert_fail(" in rust_fn: + return ( + "FAILURE: Rust function still contains __assert_fail. " + "Rewrite those into assert! calls. " + "Reply with the full Rust function definition only; " + "say nothing else." + ) + + actual_assert_count = rust_fn.count("assert!(") + if actual_assert_count < expected_assert_count: + return ( + "FAILURE: Missing rewritten assert! calls. " + f"Expected at least {expected_assert_count}, " + f"got {actual_assert_count}. " + "Reply with the full Rust function definition only; " + "say nothing else." + ) + + return "SUCCESS: Asserts transformed correctly!" + + return validate_response + + def try_apply_ident( + self, + rust_source_file: Path, + rust_definition: str, + c_definition: CDefinition, + identifier: str, + attempt: int = 0, + previous_error: MergeRustError | None = None, + ) -> TransformCandidate | None: + _ = (rust_source_file, attempt) + expected_assert_count = rust_definition.count("__assert_fail(") + if expected_assert_count == 0: + logging.info( + f"{self.__class__.__name__}: " + f"Skipping function without transpiled asserts: {identifier}" + ) + return + + prompt_text = """ + Rewrite the Rust function below by replacing transpiled C assert-macro + expansions (which call __assert_fail) with idiomatic Rust assert! calls. + + Requirements: + - Preserve function behavior. + - Preserve formatting and indentation. + - Keep all non-assert logic unchanged. + - Return the full Rust function definition only; say nothing else. + """ + prompt_text = dedent(prompt_text).strip() + + prompt = AssertsTransformPrompt( + c_function=c_definition.effective, + rust_function=rust_definition, + prompt_text=prompt_text, + identifier=identifier, + ) + + messages = [{"role": "user", "content": str(prompt)}] + messages = self.with_merge_retry_message(messages, previous_error) + + transform = self.__class__.__name__ + model = self.model.id + + if response := self.cache.lookup( + transform=transform, + identifier=identifier, + model=model, + messages=messages, + ): + rust_fn = remove_backticks(response) + return TransformCandidate( + identifier=identifier, + messages=messages, + response=response, + definition=rust_fn, + ) + + validate_response = self.get_validation_fn(expected_assert_count) + + # TODO: control attempts from command line args + for _attempt in range(3): + response = self.model.generate_with_tools( + messages, tools=[validate_response] + ) + if response is None: + if api_key_from_env(model) is None: + logging.warning( + f"Cache miss for {identifier}; " + "skipping since no API key was set..." + ) + return + logging.warning("Model returned no response") + continue + + if response.strip() == "": + logging.warning("Model returned empty response") + continue + + validation_result = validate_response(response) + if not validation_result.startswith("SUCCESS"): + logging.warning( + f"Model response for {identifier} failed validation: " + + validation_result + + "\nResponse was:\n" + + response + ) + continue + + break + else: + raise TransformError( + f"Model failed to produce valid response after multiple " + f"attempts for {identifier}" + ) + + rust_fn = remove_backticks(response) + if rust_fn == rust_definition: + logging.warning( + f"{self.__class__.__name__}: " + f"No assert rewrite changes for function: {identifier}" + ) + return + + print(get_highlighted_rust(rust_fn)) + + return TransformCandidate( + identifier=identifier, + messages=messages, + response=response, + definition=rust_fn, + ) diff --git a/c2rust-postprocess/postprocess/transforms/base.py b/c2rust-postprocess/postprocess/transforms/base.py index 8476c6e49a..8305fbfe14 100644 --- a/c2rust-postprocess/postprocess/transforms/base.py +++ b/c2rust-postprocess/postprocess/transforms/base.py @@ -1,12 +1,16 @@ import logging import re from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path +from typing import Any, cast from postprocess.definitions import ( CDefinition, + MergeRustError, get_c_definitions, get_rust_definitions, + update_rust_definition, ) from postprocess.exclude_list import IdentifierExcludeList from postprocess.utils import get_highlighted_c @@ -16,11 +20,23 @@ class TransformError(Exception): """A transform failed to process a single definition.""" +@dataclass(frozen=True) +class TransformCandidate: + identifier: str + messages: list[dict[str, Any]] + response: str + # Full replacement Rust definition for `identifier`, ready for merge_rust. + definition: str + cacheable: bool = True + + class AbstractTransform(ABC): """ Abstract base class for LLM-driven transforms of c2rust transpiler output. """ + max_merge_attempts = 3 + def __init__(self, system_instruction: str): self._system_instruction = system_instruction @@ -28,7 +44,6 @@ def __init__(self, system_instruction: str): def system_instruction(self) -> str: return self._system_instruction - @abstractmethod def apply_ident( self, rust_source_file: Path, @@ -38,10 +53,106 @@ def apply_ident( update_rust: bool = True, ) -> str | None: """ - Implementations should apply transform to a single Rust definition - with the given identifier. + Apply transform to one Rust definition and commit it through merge_rust. + """ + previous_error: MergeRustError | None = None + + for attempt in range(self.max_merge_attempts): + candidate = self.try_apply_ident( + rust_source_file=rust_source_file, + rust_definition=rust_definition, + c_definition=c_definition, + identifier=identifier, + attempt=attempt, + previous_error=previous_error, + ) + if candidate is None: + return None + + if candidate.definition == rust_definition: + self.cache_candidate(candidate) + logging.info( + f"{self.__class__.__name__}: " + f"No changes for function: {identifier}" + ) + return None + + try: + if update_rust: + update_rust_definition( + root_rust_source_file=rust_source_file, + identifier=candidate.identifier, + new_definition=candidate.definition, + ) + except MergeRustError as error: + previous_error = error + logging.warning( + f"merge_rust rejected transform candidate for {identifier} " + f"in {rust_source_file} on attempt " + f"{attempt + 1}/{self.max_merge_attempts}: {error}" + ) + continue + + self.cache_candidate(candidate) + + logging.info( + f"{self.__class__.__name__}: " + f"Updated Rust fn {identifier}" + ) + return candidate.definition + + raise TransformError( + f"merge_rust failed after {self.max_merge_attempts} attempts " + f"for {identifier}" + ) from previous_error + + @abstractmethod + def try_apply_ident( + self, + rust_source_file: Path, + rust_definition: str, + c_definition: CDefinition, + identifier: str, + attempt: int = 0, + previous_error: MergeRustError | None = None, + ) -> TransformCandidate | None: + """ + Return a candidate Rust replacement, or None to skip this identifier. """ - pass + raise NotImplementedError + + def cache_candidate(self, candidate: TransformCandidate) -> None: + if not candidate.cacheable: + return + + transform = cast(Any, self) + transform.cache.update( + transform=self.__class__.__name__, + identifier=candidate.identifier, + model=transform.model.id, + messages=candidate.messages, + response=candidate.response, + ) + + @staticmethod + def with_merge_retry_message( + messages: list[dict[str, Any]], previous_error: MergeRustError | None + ) -> list[dict[str, Any]]: + if previous_error is None: + return messages + + return [ + *messages, + { + "role": "user", + "content": ( + "The previous Rust function definition was rejected by " + f"merge_rust:\n{previous_error}\n\n" + "Return a syntactically valid full Rust function definition " + "only; say nothing else." + ), + }, + ] def apply_dir( self, diff --git a/c2rust-postprocess/postprocess/transforms/comments.py b/c2rust-postprocess/postprocess/transforms/comments.py index 06cecee46e..479c9e1e20 100644 --- a/c2rust-postprocess/postprocess/transforms/comments.py +++ b/c2rust-postprocess/postprocess/transforms/comments.py @@ -5,12 +5,16 @@ from postprocess.cache import AbstractCache from postprocess.definitions import ( CDefinition, + MergeRustError, get_c_comments, get_rust_comments, - update_rust_definition, ) from postprocess.models import AbstractGenerativeModel, api_key_from_env -from postprocess.transforms.base import AbstractTransform, TransformError +from postprocess.transforms.base import ( + AbstractTransform, + TransformCandidate, + TransformError, +) from postprocess.transforms.trim import TrimTransform from postprocess.utils import get_highlighted_rust, remove_backticks @@ -76,14 +80,16 @@ def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel): self.model = model self.trim_transform = TrimTransform(cache, model) - def apply_ident( + def try_apply_ident( self, rust_source_file: Path, rust_definition: str, c_definition: CDefinition, identifier: str, - update_rust: bool = True, - ) -> None: + attempt: int = 0, + previous_error: MergeRustError | None = None, + ) -> TransformCandidate | None: + _ = attempt rust_comments = get_rust_comments(rust_definition) if rust_comments: logging.info( @@ -150,6 +156,7 @@ def apply_ident( messages = [ {"role": "user", "content": str(prompt)}, ] + messages = self.with_merge_retry_message(messages, previous_error) transform = self.__class__.__name__ identifier = prompt.identifier @@ -160,7 +167,13 @@ def apply_ident( model=model, messages=messages, ): - return + rust_fn = remove_backticks(response) + return TransformCandidate( + identifier=identifier, + messages=messages, + response=response, + definition=rust_fn, + ) response = self.model.generate_with_tools(messages) @@ -187,24 +200,14 @@ def apply_ident( f"\n{c_comments=}\n{rust_comments=}" ) - self.cache.update( - transform=transform, - identifier=identifier, - model=model, - messages=messages, - response=response, - ) - logging.info( f"Comments transferred to Rust fn {identifier}:\ \n{get_highlighted_rust(rust_fn)}" ) - # TODO: move this to apply_file? - # the challenge is that not all transforms will update Rust code - if update_rust: - update_rust_definition( - root_rust_source_file=rust_source_file, - identifier=prompt.identifier, - new_definition=rust_fn, - ) + return TransformCandidate( + identifier=identifier, + messages=messages, + response=response, + definition=rust_fn, + ) diff --git a/c2rust-postprocess/postprocess/transforms/formatting.py b/c2rust-postprocess/postprocess/transforms/formatting.py new file mode 100644 index 0000000000..c0c958d354 --- /dev/null +++ b/c2rust-postprocess/postprocess/transforms/formatting.py @@ -0,0 +1,222 @@ +import logging +import re +from collections.abc import Callable +from pathlib import Path +from textwrap import dedent + +from postprocess.cache import AbstractCache +from postprocess.definitions import CDefinition, MergeRustError +from postprocess.models import AbstractGenerativeModel, api_key_from_env +from postprocess.transforms.base import ( + AbstractTransform, + TransformCandidate, + TransformError, +) +from postprocess.utils import get_highlighted_rust, remove_backticks + +SYSTEM_INSTRUCTION = ( + "You are a helpful assistant that conservatively reformats c2rust-transpiled " + "Rust functions for compactness while preserving idiomatic Rust formatting." +) + + +class FormattingTransformPrompt: + c_function: str + rust_function: str + prompt_text: str + identifier: str + + __slots__ = ("c_function", "rust_function", "prompt_text", "identifier") + + def __init__( + self, c_function: str, rust_function: str, prompt_text: str, identifier: str + ): + self.c_function = c_function + self.rust_function = rust_function + self.prompt_text = prompt_text + self.identifier = identifier + + def __str__(self) -> str: + return ( + self.prompt_text + + "\n\n" + + "C function:\n```c\n" + + self.c_function + + "```\n\n" + + "Rust function:\n```rust\n" + + self.rust_function + + "```\n" + ) + + +class FormattingTransform(AbstractTransform): + def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel): + super().__init__(SYSTEM_INSTRUCTION) + self.cache = cache + self.model = model + + @staticmethod + def should_attempt_formatting(c_function: str, rust_function: str) -> bool: + c_line_count = len([line for line in c_function.splitlines() if line.strip()]) + rust_line_count = len( + [line for line in rust_function.splitlines() if line.strip()] + ) + + if rust_line_count <= max(c_line_count * 2, c_line_count + 20): + return False + + compactable_item_lines = 0 + for line in rust_function.splitlines(): + stripped = line.strip() + if len(stripped) <= 48 and re.fullmatch(r"[^,{};]+,", stripped): + compactable_item_lines += 1 + + return " = [" in rust_function and compactable_item_lines >= 16 + + @staticmethod + def get_validation_fn(identifier: str) -> Callable[[str], str]: + def validate_response(rust_fn: str) -> str: + rust_fn = remove_backticks(rust_fn).strip() + + if not rust_fn: + return ( + "FAILURE: Empty response. Reply with the full Rust function " + "definition only; say nothing else." + ) + + if "```" in rust_fn: + return ( + "FAILURE: Response contains Markdown code fences. Reply with the " + "full Rust function definition only; say nothing else." + ) + + if f"fn {identifier}" not in rust_fn: + return ( + f"FAILURE: Response does not contain function `{identifier}`. " + "Reply with the full Rust function definition only; say nothing else." # noqa: E501 + ) + + return "SUCCESS: Function formatted correctly!" + + return validate_response + + def try_apply_ident( + self, + rust_source_file: Path, + rust_definition: str, + c_definition: CDefinition, + identifier: str, + attempt: int = 0, + previous_error: MergeRustError | None = None, + ) -> TransformCandidate | None: + _ = (rust_source_file, attempt) + if not self.should_attempt_formatting(c_definition.effective, rust_definition): + logging.info( + f"{self.__class__.__name__}: " + f"Skipping function without obvious compactness issue: {identifier}" + ) + return + + prompt_text = """ + Reformat the Rust function below only where the transpiled formatting is + needlessly verbose compared with the corresponding C function. + + Most Rust functions should stay exactly as rustfmt would format them. Make + changes only for mechanically expanded tables, arrays, lookup data, or similar + data-heavy structures where the Rust version is much longer than the C version + because rustfmt placed one small element per line. + + Requirements: + - Preserve Rust syntax, behavior, attributes, signature, names, types, + expressions, comments, and control flow. + - Do not try to make ordinary Rust statements imitate C brace or indentation + style. Keep ordinary code idiomatic for Rust. + - For compacted data structures, take formatting clues from the C version: + group comparable numbers of elements per line, keep related comments near the + same data, and preserve useful visual structure. + - Add #[rustfmt::skip] to the function if needed so rustfmt will not expand the + compacted data structure again. + - If there is no clear table/array/data-structure compactness problem, return + the original Rust function unchanged. + - Return the full Rust function definition only; say nothing else. + """ + prompt_text = dedent(prompt_text).strip() + + prompt = FormattingTransformPrompt( + c_function=c_definition.effective, + rust_function=rust_definition, + prompt_text=prompt_text, + identifier=identifier, + ) + + messages = [{"role": "user", "content": str(prompt)}] + messages = self.with_merge_retry_message(messages, previous_error) + + transform = self.__class__.__name__ + model = self.model.id + + if response := self.cache.lookup( + transform=transform, + identifier=identifier, + model=model, + messages=messages, + ): + rust_fn = remove_backticks(response) + return TransformCandidate( + identifier=identifier, + messages=messages, + response=response, + definition=rust_fn, + ) + + validate_response = self.get_validation_fn(identifier) + + for _attempt in range(3): + response = self.model.generate_with_tools( + messages, tools=[validate_response] + ) + if response is None: + if api_key_from_env(model) is None: + logging.warning( + f"Cache miss for {identifier}; " + "skipping since no API key was set..." + ) + return + logging.warning("Model returned no response") + continue + + validation_result = validate_response(response) + if not validation_result.startswith("SUCCESS"): + logging.warning( + f"Model response for {identifier} failed validation: " + + validation_result + + "\nResponse was:\n" + + response + ) + continue + + break + else: + raise TransformError( + f"Model failed to produce valid response after multiple " + f"attempts for {identifier}" + ) + + rust_fn = remove_backticks(response) + if rust_fn == rust_definition: + logging.warning( + f"{self.__class__.__name__}: " + f"No formatting changes for function: {identifier}" + ) + else: + logging.info( + f"Formatted Rust fn {identifier}:\ + \n{get_highlighted_rust(rust_fn)}" + ) + + return TransformCandidate( + identifier=identifier, + messages=messages, + response=response, + definition=rust_fn, + ) diff --git a/c2rust-postprocess/postprocess/transforms/trim.py b/c2rust-postprocess/postprocess/transforms/trim.py index 96ab787d80..7e712f5029 100644 --- a/c2rust-postprocess/postprocess/transforms/trim.py +++ b/c2rust-postprocess/postprocess/transforms/trim.py @@ -3,9 +3,9 @@ from textwrap import dedent from postprocess.cache import AbstractCache -from postprocess.definitions import CDefinition, get_c_comments +from postprocess.definitions import CDefinition, MergeRustError, get_c_comments from postprocess.models import AbstractGenerativeModel -from postprocess.transforms.base import AbstractTransform +from postprocess.transforms.base import AbstractTransform, TransformCandidate from postprocess.utils import remove_backticks SYSTEM_INSTRUCTION = ( @@ -21,6 +21,25 @@ def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel): self.cache = cache self.model = model + def try_apply_ident( + self, + rust_source_file: Path, + rust_definition: str, + c_definition: CDefinition, + identifier: str, + attempt: int = 0, + previous_error: MergeRustError | None = None, + ) -> TransformCandidate | None: + _ = ( + rust_source_file, + rust_definition, + c_definition, + identifier, + attempt, + previous_error, + ) + raise NotImplementedError("TrimTransform overrides apply_ident directly") + def apply_ident( self, rust_source_file: Path, diff --git a/tools/merge_rust/Cargo.lock b/tools/merge_rust/Cargo.lock index ca20e777f6..8c17b06205 100644 --- a/tools/merge_rust/Cargo.lock +++ b/tools/merge_rust/Cargo.lock @@ -98,12 +98,36 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "hashbrown" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown", + "serde", + "serde_core", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -127,6 +151,7 @@ name = "merge_rust" version = "0.1.0" dependencies = [ "clap", + "indexmap", "proc-macro2", "quote", "rust_util", diff --git a/tools/merge_rust/src/main.rs b/tools/merge_rust/src/main.rs index 0527912c2c..c3067392b2 100644 --- a/tools/merge_rust/src/main.rs +++ b/tools/merge_rust/src/main.rs @@ -173,6 +173,7 @@ fn main() { } // Apply the collected rewrites to each file. + let mut pending_writes = Vec::new(); for (file_path, mut rewrites) in file_rewrites { if rewrites.len() == 0 { continue; @@ -201,9 +202,18 @@ fn main() { } new_src.push_str(&old_src[pos..]); + if let Err(error) = syn::parse_file(&new_src) { + eprintln!("merged Rust in {:?} failed to parse: {error}", file_path); + std::process::exit(1); + } + + pending_writes.push((file_path, new_src, rewrites.len())); + } + + for (file_path, new_src, rewrite_count) in pending_writes { let tmp_path = file_path.with_extension(".new"); fs::write(&tmp_path, &new_src).unwrap(); fs::rename(&tmp_path, &file_path).unwrap(); - eprintln!("applied {} rewrites to {:?}", rewrites.len(), file_path); + eprintln!("applied {} rewrites to {:?}", rewrite_count, file_path); } } diff --git a/tools/merge_rust/tests/reject_invalid.rs b/tools/merge_rust/tests/reject_invalid.rs new file mode 100644 index 0000000000..74154480f3 --- /dev/null +++ b/tools/merge_rust/tests/reject_invalid.rs @@ -0,0 +1,57 @@ +use std::fs; +use std::path::PathBuf; +use std::process::Command; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn temp_dir() -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + let path = std::env::temp_dir().join(format!( + "merge_rust_reject_invalid_{}_{}", + std::process::id(), + nanos + )); + fs::create_dir(&path).unwrap(); + path +} + +#[test] +fn update_only_rejects_invalid_replacement_without_writing() { + let dir = temp_dir(); + let src_path = dir.join("lib.rs"); + let snippets_path = dir.join("snippets.json"); + let original = "\ +pub fn good() -> i32 { + 1 +} + +pub fn other() -> i32 { + 2 +} +"; + + fs::write(&src_path, original).unwrap(); + fs::write( + &snippets_path, + "{\"good\":\"pub fn good() -> i32 {\\n 1\\n\"}", + ) + .unwrap(); + + let output = Command::new(env!("CARGO_BIN_EXE_merge_rust")) + .arg("--update-only") + .arg(&src_path) + .arg(&snippets_path) + .output() + .unwrap(); + + assert!(!output.status.success()); + assert_eq!(fs::read_to_string(&src_path).unwrap(), original); + + let stderr = String::from_utf8(output.stderr).unwrap(); + assert!(stderr.contains("merged Rust in")); + assert!(stderr.contains("failed to parse")); + + fs::remove_dir_all(dir).unwrap(); +}