Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ impl InferredDataType {
#[derive(Debug, Clone, Default)]
pub struct Format {
header: bool,
header_validation: bool,
delimiter: Option<u8>,
escape: Option<u8>,
quote: Option<u8>,
Expand All @@ -291,6 +292,16 @@ impl Format {
self
}

/// Specify whether to validate the CSV header against the schema, defaults to `false`
///
/// When `true`, the first row gets validated against the schema before any data is read
///
/// Only applies when [`Self::with_header`] is set to `true`
pub fn with_header_validation(mut self, validate_header: bool) -> Self {
self.header_validation = validate_header;
self
}

/// Specify a custom delimiter character, defaults to comma `','`
pub fn with_delimiter(mut self, delimiter: u8) -> Self {
self.delimiter = Some(delimiter);
Expand Down Expand Up @@ -610,6 +621,9 @@ pub struct Decoder {
/// Rows to skip
to_skip: usize,

/// Whether to validate the first skipped row against the schema
header_validation: bool,

/// Current line number
line_number: usize,

Expand All @@ -635,6 +649,20 @@ impl Decoder {
/// network sources such as object storage
pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
if self.to_skip != 0 {
if self.header_validation {
let (skipped, bytes) = self.record_decoder.decode(buf, 1)?;

if skipped == 0 {
return Ok(bytes);
}

let rows = self.record_decoder.flush()?;
validate_header(&rows, self.schema.fields())?;
self.header_validation = false;
self.to_skip -= 1;
return Ok(bytes);
}

// Skip in units of `to_read` to avoid over-allocating buffers
let to_skip = self.to_skip.min(self.batch_size);
let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
Expand Down Expand Up @@ -678,6 +706,24 @@ impl Decoder {
}
}

fn validate_header(rows: &StringRecords<'_>, fields: &Fields) -> Result<(), ArrowError> {
let header = rows.iter().next().ok_or_else(|| {
ArrowError::CsvError("CSV header validation failed: no header row found".to_string())
})?;

for (idx, field) in fields.iter().enumerate() {
let actual = header.get(idx);
let expected = field.name();
if actual != expected {
return Err(ArrowError::CsvError(format!(
"CSV header does not match schema at column {idx}: expected {expected:?} but found {actual:?}"
)));
}
}

Ok(())
}

/// Parses a slice of [`StringRecords`] into a [RecordBatch]
fn parse(
rows: &StringRecords<'_>,
Expand Down Expand Up @@ -1154,6 +1200,14 @@ impl ReaderBuilder {
self
}

/// Set whether to validate the CSV header against the schema
///
/// This option only applies when [`Self::with_header`] is set to `true`, and defaults to `false`
pub fn with_header_validation(mut self, validate_header: bool) -> Self {
self.format.header_validation = validate_header;
self
}

/// Overrides the [Format] of this [ReaderBuilder]
pub fn with_format(mut self, format: Format) -> Self {
self.format = format;
Expand Down Expand Up @@ -1261,6 +1315,7 @@ impl ReaderBuilder {
Decoder {
schema: self.schema,
to_skip: start,
header_validation: self.format.header && self.format.header_validation,
record_decoder,
line_number: start,
end,
Expand Down Expand Up @@ -2351,6 +2406,86 @@ mod tests {
}
}

#[test]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might also be a good idea to add a test with with_truncated_rows (which allows having less columns than specified)

e.g.

    #[test]
    fn test123() {
        let schema = Arc::new(Schema::new(vec![
            Field::new("a", DataType::Int32, true),
            Field::new("b", DataType::Int32, true),
        ]));

        let csv = "a\n1\n";
        let a = ReaderBuilder::new(schema.clone())
            .with_header(true)
            .with_header_validation(true)
            .with_truncated_rows(true)
            .build_buffered(Cursor::new(csv.as_bytes()))
            .unwrap()
            .next();
        dbg!(a);
    }

output

running 1 test
[arrow-csv/src/reader/mod.rs:2457:9] a = Some(
    Err(
        CsvError(
            "CSV header does not match schema at column 1: expected \"b\" but found \"\"",
        ),
    ),
)
test reader::tests::test123 ... ok

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be desirable to pass the validation in this case? Or would it make more sense to keep the current behavior?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my opinion if we're validating it would make sense to error out as it currently does

@XiNiHa XiNiHa Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fn test_header_validation() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));

let csv = "a,c\n1,2\n";
let err = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_header_validation(true)
.build_buffered(Cursor::new(csv.as_bytes()))
.unwrap()
.next()
.unwrap()
.unwrap_err()
.to_string();
assert_eq!(
err,
"Csv error: CSV header does not match schema at column 1: expected \"b\" but found \"c\""
);

let batch = ReaderBuilder::new(schema)
.with_header(true)
.with_header_validation(false)
.build_buffered(Cursor::new(csv.as_bytes()))
.unwrap()
.next()
.unwrap()
.unwrap();
assert_eq!(batch.num_rows(), 1);
}

#[test]
fn test_header_validation_with_buffered_reader() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));

let csv = "a,b\n1,2\n";
let buffered = std::io::BufReader::with_capacity(1, Cursor::new(csv.as_bytes()));
let batch = ReaderBuilder::new(schema)
.with_header(true)
.with_header_validation(true)
.build_buffered(buffered)
.unwrap()
.next()
.unwrap()
.unwrap();

assert_eq!(batch.num_rows(), 1);
let a = batch.column(0).as_primitive::<Int32Type>();
assert_eq!(a.value(0), 1);
}

#[test]
fn test_header_validation_with_truncated_rows() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let csv = "a\n1\n";
let err = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_header_validation(true)
.with_truncated_rows(true)
.build_buffered(Cursor::new(csv.as_bytes()))
.unwrap()
.next()
.unwrap()
.unwrap_err()
.to_string();
assert_eq!(
err,
"Csv error: CSV header does not match schema at column 1: expected \"b\" but found \"\"",
)
}

#[test]
fn test_null_boolean() {
let csv = "true,false\nFalse,True\n,True\nFalse,";
Expand Down
Loading