diff --git a/c2rust-postprocess/postprocess/transforms/base.py b/c2rust-postprocess/postprocess/transforms/base.py index 5ce27c6781..8476c6e49a 100644 --- a/c2rust-postprocess/postprocess/transforms/base.py +++ b/c2rust-postprocess/postprocess/transforms/base.py @@ -36,7 +36,7 @@ def apply_ident( c_definition: CDefinition, identifier: str, update_rust: bool = True, - ) -> None: + ) -> str | None: """ Implementations should apply transform to a single Rust definition with the given identifier. diff --git a/c2rust-postprocess/postprocess/transforms/comments.py b/c2rust-postprocess/postprocess/transforms/comments.py index 4bf5c644a2..06cecee46e 100644 --- a/c2rust-postprocess/postprocess/transforms/comments.py +++ b/c2rust-postprocess/postprocess/transforms/comments.py @@ -11,6 +11,7 @@ ) from postprocess.models import AbstractGenerativeModel, api_key_from_env from postprocess.transforms.base import AbstractTransform, TransformError +from postprocess.transforms.trim import TrimTransform from postprocess.utils import get_highlighted_rust, remove_backticks # TODO: get from model @@ -73,6 +74,7 @@ def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel): super().__init__(SYSTEM_INSTRUCTION) self.cache = cache self.model = model + self.trim_transform = TrimTransform(cache, model) def apply_ident( self, @@ -86,8 +88,7 @@ def apply_ident( if rust_comments: logging.info( f"Skipping Rust fn {identifier} with existing comments:\ - \n{rust_comments} in\ - \n{rust_definition}" + \n{get_highlighted_rust(rust_definition)}" ) return @@ -98,6 +99,31 @@ def apply_ident( logging.info(f"Skipping C function without comments: {identifier}") return + match self.trim_transform.apply_ident( + rust_source_file=rust_source_file, + rust_definition=rust_definition, + c_definition=c_definition, + identifier=identifier, + update_rust=False, # nothing to update here + ): + case None: + logging.error( + f"Trim transform failed for {identifier}, " + "skipping comments transfer" + ) + return + case str() as trimmed_c_definition: + # TODO: consider trimming both the definition and the preprocessed + # definition instead of possibly replacing the original + # definition with the trimmed and preprocessed one. + c_definition = CDefinition( + definition=trimmed_c_definition, preprocessed_definition=None + ) + case _: + raise AssertionError( + "Unexpected return type from trim transform: expected None or str" + ) + # TODO: make this function take a model and get prompt from model prompt_text = """ Transfer the comments from the following C function to the corresponding Rust function. @@ -169,7 +195,10 @@ def apply_ident( response=response, ) - print(get_highlighted_rust(rust_fn)) + logging.info( + f"Comments transferred to Rust fn {identifier}:\ + \n{get_highlighted_rust(rust_fn)}" + ) # TODO: move this to apply_file? # the challenge is that not all transforms will update Rust code diff --git a/c2rust-postprocess/postprocess/transforms/trim.py b/c2rust-postprocess/postprocess/transforms/trim.py new file mode 100644 index 0000000000..96ab787d80 --- /dev/null +++ b/c2rust-postprocess/postprocess/transforms/trim.py @@ -0,0 +1,88 @@ +import logging +from pathlib import Path +from textwrap import dedent + +from postprocess.cache import AbstractCache +from postprocess.definitions import CDefinition, get_c_comments +from postprocess.models import AbstractGenerativeModel +from postprocess.transforms.base import AbstractTransform +from postprocess.utils import remove_backticks + +SYSTEM_INSTRUCTION = ( + "You are a helpful assistant that removes top-of-file prologues and" + " other unnecessary content such as preprocessor definitions and" + " comments that preceede and are unrelated to the definition of a C function." +) + + +class TrimTransform(AbstractTransform): + def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel): + super().__init__(SYSTEM_INSTRUCTION) + self.cache = cache + self.model = model + + def apply_ident( + self, + rust_source_file: Path, + rust_definition: str, + c_definition: CDefinition, + identifier: str, + update_rust: bool = True, + ) -> str | None: + c_comments = get_c_comments(c_definition.effective) + if not c_comments: + logging.info( + f"{self.__class__.__name__}: " + f"Skipping C function without comments: {identifier}" + ) + return + + prompt = """ + Remove any prologues, preprocessor definitions, and comments that are unrelated to the definition + of the C function `{identifier}`. Respond with the trimmed C function definition; say nothing else. + + C function: + ```c + {c_definition} + ``` + """ # noqa: E501 + prompt = dedent( + prompt + ).strip() # note: dedent then format since the C definition isn't indented + prompt = prompt.format( + identifier=identifier, c_definition=c_definition.effective + ) + + messages = [ + {"role": "user", "content": prompt}, + ] + + transform = self.__class__.__name__ + model = self.model.id + if response := self.cache.lookup( + transform=transform, + identifier=identifier, + model=model, + messages=messages, + ): + return response + + response = self.model.generate_with_tools(messages) + + if response is None: + logging.error("Model returned no response") + return response + + # TODO: validate that function definition is still present? + + response = remove_backticks(response) + + self.cache.update( + transform=transform, + identifier=identifier, + model=model, + messages=messages, + response=response, + ) + + return response diff --git a/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/bf27c23abbb2f5c2d7dba1891b748252983959f5dc232f35146250b6d2d1fb29/metadata.toml b/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/37c722369b72db3bc2af3c6d410ca7c363cfe54a7adef376c1e9a7f38beffd2d/metadata.toml similarity index 88% rename from c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/bf27c23abbb2f5c2d7dba1891b748252983959f5dc232f35146250b6d2d1fb29/metadata.toml rename to c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/37c722369b72db3bc2af3c6d410ca7c363cfe54a7adef376c1e9a7f38beffd2d/metadata.toml index 7b0f8f2f8b..858ba90e7b 100644 --- a/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/bf27c23abbb2f5c2d7dba1891b748252983959f5dc232f35146250b6d2d1fb29/metadata.toml +++ b/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/37c722369b72db3bc2af3c6d410ca7c363cfe54a7adef376c1e9a7f38beffd2d/metadata.toml @@ -1,16 +1,13 @@ transform = "CommentsTransform" identifier = "partition" model = "gemini-3-flash-preview" -response = """/// Lomuto Partition Scheme: -/// Partitions the array so that elements < pivot are on the left, -/// and elements >= pivot are on the right. +response = """/// Partition the subarray around the last element as pivot and return pivot's final index. #[no_mangle] pub unsafe extern "C" fn partition( mut arr: *mut ::core::ffi::c_int, mut low: ::core::ffi::c_int, mut high: ::core::ffi::c_int, ) -> ::core::ffi::c_int { - // Partition the subarray around the last element as pivot and return pivot's final index. let mut pivot: ::core::ffi::c_int = *arr.offset(high as isize); let mut i: ::core::ffi::c_int = low - 1 as ::core::ffi::c_int; let mut j: ::core::ffi::c_int = low; @@ -42,11 +39,6 @@ Respond with the Rust function definition with the transferred comments; say not C function: ```c -/* - * Lomuto Partition Scheme: - * Partitions the array so that elements < pivot are on the left, - * and elements >= pivot are on the right. - */ int partition (int arr[], int low, int high) { // Partition the subarray around the last element as pivot and return pivot's final index. diff --git a/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/bf27c23abbb2f5c2d7dba1891b748252983959f5dc232f35146250b6d2d1fb29/response.txt b/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/37c722369b72db3bc2af3c6d410ca7c363cfe54a7adef376c1e9a7f38beffd2d/response.txt similarity index 81% rename from c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/bf27c23abbb2f5c2d7dba1891b748252983959f5dc232f35146250b6d2d1fb29/response.txt rename to c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/37c722369b72db3bc2af3c6d410ca7c363cfe54a7adef376c1e9a7f38beffd2d/response.txt index d4daefdadb..6542e787af 100644 --- a/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/bf27c23abbb2f5c2d7dba1891b748252983959f5dc232f35146250b6d2d1fb29/response.txt +++ b/c2rust-postprocess/tests/llm-cache/CommentsTransform/partition/37c722369b72db3bc2af3c6d410ca7c363cfe54a7adef376c1e9a7f38beffd2d/response.txt @@ -1,13 +1,10 @@ -/// Lomuto Partition Scheme: -/// Partitions the array so that elements < pivot are on the left, -/// and elements >= pivot are on the right. +/// Partition the subarray around the last element as pivot and return pivot's final index. #[no_mangle] pub unsafe extern "C" fn partition( mut arr: *mut ::core::ffi::c_int, mut low: ::core::ffi::c_int, mut high: ::core::ffi::c_int, ) -> ::core::ffi::c_int { - // Partition the subarray around the last element as pivot and return pivot's final index. let mut pivot: ::core::ffi::c_int = *arr.offset(high as isize); let mut i: ::core::ffi::c_int = low - 1 as ::core::ffi::c_int; let mut j: ::core::ffi::c_int = low; diff --git a/c2rust-postprocess/tests/llm-cache/TrimTransform/partition/a4e1c68bf25ff0d12996deed99aff2f12fb77a27642c356431d8fb7bd012b82e/metadata.toml b/c2rust-postprocess/tests/llm-cache/TrimTransform/partition/a4e1c68bf25ff0d12996deed99aff2f12fb77a27642c356431d8fb7bd012b82e/metadata.toml new file mode 100644 index 0000000000..a8ffba29a3 --- /dev/null +++ b/c2rust-postprocess/tests/llm-cache/TrimTransform/partition/a4e1c68bf25ff0d12996deed99aff2f12fb77a27642c356431d8fb7bd012b82e/metadata.toml @@ -0,0 +1,51 @@ +transform = "TrimTransform" +identifier = "partition" +model = "gemini-3-flash-preview" +response = """int partition (int arr[], int low, int high) +{ + // Partition the subarray around the last element as pivot and return pivot's final index. + int pivot = arr[high]; + int i = low - 1; + + for (int j = low; j <= high - 1; j++) { + if (arr[j] <= pivot) { + i++; + // Move elements <= pivot into the left partition. + swap(&arr[i], &arr[j]); + } + } + // Place pivot just after the final element of the left partition. + swap(&arr[i + 1], &arr[high]); + return i + 1; +}""" + +[[messages]] +role = "user" +content = """Remove any prologues, preprocessor definitions, and comments that are unrelated to the definition +of the C function `partition`. Respond with the trimmed C function definition; say nothing else. + +C function: +```c +/* + * Lomuto Partition Scheme: + * Partitions the array so that elements < pivot are on the left, + * and elements >= pivot are on the right. + */ +int partition (int arr[], int low, int high) +{ + // Partition the subarray around the last element as pivot and return pivot's final index. + int pivot = arr[high]; + int i = low - 1; + + for (int j = low; j <= high - 1; j++) { + if (arr[j] <= pivot) { + i++; + // Move elements <= pivot into the left partition. + swap(&arr[i], &arr[j]); + } + } + // Place pivot just after the final element of the left partition. + swap(&arr[i + 1], &arr[high]); + return i + 1; +} +```""" diff --git a/c2rust-postprocess/tests/llm-cache/TrimTransform/partition/a4e1c68bf25ff0d12996deed99aff2f12fb77a27642c356431d8fb7bd012b82e/response.txt b/c2rust-postprocess/tests/llm-cache/TrimTransform/partition/a4e1c68bf25ff0d12996deed99aff2f12fb77a27642c356431d8fb7bd012b82e/response.txt new file mode 100644 index 0000000000..dd535b54b6 --- /dev/null +++ b/c2rust-postprocess/tests/llm-cache/TrimTransform/partition/a4e1c68bf25ff0d12996deed99aff2f12fb77a27642c356431d8fb7bd012b82e/response.txt @@ -0,0 +1,17 @@ +int partition (int arr[], int low, int high) +{ + // Partition the subarray around the last element as pivot and return pivot's final index. + int pivot = arr[high]; + int i = low - 1; + + for (int j = low; j <= high - 1; j++) { + if (arr[j] <= pivot) { + i++; + // Move elements <= pivot into the left partition. + swap(&arr[i], &arr[j]); + } + } + // Place pivot just after the final element of the left partition. + swap(&arr[i + 1], &arr[high]); + return i + 1; +} \ No newline at end of file diff --git a/c2rust-postprocess/tests/llm-cache/TrimTransform/quickSort/7f54c99381ba6e83aae5416848fb8435ec3538f98668b1ce6d12953bc1d19990/metadata.toml b/c2rust-postprocess/tests/llm-cache/TrimTransform/quickSort/7f54c99381ba6e83aae5416848fb8435ec3538f98668b1ce6d12953bc1d19990/metadata.toml new file mode 100644 index 0000000000..66d33d0f58 --- /dev/null +++ b/c2rust-postprocess/tests/llm-cache/TrimTransform/quickSort/7f54c99381ba6e83aae5416848fb8435ec3538f98668b1ce6d12953bc1d19990/metadata.toml @@ -0,0 +1,34 @@ +transform = "TrimTransform" +identifier = "quickSort" +model = "gemini-3-flash-preview" +response = """void quickSort(int arr[], int low, int high) +{ + if (low < high) { + /* pi is the partitioning index; arr[pi] is now at the right place */ + int pi = partition(arr, low, high); + + /* Recursively sort elements before and after partition */ + quickSort(arr, low, pi - 1); + quickSort(arr, pi + 1, high); + } +}""" + +[[messages]] +role = "user" +content = """Remove any prologues, preprocessor definitions, and comments that are unrelated to the definition +of the C function `quickSort`. Respond with the trimmed C function definition; say nothing else. + +C function: +```c +void quickSort(int arr[], int low, int high) +{ + if (low < high) { + /* pi is the partitioning index; arr[pi] is now at the right place */ + int pi = partition(arr, low, high); + + /* Recursively sort elements before and after partition */ + quickSort(arr, low, pi - 1); + quickSort(arr, pi + 1, high); + } +} +```""" diff --git a/c2rust-postprocess/tests/llm-cache/TrimTransform/quickSort/7f54c99381ba6e83aae5416848fb8435ec3538f98668b1ce6d12953bc1d19990/response.txt b/c2rust-postprocess/tests/llm-cache/TrimTransform/quickSort/7f54c99381ba6e83aae5416848fb8435ec3538f98668b1ce6d12953bc1d19990/response.txt new file mode 100644 index 0000000000..83f9d3e6ce --- /dev/null +++ b/c2rust-postprocess/tests/llm-cache/TrimTransform/quickSort/7f54c99381ba6e83aae5416848fb8435ec3538f98668b1ce6d12953bc1d19990/response.txt @@ -0,0 +1,11 @@ +void quickSort(int arr[], int low, int high) +{ + if (low < high) { + /* pi is the partitioning index; arr[pi] is now at the right place */ + int pi = partition(arr, low, high); + + /* Recursively sort elements before and after partition */ + quickSort(arr, low, pi - 1); + quickSort(arr, pi + 1, high); + } +} \ No newline at end of file diff --git a/tests/integration/tests/lua/conf.yml b/tests/integration/tests/lua/conf.yml index 5da76c055c..82220e45bf 100644 --- a/tests/integration/tests/lua/conf.yml +++ b/tests/integration/tests/lua/conf.yml @@ -33,3 +33,9 @@ refactor: cargo.refactor: autogen: true + +postprocess: + autogen: true + +cargo.postprocess: + autogen: true