Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions c2rust-postprocess/postprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from postprocess.models.gpt import GPTModel
from postprocess.models.mock import MockGenerativeModel
from postprocess.transforms import get_transform_by_id
from postprocess.transforms.base import TransformError
from postprocess.transforms.comments import (
SYSTEM_INSTRUCTION,
AbstractGenerativeModel,
Expand Down Expand Up @@ -97,6 +98,16 @@ def build_arg_parser() -> argparse.ArgumentParser:
help="Update the Rust in-place",
)

parser.add_argument(
"--on-error",
type=str,
required=False,
default="keep-going",
choices=["abort", "keep-going", "warn"],
help="Handle per-function transform failures: abort at first failure,"
" keep going with exit 1, or warn and exit 0 (default: keep-going)",
)

parser.add_argument(
"--transform",
type=str,
Expand Down Expand Up @@ -174,19 +185,39 @@ def main(argv: Sequence[str] | None = None):
if transform_id.strip()
)
transforms = [
get_transform_by_id(transform_id, cache=cache, model=model)
get_transform_by_id(
transform_id,
cache=cache,
model=model,
)
for transform_id in transform_ids
]

failures = 0
failure_log_level = (
logging.WARNING if args.on_error == "warn" else logging.ERROR
)
for transform in transforms:
transform.apply_dir(
failures += transform.apply_dir(
root_rust_source_file=args.root_rust_source_file,
exclude_list=IdentifierExcludeList(src_path=args.exclude_file),
ident_filter=args.ident_filter,
update_rust=args.update_rust,
keep_going=args.on_error != "abort",
failure_log_level=failure_log_level,
)

if failures:
logging.log(
failure_log_level, f"Failed to transform {failures} function(s)"
)
if args.on_error != "warn":
return 1

return 0
except TransformError as error:
logging.exception(f"Aborting at first transform failure: {error}")
return 1
except KeyboardInterrupt:
logging.warning("Interrupted by user, terminating...")
return 130 # 128 + SIGINT(2)
46 changes: 36 additions & 10 deletions c2rust-postprocess/postprocess/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from postprocess.utils import get_highlighted_c


class TransformError(Exception):
"""A transform failed to process a single definition."""


class AbstractTransform(ABC):
"""
Abstract base class for LLM-driven transforms of c2rust transpiler output.
Expand Down Expand Up @@ -44,33 +48,44 @@ def apply_dir(
exclude_list: IdentifierExcludeList,
ident_filter: str | None = None,
update_rust: bool = True,
):
keep_going: bool = False,
failure_log_level: int = logging.ERROR,
) -> int:
"""
Run `self.apply_file` on each `*.rs` in `dir`
with a corresponding `*.c_decls.json`.

Returns the number of definitions that failed to transform.
"""
failures = 0
root_dir = root_rust_source_file.parent
c_decls_json_suffix = ".c_decls.json"
for c_decls_path in root_dir.glob(f"**/*{c_decls_json_suffix}"):
rs_path = c_decls_path.with_name(
c_decls_path.name.removesuffix(c_decls_json_suffix) + ".rs"
)
assert rs_path.exists()
self.apply_file(
failures += self.apply_file(
rust_source_file=rs_path,
exclude_list=exclude_list,
ident_filter=ident_filter,
update_rust=update_rust,
keep_going=keep_going,
failure_log_level=failure_log_level,
)
return failures

def apply_file(
self,
rust_source_file: Path,
exclude_list: IdentifierExcludeList,
ident_filter: str | None = None,
update_rust: bool = True,
) -> None:
keep_going: bool = False,
failure_log_level: int = logging.ERROR,
) -> int:
ident_regex = re.compile(ident_filter) if ident_filter else None
failures = 0

rust_definitions = get_rust_definitions(rust_source_file)
c_definitions = get_c_definitions(rust_source_file)
Expand Down Expand Up @@ -104,13 +119,24 @@ def apply_file(
f"C function {identifier} definition:\n{highlighted_c_definition}\n"
)

self.apply_ident(
rust_source_file=rust_source_file,
rust_definition=rust_definition,
c_definition=c_definition,
identifier=identifier,
update_rust=update_rust,
)
try:
self.apply_ident(
rust_source_file=rust_source_file,
rust_definition=rust_definition,
c_definition=c_definition,
identifier=identifier,
update_rust=update_rust,
)
except TransformError as error:
if not keep_going:
raise
logging.log(
failure_log_level,
f"Transform failed for {identifier} in {rust_source_file}: {error}",
)
failures += 1

return failures


# TODO: We probably want a an interface that generates validators specialized to
Expand Down
30 changes: 16 additions & 14 deletions c2rust-postprocess/postprocess/transforms/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
update_rust_definition,
)
from postprocess.models import AbstractGenerativeModel, api_key_from_env
from postprocess.transforms.base import AbstractTransform
from postprocess.transforms.base import AbstractTransform, TransformError
from postprocess.utils import get_highlighted_rust, remove_backticks

# TODO: get from model
Expand Down Expand Up @@ -115,18 +115,8 @@ def apply_ident(
logging.warning(
f"Cache miss for {identifier}; skipping since no API key was set..."
)
else:
logging.error(f"Model returned no response for {identifier}")
return

# TODO: move this to apply_file?
self.cache.update(
transform=transform,
identifier=identifier,
model=model,
messages=messages,
response=response,
)
return
raise TransformError(f"model returned no response for {identifier}")

rust_fn = remove_backticks(response)

Expand All @@ -136,7 +126,19 @@ def apply_ident(
rust_comments = get_rust_comments(rust_fn)
logging.debug(f"{rust_comments=}")

assert c_comments == rust_comments
if c_comments != rust_comments:
raise TransformError(
f"comments were not transferred verbatim for {identifier}:"
f"\n{c_comments=}\n{rust_comments=}"
)

self.cache.update(
transform=transform,
identifier=identifier,
model=model,
messages=messages,
response=response,
)

print(get_highlighted_rust(rust_fn))

Expand Down
Loading