Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/daft-logical-plan/src/scan_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct ParquetScanBuilder {
pub schema: Option<SchemaRef>,
pub file_path_column: Option<String>,
pub hive_partitioning: bool,
pub ignore_corrupt_files: bool,
}

impl ParquetScanBuilder {
Expand All @@ -47,6 +48,7 @@ impl ParquetScanBuilder {
io_config: None,
file_path_column: None,
hive_partitioning: false,
ignore_corrupt_files: false,
}
}
pub fn infer_schema(mut self, infer_schema: bool) -> Self {
Expand Down Expand Up @@ -95,13 +97,18 @@ impl ParquetScanBuilder {
self
}

pub fn ignore_corrupt_files(mut self, ignore_corrupt_files: bool) -> Self {
self.ignore_corrupt_files = ignore_corrupt_files;
self
}

pub async fn finish(self) -> DaftResult<LogicalPlanBuilder> {
let cfg = ParquetSourceConfig {
coerce_int96_timestamp_unit: self.coerce_int96_timestamp_unit,
field_id_mapping: self.field_id_mapping,
row_groups: self.row_groups,
chunk_size: self.chunk_size,
ignore_corrupt_files: false,
ignore_corrupt_files: self.ignore_corrupt_files,
};

let operator = Arc::new(
Expand Down
3 changes: 3 additions & 0 deletions src/daft-sql/src/table_provider/read_parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl TryFrom<SQLFunctionArguments> for ParquetScanBuilder {
let file_path_column = args.try_get_named("file_path_column")?;
let multithreaded = args.try_get_named("multithreaded")?.unwrap_or(true);
let hive_partitioning = args.try_get_named("hive_partitioning")?.unwrap_or(false);
let ignore_corrupt_files = args.try_get_named("ignore_corrupt_files")?.unwrap_or(false);

let field_id_mapping = None; // TODO
let row_groups = None; // TODO
Expand All @@ -64,6 +65,7 @@ impl TryFrom<SQLFunctionArguments> for ParquetScanBuilder {
schema,
file_path_column,
hive_partitioning,
ignore_corrupt_files,
})
}
}
Expand All @@ -88,6 +90,7 @@ impl SQLTableFunction for ReadParquetFunction {
"io_config",
"file_path_column",
"hive_partitioning",
"ignore_corrupt_files",
],
1, // 1 positional argument (path)
)?;
Expand Down
22 changes: 22 additions & 0 deletions tests/sql/test_sql_table_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os

import pyarrow as pa
import pyarrow.parquet as papq
import pytest
Expand Down Expand Up @@ -132,6 +134,26 @@ def test_sql_read_parquet_file_options(tmp_path):
assert actual.to_pydict() == expect.to_pydict()


def test_sql_read_parquet_ignore_corrupt_files(tmp_path):
good_path = tmp_path / "good.parquet"
bad_path = tmp_path / "bad.parquet"
papq.write_table(pa.table({"x": [1, 2]}), good_path)
bad_path.write_bytes(b"PAR1" + b"\x00" * 20 + b"PAR1")

with pytest.raises(Exception):
daft.sql(f"SELECT * FROM read_parquet('{tmp_path.as_posix()}')").collect()

df = daft.sql(f"SELECT * FROM read_parquet('{tmp_path.as_posix()}', ignore_corrupt_files => true)").collect()

assert df.to_pydict() == {"x": [1, 2]}
skipped = df.skipped_corrupt_files
assert len(skipped) == 1
path, reason, partial = skipped[0]
assert os.path.basename(path) == "bad.parquet"
assert reason
assert not partial


def test_sql_read_csv(sample_csv_path):
actual = daft.sql(f"SELECT * FROM read_csv('{sample_csv_path}')")
expect = daft.read_csv(sample_csv_path)
Expand Down
Loading