Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c2rust-postprocess/postprocess/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def apply_ident(
c_definition: CDefinition,
identifier: str,
update_rust: bool = True,
) -> None:
) -> str | None:
"""
Implementations should apply transform to a single Rust definition
with the given identifier.
Expand Down
35 changes: 32 additions & 3 deletions c2rust-postprocess/postprocess/transforms/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from postprocess.models import AbstractGenerativeModel, api_key_from_env
from postprocess.transforms.base import AbstractTransform, TransformError
from postprocess.transforms.trim import TrimTransform
from postprocess.utils import get_highlighted_rust, remove_backticks

# TODO: get from model
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel):
super().__init__(SYSTEM_INSTRUCTION)
self.cache = cache
self.model = model
self.trim_transform = TrimTransform(cache, model)

def apply_ident(
self,
Expand All @@ -86,8 +88,7 @@ def apply_ident(
if rust_comments:
logging.info(
f"Skipping Rust fn {identifier} with existing comments:\
\n{rust_comments} in\
\n{rust_definition}"
\n{get_highlighted_rust(rust_definition)}"
)
return

Expand All @@ -98,6 +99,31 @@ def apply_ident(
logging.info(f"Skipping C function without comments: {identifier}")
return

match self.trim_transform.apply_ident(
rust_source_file=rust_source_file,
rust_definition=rust_definition,
c_definition=c_definition,
identifier=identifier,
update_rust=False, # nothing to update here
):
case None:
logging.error(
f"Trim transform failed for {identifier}, "
"skipping comments transfer"
)
return
case str() as trimmed_c_definition:
# TODO: consider trimming both the definition and the preprocessed
# definition instead of possibly replacing the original
# definition with the trimmed and preprocessed one.
c_definition = CDefinition(
definition=trimmed_c_definition, preprocessed_definition=None
)
case _:
raise AssertionError(
"Unexpected return type from trim transform: expected None or str"
)

# TODO: make this function take a model and get prompt from model
prompt_text = """
Transfer the comments from the following C function to the corresponding Rust function.
Expand Down Expand Up @@ -169,7 +195,10 @@ def apply_ident(
response=response,
)

print(get_highlighted_rust(rust_fn))
logging.info(
f"Comments transferred to Rust fn {identifier}:\
\n{get_highlighted_rust(rust_fn)}"
)

# TODO: move this to apply_file?
# the challenge is that not all transforms will update Rust code
Expand Down
88 changes: 88 additions & 0 deletions c2rust-postprocess/postprocess/transforms/trim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
from pathlib import Path
from textwrap import dedent

from postprocess.cache import AbstractCache
from postprocess.definitions import CDefinition, get_c_comments
from postprocess.models import AbstractGenerativeModel
from postprocess.transforms.base import AbstractTransform
from postprocess.utils import remove_backticks

SYSTEM_INSTRUCTION = (
"You are a helpful assistant that removes top-of-file prologues and"
" other unnecessary content such as preprocessor definitions and"
" comments that preceede and are unrelated to the definition of a C function."
)


class TrimTransform(AbstractTransform):
def __init__(self, cache: AbstractCache, model: AbstractGenerativeModel):
super().__init__(SYSTEM_INSTRUCTION)
self.cache = cache
self.model = model

def apply_ident(
self,
rust_source_file: Path,
rust_definition: str,
c_definition: CDefinition,
identifier: str,
update_rust: bool = True,
) -> str | None:
c_comments = get_c_comments(c_definition.effective)
if not c_comments:
logging.info(
f"{self.__class__.__name__}: "
f"Skipping C function without comments: {identifier}"
)
return

prompt = """
Remove any prologues, preprocessor definitions, and comments that are unrelated to the definition
of the C function `{identifier}`. Respond with the trimmed C function definition; say nothing else.

C function:
```c
{c_definition}
```
""" # noqa: E501
prompt = dedent(
prompt
).strip() # note: dedent then format since the C definition isn't indented
prompt = prompt.format(
identifier=identifier, c_definition=c_definition.effective
)

messages = [
{"role": "user", "content": prompt},
]

transform = self.__class__.__name__
model = self.model.id
if response := self.cache.lookup(
transform=transform,
identifier=identifier,
model=model,
messages=messages,
):
return response

response = self.model.generate_with_tools(messages)

if response is None:
logging.error("Model returned no response")
return response

# TODO: validate that function definition is still present?

response = remove_backticks(response)

self.cache.update(
transform=transform,
identifier=identifier,
model=model,
messages=messages,
response=response,
)

return response
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
transform = "CommentsTransform"
identifier = "partition"
model = "gemini-3-flash-preview"
response = """/// Lomuto Partition Scheme:
/// Partitions the array so that elements < pivot are on the left,
/// and elements >= pivot are on the right.
response = """/// Partition the subarray around the last element as pivot and return pivot's final index.
#[no_mangle]
pub unsafe extern "C" fn partition(
mut arr: *mut ::core::ffi::c_int,
mut low: ::core::ffi::c_int,
mut high: ::core::ffi::c_int,
) -> ::core::ffi::c_int {
// Partition the subarray around the last element as pivot and return pivot's final index.
let mut pivot: ::core::ffi::c_int = *arr.offset(high as isize);
let mut i: ::core::ffi::c_int = low - 1 as ::core::ffi::c_int;
let mut j: ::core::ffi::c_int = low;
Expand Down Expand Up @@ -42,11 +39,6 @@ Respond with the Rust function definition with the transferred comments; say not

C function:
```c
/*
* Lomuto Partition Scheme:
* Partitions the array so that elements < pivot are on the left,
* and elements >= pivot are on the right.
*/
int partition (int arr[], int low, int high)
{
// Partition the subarray around the last element as pivot and return pivot's final index.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
/// Lomuto Partition Scheme:
/// Partitions the array so that elements < pivot are on the left,
/// and elements >= pivot are on the right.
/// Partition the subarray around the last element as pivot and return pivot's final index.
#[no_mangle]
pub unsafe extern "C" fn partition(
mut arr: *mut ::core::ffi::c_int,
mut low: ::core::ffi::c_int,
mut high: ::core::ffi::c_int,
) -> ::core::ffi::c_int {
// Partition the subarray around the last element as pivot and return pivot's final index.
let mut pivot: ::core::ffi::c_int = *arr.offset(high as isize);
let mut i: ::core::ffi::c_int = low - 1 as ::core::ffi::c_int;
let mut j: ::core::ffi::c_int = low;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
transform = "TrimTransform"
identifier = "partition"
model = "gemini-3-flash-preview"
response = """int partition (int arr[], int low, int high)
{
// Partition the subarray around the last element as pivot and return pivot's final index.
int pivot = arr[high];
int i = low - 1;

for (int j = low; j <= high - 1; j++) {
if (arr[j] <= pivot) {
i++;
// Move elements <= pivot into the left partition.
swap(&arr[i], &arr[j]);
}
}
// Place pivot just after the final element of the left partition.
swap(&arr[i + 1], &arr[high]);
return i + 1;
}"""

[[messages]]
role = "user"
content = """Remove any prologues, preprocessor definitions, and comments that are unrelated to the definition
of the C function `partition`. Respond with the trimmed C function definition; say nothing else.

C function:
```c
/*
* Lomuto Partition Scheme:
* Partitions the array so that elements < pivot are on the left,
* and elements >= pivot are on the right.
*/
int partition (int arr[], int low, int high)
{
// Partition the subarray around the last element as pivot and return pivot's final index.
int pivot = arr[high];
int i = low - 1;

for (int j = low; j <= high - 1; j++) {
if (arr[j] <= pivot) {
i++;
// Move elements <= pivot into the left partition.
swap(&arr[i], &arr[j]);
}
}
// Place pivot just after the final element of the left partition.
swap(&arr[i + 1], &arr[high]);
return i + 1;
}
```"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
int partition (int arr[], int low, int high)
{
// Partition the subarray around the last element as pivot and return pivot's final index.
int pivot = arr[high];
int i = low - 1;

for (int j = low; j <= high - 1; j++) {
if (arr[j] <= pivot) {
i++;
// Move elements <= pivot into the left partition.
swap(&arr[i], &arr[j]);
}
}
// Place pivot just after the final element of the left partition.
swap(&arr[i + 1], &arr[high]);
return i + 1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
transform = "TrimTransform"
identifier = "quickSort"
model = "gemini-3-flash-preview"
response = """void quickSort(int arr[], int low, int high)
{
if (low < high) {
/* pi is the partitioning index; arr[pi] is now at the right place */
int pi = partition(arr, low, high);

/* Recursively sort elements before and after partition */
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}"""

[[messages]]
role = "user"
content = """Remove any prologues, preprocessor definitions, and comments that are unrelated to the definition
of the C function `quickSort`. Respond with the trimmed C function definition; say nothing else.

C function:
```c
void quickSort(int arr[], int low, int high)
{
if (low < high) {
/* pi is the partitioning index; arr[pi] is now at the right place */
int pi = partition(arr, low, high);

/* Recursively sort elements before and after partition */
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
```"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
void quickSort(int arr[], int low, int high)
{
if (low < high) {
/* pi is the partitioning index; arr[pi] is now at the right place */
int pi = partition(arr, low, high);

/* Recursively sort elements before and after partition */
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
6 changes: 6 additions & 0 deletions tests/integration/tests/lua/conf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ refactor:

cargo.refactor:
autogen: true

postprocess:
autogen: true

cargo.postprocess:
autogen: true
Loading