diff --git a/c2rust-postprocess/postprocess/__init__.py b/c2rust-postprocess/postprocess/__init__.py index fdf9374f90..306d0a1cca 100644 --- a/c2rust-postprocess/postprocess/__init__.py +++ b/c2rust-postprocess/postprocess/__init__.py @@ -17,6 +17,7 @@ from postprocess.models.gpt import GPTModel from postprocess.models.mock import MockGenerativeModel from postprocess.transforms import get_transform_by_id +from postprocess.transforms.base import TransformError from postprocess.transforms.comments import ( SYSTEM_INSTRUCTION, AbstractGenerativeModel, @@ -97,6 +98,16 @@ def build_arg_parser() -> argparse.ArgumentParser: help="Update the Rust in-place", ) + parser.add_argument( + "--on-error", + type=str, + required=False, + default="keep-going", + choices=["abort", "keep-going", "warn"], + help="Handle per-function transform failures: abort at first failure," + " keep going with exit 1, or warn and exit 0 (default: keep-going)", + ) + parser.add_argument( "--transform", type=str, @@ -174,19 +185,39 @@ def main(argv: Sequence[str] | None = None): if transform_id.strip() ) transforms = [ - get_transform_by_id(transform_id, cache=cache, model=model) + get_transform_by_id( + transform_id, + cache=cache, + model=model, + ) for transform_id in transform_ids ] + failures = 0 + failure_log_level = ( + logging.WARNING if args.on_error == "warn" else logging.ERROR + ) for transform in transforms: - transform.apply_dir( + failures += transform.apply_dir( root_rust_source_file=args.root_rust_source_file, exclude_list=IdentifierExcludeList(src_path=args.exclude_file), ident_filter=args.ident_filter, update_rust=args.update_rust, + keep_going=args.on_error != "abort", + failure_log_level=failure_log_level, + ) + + if failures: + logging.log( + failure_log_level, f"Failed to transform {failures} function(s)" ) + if args.on_error != "warn": + return 1 return 0 + except TransformError as error: + logging.exception(f"Aborting at first transform failure: {error}") + return 1 except KeyboardInterrupt: logging.warning("Interrupted by user, terminating...") return 130 # 128 + SIGINT(2) diff --git a/c2rust-postprocess/postprocess/transforms/base.py b/c2rust-postprocess/postprocess/transforms/base.py index 273d3ad53d..16a144a4f4 100644 --- a/c2rust-postprocess/postprocess/transforms/base.py +++ b/c2rust-postprocess/postprocess/transforms/base.py @@ -11,6 +11,10 @@ from postprocess.utils import get_highlighted_c +class TransformError(Exception): + """A transform failed to process a single definition.""" + + class AbstractTransform(ABC): """ Abstract base class for LLM-driven transforms of c2rust transpiler output. @@ -44,11 +48,16 @@ def apply_dir( exclude_list: IdentifierExcludeList, ident_filter: str | None = None, update_rust: bool = True, - ): + keep_going: bool = False, + failure_log_level: int = logging.ERROR, + ) -> int: """ Run `self.apply_file` on each `*.rs` in `dir` with a corresponding `*.c_decls.json`. + + Returns the number of definitions that failed to transform. """ + failures = 0 root_dir = root_rust_source_file.parent c_decls_json_suffix = ".c_decls.json" for c_decls_path in root_dir.glob(f"**/*{c_decls_json_suffix}"): @@ -56,12 +65,15 @@ def apply_dir( c_decls_path.name.removesuffix(c_decls_json_suffix) + ".rs" ) assert rs_path.exists() - self.apply_file( + failures += self.apply_file( rust_source_file=rs_path, exclude_list=exclude_list, ident_filter=ident_filter, update_rust=update_rust, + keep_going=keep_going, + failure_log_level=failure_log_level, ) + return failures def apply_file( self, @@ -69,8 +81,11 @@ def apply_file( exclude_list: IdentifierExcludeList, ident_filter: str | None = None, update_rust: bool = True, - ) -> None: + keep_going: bool = False, + failure_log_level: int = logging.ERROR, + ) -> int: ident_regex = re.compile(ident_filter) if ident_filter else None + failures = 0 rust_definitions = get_rust_definitions(rust_source_file) c_definitions = get_c_definitions(rust_source_file) @@ -104,13 +119,24 @@ def apply_file( f"C function {identifier} definition:\n{highlighted_c_definition}\n" ) - self.apply_ident( - rust_source_file=rust_source_file, - rust_definition=rust_definition, - c_definition=c_definition, - identifier=identifier, - update_rust=update_rust, - ) + try: + self.apply_ident( + rust_source_file=rust_source_file, + rust_definition=rust_definition, + c_definition=c_definition, + identifier=identifier, + update_rust=update_rust, + ) + except TransformError as error: + if not keep_going: + raise + logging.log( + failure_log_level, + f"Transform failed for {identifier} in {rust_source_file}: {error}", + ) + failures += 1 + + return failures # TODO: We probably want a an interface that generates validators specialized to diff --git a/c2rust-postprocess/postprocess/transforms/comments.py b/c2rust-postprocess/postprocess/transforms/comments.py index 60b2212329..d750d03f98 100644 --- a/c2rust-postprocess/postprocess/transforms/comments.py +++ b/c2rust-postprocess/postprocess/transforms/comments.py @@ -9,7 +9,7 @@ update_rust_definition, ) from postprocess.models import AbstractGenerativeModel, api_key_from_env -from postprocess.transforms.base import AbstractTransform +from postprocess.transforms.base import AbstractTransform, TransformError from postprocess.utils import get_highlighted_rust, remove_backticks # TODO: get from model @@ -115,18 +115,8 @@ def apply_ident( logging.warning( f"Cache miss for {identifier}; skipping since no API key was set..." ) - else: - logging.error(f"Model returned no response for {identifier}") - return - - # TODO: move this to apply_file? - self.cache.update( - transform=transform, - identifier=identifier, - model=model, - messages=messages, - response=response, - ) + return + raise TransformError(f"model returned no response for {identifier}") rust_fn = remove_backticks(response) @@ -136,7 +126,19 @@ def apply_ident( rust_comments = get_rust_comments(rust_fn) logging.debug(f"{rust_comments=}") - assert c_comments == rust_comments + if c_comments != rust_comments: + raise TransformError( + f"comments were not transferred verbatim for {identifier}:" + f"\n{c_comments=}\n{rust_comments=}" + ) + + self.cache.update( + transform=transform, + identifier=identifier, + model=model, + messages=messages, + response=response, + ) print(get_highlighted_rust(rust_fn))