From 516ffa042d6ac867f66e0a9d1ea85a04d62e9c71 Mon Sep 17 00:00:00 2001 From: akourne Date: Thu, 11 Jun 2026 13:51:09 -0800 Subject: [PATCH] postprocess: Add on-error policy for transform failures Add an --on-error option for per-function transform failures. The default keep-going mode continues after TransformError failures and exits nonzero if any function failed. abort stops at the first TransformError. warn continues after TransformError failures, reports them as warnings, and exits successfully. Only TransformError is treated as a recoverable per-function failure; other exceptions still surface as normal program errors. --- c2rust-postprocess/postprocess/__init__.py | 35 +++++++++++++- .../postprocess/transforms/base.py | 46 +++++++++++++++---- .../postprocess/transforms/comments.py | 30 ++++++------ 3 files changed, 85 insertions(+), 26 deletions(-) diff --git a/c2rust-postprocess/postprocess/__init__.py b/c2rust-postprocess/postprocess/__init__.py index fdf9374f90..306d0a1cca 100644 --- a/c2rust-postprocess/postprocess/__init__.py +++ b/c2rust-postprocess/postprocess/__init__.py @@ -17,6 +17,7 @@ from postprocess.models.gpt import GPTModel from postprocess.models.mock import MockGenerativeModel from postprocess.transforms import get_transform_by_id +from postprocess.transforms.base import TransformError from postprocess.transforms.comments import ( SYSTEM_INSTRUCTION, AbstractGenerativeModel, @@ -97,6 +98,16 @@ def build_arg_parser() -> argparse.ArgumentParser: help="Update the Rust in-place", ) + parser.add_argument( + "--on-error", + type=str, + required=False, + default="keep-going", + choices=["abort", "keep-going", "warn"], + help="Handle per-function transform failures: abort at first failure," + " keep going with exit 1, or warn and exit 0 (default: keep-going)", + ) + parser.add_argument( "--transform", type=str, @@ -174,19 +185,39 @@ def main(argv: Sequence[str] | None = None): if transform_id.strip() ) transforms = [ - get_transform_by_id(transform_id, cache=cache, model=model) + get_transform_by_id( + transform_id, + cache=cache, + model=model, + ) for transform_id in transform_ids ] + failures = 0 + failure_log_level = ( + logging.WARNING if args.on_error == "warn" else logging.ERROR + ) for transform in transforms: - transform.apply_dir( + failures += transform.apply_dir( root_rust_source_file=args.root_rust_source_file, exclude_list=IdentifierExcludeList(src_path=args.exclude_file), ident_filter=args.ident_filter, update_rust=args.update_rust, + keep_going=args.on_error != "abort", + failure_log_level=failure_log_level, + ) + + if failures: + logging.log( + failure_log_level, f"Failed to transform {failures} function(s)" ) + if args.on_error != "warn": + return 1 return 0 + except TransformError as error: + logging.exception(f"Aborting at first transform failure: {error}") + return 1 except KeyboardInterrupt: logging.warning("Interrupted by user, terminating...") return 130 # 128 + SIGINT(2) diff --git a/c2rust-postprocess/postprocess/transforms/base.py b/c2rust-postprocess/postprocess/transforms/base.py index 273d3ad53d..16a144a4f4 100644 --- a/c2rust-postprocess/postprocess/transforms/base.py +++ b/c2rust-postprocess/postprocess/transforms/base.py @@ -11,6 +11,10 @@ from postprocess.utils import get_highlighted_c +class TransformError(Exception): + """A transform failed to process a single definition.""" + + class AbstractTransform(ABC): """ Abstract base class for LLM-driven transforms of c2rust transpiler output. @@ -44,11 +48,16 @@ def apply_dir( exclude_list: IdentifierExcludeList, ident_filter: str | None = None, update_rust: bool = True, - ): + keep_going: bool = False, + failure_log_level: int = logging.ERROR, + ) -> int: """ Run `self.apply_file` on each `*.rs` in `dir` with a corresponding `*.c_decls.json`. + + Returns the number of definitions that failed to transform. """ + failures = 0 root_dir = root_rust_source_file.parent c_decls_json_suffix = ".c_decls.json" for c_decls_path in root_dir.glob(f"**/*{c_decls_json_suffix}"): @@ -56,12 +65,15 @@ def apply_dir( c_decls_path.name.removesuffix(c_decls_json_suffix) + ".rs" ) assert rs_path.exists() - self.apply_file( + failures += self.apply_file( rust_source_file=rs_path, exclude_list=exclude_list, ident_filter=ident_filter, update_rust=update_rust, + keep_going=keep_going, + failure_log_level=failure_log_level, ) + return failures def apply_file( self, @@ -69,8 +81,11 @@ def apply_file( exclude_list: IdentifierExcludeList, ident_filter: str | None = None, update_rust: bool = True, - ) -> None: + keep_going: bool = False, + failure_log_level: int = logging.ERROR, + ) -> int: ident_regex = re.compile(ident_filter) if ident_filter else None + failures = 0 rust_definitions = get_rust_definitions(rust_source_file) c_definitions = get_c_definitions(rust_source_file) @@ -104,13 +119,24 @@ def apply_file( f"C function {identifier} definition:\n{highlighted_c_definition}\n" ) - self.apply_ident( - rust_source_file=rust_source_file, - rust_definition=rust_definition, - c_definition=c_definition, - identifier=identifier, - update_rust=update_rust, - ) + try: + self.apply_ident( + rust_source_file=rust_source_file, + rust_definition=rust_definition, + c_definition=c_definition, + identifier=identifier, + update_rust=update_rust, + ) + except TransformError as error: + if not keep_going: + raise + logging.log( + failure_log_level, + f"Transform failed for {identifier} in {rust_source_file}: {error}", + ) + failures += 1 + + return failures # TODO: We probably want a an interface that generates validators specialized to diff --git a/c2rust-postprocess/postprocess/transforms/comments.py b/c2rust-postprocess/postprocess/transforms/comments.py index 60b2212329..d750d03f98 100644 --- a/c2rust-postprocess/postprocess/transforms/comments.py +++ b/c2rust-postprocess/postprocess/transforms/comments.py @@ -9,7 +9,7 @@ update_rust_definition, ) from postprocess.models import AbstractGenerativeModel, api_key_from_env -from postprocess.transforms.base import AbstractTransform +from postprocess.transforms.base import AbstractTransform, TransformError from postprocess.utils import get_highlighted_rust, remove_backticks # TODO: get from model @@ -115,18 +115,8 @@ def apply_ident( logging.warning( f"Cache miss for {identifier}; skipping since no API key was set..." ) - else: - logging.error(f"Model returned no response for {identifier}") - return - - # TODO: move this to apply_file? - self.cache.update( - transform=transform, - identifier=identifier, - model=model, - messages=messages, - response=response, - ) + return + raise TransformError(f"model returned no response for {identifier}") rust_fn = remove_backticks(response) @@ -136,7 +126,19 @@ def apply_ident( rust_comments = get_rust_comments(rust_fn) logging.debug(f"{rust_comments=}") - assert c_comments == rust_comments + if c_comments != rust_comments: + raise TransformError( + f"comments were not transferred verbatim for {identifier}:" + f"\n{c_comments=}\n{rust_comments=}" + ) + + self.cache.update( + transform=transform, + identifier=identifier, + model=model, + messages=messages, + response=response, + ) print(get_highlighted_rust(rust_fn))