diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1c386a57..ddd69e58 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: platform: ["ubuntu-latest"] tox-env: ["py310", "py311", "py312", "py313"] include: - - platform: macos-13 + - platform: macos-15 tox-env: "py313" - platform: macos-latest tox-env: "py313" diff --git a/docs/index.rst b/docs/index.rst index 997de169..8434017e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,6 +63,7 @@ User Guide using_task_task_conflict_prevention_lock efficient_run_on_multi_workers for_pandas + polars mypy_plugin API References diff --git a/docs/polars.rst b/docs/polars.rst new file mode 100644 index 00000000..b4ca1f4a --- /dev/null +++ b/docs/polars.rst @@ -0,0 +1,225 @@ +Polars Support +============== + +Gokart supports Polars DataFrames alongside pandas DataFrames for DataFrame-based file processors. This allows gradual migration from pandas to Polars or using both libraries simultaneously in your data pipelines. + + +Installation +------------ + +Polars support is optional. Install it with: + +.. code:: bash + + pip install gokart[polars] + +Or install Polars separately: + +.. code:: bash + + pip install polars + + +Basic Usage +----------- + +To use Polars DataFrames with gokart, specify ``dataframe_type='polars'`` when creating file processors: + +.. code:: python + + import polars as pl + from gokart import TaskOnKart + from gokart.file_processor import FeatherFileProcessor + + class MyPolarsTask(TaskOnKart[pl.DataFrame]): + def output(self): + return self.make_target( + 'path/to/target.feather', + processor=FeatherFileProcessor( + store_index_in_feather=False, + dataframe_type='polars' + ) + ) + + def run(self): + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + self.dump(df) + + +Supported File Processors +-------------------------- + +The following file processors support the ``dataframe_type`` parameter: + +CsvFileProcessor +^^^^^^^^^^^^^^^^ + +.. code:: python + + from gokart.file_processor import CsvFileProcessor + + # For Polars + processor = CsvFileProcessor(sep=',', encoding='utf-8', dataframe_type='polars') + + # For pandas (default) + processor = CsvFileProcessor(sep=',', encoding='utf-8', dataframe_type='pandas') + # or simply + processor = CsvFileProcessor(sep=',', encoding='utf-8') + + +JsonFileProcessor +^^^^^^^^^^^^^^^^^ + +.. code:: python + + from gokart.file_processor import JsonFileProcessor + + # For Polars + processor = JsonFileProcessor(orient='records', dataframe_type='polars') + + # For pandas (default) + processor = JsonFileProcessor(orient='records', dataframe_type='pandas') + + +ParquetFileProcessor +^^^^^^^^^^^^^^^^^^^^ + +.. code:: python + + from gokart.file_processor import ParquetFileProcessor + + # For Polars + processor = ParquetFileProcessor( + compression='gzip', + dataframe_type='polars' + ) + + # For pandas (default) + processor = ParquetFileProcessor( + compression='gzip', + dataframe_type='pandas' + ) + + +FeatherFileProcessor +^^^^^^^^^^^^^^^^^^^^ + +.. code:: python + + from gokart.file_processor import FeatherFileProcessor + + # For Polars + processor = FeatherFileProcessor( + store_index_in_feather=False, + dataframe_type='polars' + ) + + # For pandas (default) + processor = FeatherFileProcessor( + store_index_in_feather=True, + dataframe_type='pandas' + ) + +.. note:: + The ``store_index_in_feather`` parameter is pandas-specific and is ignored when using Polars. + + +Using Pandas and Polars Together +--------------------------------- + +Since projects often migrate from pandas gradually, gokart allows you to use both pandas and Polars simultaneously: + +.. code:: python + + import pandas as pd + import polars as pl + from gokart import TaskOnKart + from gokart.file_processor import FeatherFileProcessor + + class PandasTask(TaskOnKart[pd.DataFrame]): + """Task that outputs pandas DataFrame""" + def output(self): + return self.make_target( + 'path/to/pandas_output.feather', + processor=FeatherFileProcessor( + store_index_in_feather=False, + dataframe_type='pandas' + ) + ) + + def run(self): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + self.dump(df) + + class PolarsTask(TaskOnKart[pl.DataFrame]): + """Task that outputs Polars DataFrame""" + def requires(self): + return PandasTask() + + def output(self): + return self.make_target( + 'path/to/polars_output.feather', + processor=FeatherFileProcessor( + store_index_in_feather=False, + dataframe_type='polars' + ) + ) + + def run(self): + # Load pandas DataFrame and convert to Polars + pandas_df = self.load() # Returns pandas DataFrame + polars_df = pl.from_pandas(pandas_df) + + # Process with Polars + result = polars_df.with_columns( + (pl.col('a') * 2).alias('a_doubled') + ) + + self.dump(result) + + +Default Behavior +---------------- + +When ``dataframe_type`` is not specified, file processors default to ``'pandas'`` for backward compatibility: + +.. code:: python + + # These are equivalent + processor = CsvFileProcessor(sep=',') + processor = CsvFileProcessor(sep=',', dataframe_type='pandas') + + +Important Notes +--------------- + +**File Format Compatibility** + +Files created with Polars processors can be read by pandas processors and vice versa. The underlying file formats (CSV, JSON, Parquet, Feather) are library-agnostic. + +**Pandas-specific Features** + +Some pandas-specific features are not available with Polars: + +- ``store_index_in_feather`` parameter in ``FeatherFileProcessor`` is ignored for Polars +- ``engine`` parameter in ``ParquetFileProcessor`` is ignored for Polars (uses Polars' default) + +**Error Handling** + +If you specify ``dataframe_type='polars'`` but Polars is not installed, you'll get an ``ImportError`` with installation instructions: + +.. code:: text + + ImportError: polars is required for dataframe_type='polars'. Install with: pip install polars + + +Migration Strategy +------------------ + +Recommended approach for migrating from pandas to Polars: + +1. Install Polars: ``pip install gokart[polars]`` +2. Create new tasks using ``dataframe_type='polars'`` +3. Keep existing tasks with ``dataframe_type='pandas'`` or default behavior +4. Gradually migrate tasks as needed +5. Convert DataFrames between libraries using ``pl.from_pandas()`` and ``df.to_pandas()`` when necessary diff --git a/gokart/file_processor.py b/gokart/file_processor.py index bd257cf7..e69de29b 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -1,304 +0,0 @@ -from __future__ import annotations - -import os -import xml.etree.ElementTree as ET -from abc import abstractmethod -from io import BytesIO -from logging import getLogger - -import dill -import luigi -import luigi.contrib.s3 -import luigi.format -import numpy as np -import pandas as pd -import pandas.errors -from luigi.format import TextFormat - -from gokart.object_storage import ObjectStorage -from gokart.utils import load_dill_with_pandas_backward_compatibility - -logger = getLogger(__name__) - - -class FileProcessor: - @abstractmethod - def format(self): - pass - - @abstractmethod - def load(self, file): - pass - - @abstractmethod - def dump(self, obj, file): - pass - - -class BinaryFileProcessor(FileProcessor): - """ - Pass bytes to this processor - - ``` - figure_binary = io.BytesIO() - plt.savefig(figure_binary) - figure_binary.seek(0) - BinaryFileProcessor().dump(figure_binary.read()) - ``` - """ - - def format(self): - return luigi.format.Nop - - def load(self, file): - return file.read() - - def dump(self, obj, file): - file.write(obj) - - -class _ChunkedLargeFileReader: - def __init__(self, file) -> None: - self._file = file - - def __getattr__(self, item): - return getattr(self._file, item) - - def read(self, n): - if n >= (1 << 31): - logger.info(f'reading a large file with total_bytes={n}.') - buffer = bytearray(n) - idx = 0 - while idx < n: - batch_size = min(n - idx, 1 << 31 - 1) - logger.info(f'reading bytes [{idx}, {idx + batch_size})...') - buffer[idx : idx + batch_size] = self._file.read(batch_size) - idx += batch_size - logger.info('done.') - return buffer - return self._file.read(n) - - -class PickleFileProcessor(FileProcessor): - def format(self): - return luigi.format.Nop - - def load(self, file): - if not file.seekable(): - # load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented. - # Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable. - # For example, ReadableS3File is not a seekable file. - return load_dill_with_pandas_backward_compatibility(BytesIO(file.read())) - return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file)) - - def dump(self, obj, file): - self._write(dill.dumps(obj, protocol=4), file) - - @staticmethod - def _write(buffer, file): - n = len(buffer) - idx = 0 - while idx < n: - logger.info(f'writing a file with total_bytes={n}...') - batch_size = min(n - idx, 1 << 31 - 1) - logger.info(f'writing bytes [{idx}, {idx + batch_size})') - file.write(buffer[idx : idx + batch_size]) - idx += batch_size - logger.info('done') - - -class TextFileProcessor(FileProcessor): - def format(self): - return None - - def load(self, file): - return [s.rstrip() for s in file.readlines()] - - def dump(self, obj, file): - if isinstance(obj, list): - for x in obj: - file.write(str(x) + '\n') - else: - file.write(str(obj)) - - -class CsvFileProcessor(FileProcessor): - def __init__(self, sep=',', encoding: str = 'utf-8'): - self._sep = sep - self._encoding = encoding - super().__init__() - - def format(self): - return TextFormat(encoding=self._encoding) - - def load(self, file): - try: - return pd.read_csv(file, sep=self._sep, encoding=self._encoding) - except pd.errors.EmptyDataError: - return pd.DataFrame() - - def dump(self, obj, file): - assert isinstance(obj, pd.DataFrame | pd.Series), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.' - obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) - - -class GzipFileProcessor(FileProcessor): - def format(self): - return luigi.format.Gzip - - def load(self, file): - return [s.rstrip().decode() for s in file.readlines()] - - def dump(self, obj, file): - if isinstance(obj, list): - for x in obj: - file.write((str(x) + '\n').encode()) - else: - file.write(str(obj).encode()) - - -class JsonFileProcessor(FileProcessor): - def __init__(self, orient: str | None = None): - self._orient = orient - - def format(self): - return luigi.format.Nop - - def load(self, file): - try: - return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) - except pd.errors.EmptyDataError: - return pd.DataFrame() - - def dump(self, obj, file): - assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( - f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' - ) - if isinstance(obj, dict): - obj = pd.DataFrame.from_dict(obj) - obj.to_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) - - -class XmlFileProcessor(FileProcessor): - def format(self): - return None - - def load(self, file): - try: - return ET.parse(file) - except ET.ParseError: - return ET.ElementTree() - - def dump(self, obj, file): - assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.' - obj.write(file) - - -class NpzFileProcessor(FileProcessor): - def format(self): - return luigi.format.Nop - - def load(self, file): - return np.load(file)['data'] - - def dump(self, obj, file): - assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.' - np.savez_compressed(file, data=obj) - - -class ParquetFileProcessor(FileProcessor): - def __init__(self, engine='pyarrow', compression=None): - self._engine = engine - self._compression = compression - super().__init__() - - def format(self): - return luigi.format.Nop - - def load(self, file): - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_parquet accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. - if ObjectStorage.is_buffered_reader(file): - return pd.read_parquet(file.name) - else: - return pd.read_parquet(BytesIO(file.read())) - - def dump(self, obj, file): - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' - # MEMO: to_parquet only supports a filepath as string (not a file handle) - obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) - - -class FeatherFileProcessor(FileProcessor): - def __init__(self, store_index_in_feather: bool): - super().__init__() - self._store_index_in_feather = store_index_in_feather - self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' - - def format(self): - return luigi.format.Nop - - def load(self, file): - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_feather accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. - if ObjectStorage.is_buffered_reader(file): - loaded_df = pd.read_feather(file.name) - else: - loaded_df = pd.read_feather(BytesIO(file.read())) - - if self._store_index_in_feather: - if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): - index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] - index_column = index_columns[0] - index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] - if index_name == 'None': - index_name = None - loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) - loaded_df = loaded_df.drop(columns={index_column}) - - return loaded_df - - def dump(self, obj, file): - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' - dump_obj = obj.copy() - - if self._store_index_in_feather: - index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' - assert index_column_name not in dump_obj.columns, ( - f'column name {index_column_name} already exists in dump_obj. \ - Consider not saving index by setting store_index_in_feather=False.' - ) - assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' - - dump_obj[index_column_name] = dump_obj.index - dump_obj = dump_obj.reset_index(drop=True) - - # to_feather supports "binary" file-like object, but file variable is text - dump_obj.to_feather(file.name) - - -def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: - extension2processor = { - '.txt': TextFileProcessor(), - '.ini': TextFileProcessor(), - '.csv': CsvFileProcessor(sep=','), - '.tsv': CsvFileProcessor(sep='\t'), - '.pkl': PickleFileProcessor(), - '.gz': GzipFileProcessor(), - '.json': JsonFileProcessor(), - '.ndjson': JsonFileProcessor(orient='records'), - '.xml': XmlFileProcessor(), - '.npz': NpzFileProcessor(), - '.parquet': ParquetFileProcessor(compression='gzip'), - '.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather), - '.png': BinaryFileProcessor(), - '.jpg': BinaryFileProcessor(), - } - - extension = os.path.splitext(file_path)[1] - assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.' - return extension2processor[extension] diff --git a/gokart/file_processor/__init__.py b/gokart/file_processor/__init__.py new file mode 100644 index 00000000..0916e2eb --- /dev/null +++ b/gokart/file_processor/__init__.py @@ -0,0 +1,198 @@ +"""File processor module with support for multiple DataFrame backends.""" + +from __future__ import annotations + +import os + +# Export common processors and types from base +from gokart.file_processor.base import ( + BinaryFileProcessor, + DataFrameType, + FileProcessor, + GzipFileProcessor, + NpzFileProcessor, + PickleFileProcessor, + TextFileProcessor, + XmlFileProcessor, +) + +# Import backend-specific implementations +from gokart.file_processor.pandas import ( + CsvFileProcessorPandas, + FeatherFileProcessorPandas, + JsonFileProcessorPandas, + ParquetFileProcessorPandas, +) +from gokart.file_processor.polars import ( + CsvFileProcessorPolars, + FeatherFileProcessorPolars, + JsonFileProcessorPolars, + ParquetFileProcessorPolars, +) + + +class CsvFileProcessor(FileProcessor): + """CSV file processor with automatic backend selection based on dataframe_type.""" + + def __init__(self, sep=',', encoding: str = 'utf-8', dataframe_type: DataFrameType = 'pandas'): + """ + CSV file processor with support for both pandas and polars DataFrames. + + Args: + sep: CSV delimiter (default: ',') + encoding: File encoding (default: 'utf-8') + dataframe_type: DataFrame library to use for load() - 'pandas' or 'polars' (default: 'pandas') + """ + self._sep = sep + self._encoding = encoding + self._dataframe_type = dataframe_type # Store for tests + + if dataframe_type == 'polars': + self._impl: FileProcessor = CsvFileProcessorPolars(sep=sep, encoding=encoding) + else: + self._impl = CsvFileProcessorPandas(sep=sep, encoding=encoding) + + def format(self): + return self._impl.format() + + def load(self, file): + return self._impl.load(file) + + def dump(self, obj, file): + return self._impl.dump(obj, file) + + +class JsonFileProcessor(FileProcessor): + """JSON file processor with automatic backend selection based on dataframe_type.""" + + def __init__(self, orient: str | None = None, dataframe_type: DataFrameType = 'pandas'): + """ + JSON file processor with support for both pandas and polars DataFrames. + + Args: + orient: JSON orientation. 'records' for newline-delimited JSON. + dataframe_type: DataFrame library to use for load() - 'pandas' or 'polars' (default: 'pandas') + """ + self._orient = orient + self._dataframe_type = dataframe_type # Store for tests + + if dataframe_type == 'polars': + self._impl: FileProcessor = JsonFileProcessorPolars(orient=orient) + else: + self._impl = JsonFileProcessorPandas(orient=orient) + + def format(self): + return self._impl.format() + + def load(self, file): + return self._impl.load(file) + + def dump(self, obj, file): + return self._impl.dump(obj, file) + + +class ParquetFileProcessor(FileProcessor): + """Parquet file processor with automatic backend selection based on dataframe_type.""" + + def __init__(self, engine='pyarrow', compression=None, dataframe_type: DataFrameType = 'pandas'): + """ + Parquet file processor with support for both pandas and polars DataFrames. + + Args: + engine: Parquet engine (pandas-specific, ignored for polars). + compression: Compression type. + dataframe_type: DataFrame library to use for load() - 'pandas' or 'polars' (default: 'pandas') + """ + self._engine = engine + self._compression = compression + self._dataframe_type = dataframe_type # Store for tests + + if dataframe_type == 'polars': + self._impl: FileProcessor = ParquetFileProcessorPolars(engine=engine, compression=compression) + else: + self._impl = ParquetFileProcessorPandas(engine=engine, compression=compression) + + def format(self): + return self._impl.format() + + def load(self, file): + return self._impl.load(file) + + def dump(self, obj, file): + # Use the configured implementation (pandas by default) + return self._impl.dump(obj, file) + + +class FeatherFileProcessor(FileProcessor): + """Feather file processor with automatic backend selection based on dataframe_type.""" + + def __init__(self, store_index_in_feather: bool, dataframe_type: DataFrameType = 'pandas'): + """ + Feather file processor with support for both pandas and polars DataFrames. + + Args: + store_index_in_feather: Whether to store pandas index (pandas-only feature). + dataframe_type: DataFrame library to use for load() - 'pandas' or 'polars' (default: 'pandas') + """ + self._store_index_in_feather = store_index_in_feather + self._dataframe_type = dataframe_type # Store for tests + + if dataframe_type == 'polars': + self._impl: FileProcessor = FeatherFileProcessorPolars(store_index_in_feather=store_index_in_feather) + else: + self._impl = FeatherFileProcessorPandas(store_index_in_feather=store_index_in_feather) + + def format(self): + return self._impl.format() + + def load(self, file): + return self._impl.load(file) + + def dump(self, obj, file): + # Use the configured implementation (pandas by default) + return self._impl.dump(obj, file) + + +def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: + """Create a file processor based on file extension with default parameters.""" + extension2processor = { + '.txt': TextFileProcessor(), + '.ini': TextFileProcessor(), + '.csv': CsvFileProcessor(sep=','), + '.tsv': CsvFileProcessor(sep='\t'), + '.pkl': PickleFileProcessor(), + '.gz': GzipFileProcessor(), + '.json': JsonFileProcessor(), + '.ndjson': JsonFileProcessor(orient='records'), + '.xml': XmlFileProcessor(), + '.npz': NpzFileProcessor(), + '.parquet': ParquetFileProcessor(compression='gzip'), + '.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather), + '.png': BinaryFileProcessor(), + '.jpg': BinaryFileProcessor(), + } + + extension = os.path.splitext(file_path)[1] + assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.' + return extension2processor[extension] + + +__all__ = [ + # Base classes and types + 'FileProcessor', + 'DataFrameType', + # Common processors + 'BinaryFileProcessor', + 'PickleFileProcessor', + 'TextFileProcessor', + 'GzipFileProcessor', + 'XmlFileProcessor', + 'NpzFileProcessor', + # DataFrame processors (with factory pattern) + 'CsvFileProcessor', + 'JsonFileProcessor', + 'ParquetFileProcessor', + 'FeatherFileProcessor', + # Utility functions + 'make_file_processor', +] diff --git a/gokart/file_processor/base.py b/gokart/file_processor/base.py new file mode 100644 index 00000000..d94257cf --- /dev/null +++ b/gokart/file_processor/base.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import xml.etree.ElementTree as ET +from abc import abstractmethod +from io import BytesIO +from logging import getLogger +from typing import Literal + +import dill +import luigi +import luigi.format +import numpy as np + +from gokart.utils import load_dill_with_pandas_backward_compatibility + +logger = getLogger(__name__) + +# Type alias for DataFrame library return type +DataFrameType = Literal['pandas', 'polars'] + + +class FileProcessor: + @abstractmethod + def format(self): + pass + + @abstractmethod + def load(self, file): + pass + + @abstractmethod + def dump(self, obj, file): + pass + + +class BinaryFileProcessor(FileProcessor): + """ + Pass bytes to this processor + + ``` + figure_binary = io.BytesIO() + plt.savefig(figure_binary) + figure_binary.seek(0) + BinaryFileProcessor().dump(figure_binary.read()) + ``` + """ + + def format(self): + return luigi.format.Nop + + def load(self, file): + return file.read() + + def dump(self, obj, file): + file.write(obj) + + +class _ChunkedLargeFileReader: + def __init__(self, file) -> None: + self._file = file + + def __getattr__(self, item): + return getattr(self._file, item) + + def read(self, n): + if n >= (1 << 31): + logger.info(f'reading a large file with total_bytes={n}.') + buffer = bytearray(n) + idx = 0 + while idx < n: + batch_size = min(n - idx, 1 << 31 - 1) + logger.info(f'reading bytes [{idx}, {idx + batch_size})...') + buffer[idx : idx + batch_size] = self._file.read(batch_size) + idx += batch_size + logger.info('done.') + return buffer + return self._file.read(n) + + +class PickleFileProcessor(FileProcessor): + def format(self): + return luigi.format.Nop + + def load(self, file): + if not file.seekable(): + # load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented. + # Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable. + # For example, ReadableS3File is not a seekable file. + return load_dill_with_pandas_backward_compatibility(BytesIO(file.read())) + return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file)) + + def dump(self, obj, file): + self._write(dill.dumps(obj, protocol=4), file) + + @staticmethod + def _write(buffer, file): + n = len(buffer) + idx = 0 + while idx < n: + logger.info(f'writing a file with total_bytes={n}...') + batch_size = min(n - idx, 1 << 31 - 1) + logger.info(f'writing bytes [{idx}, {idx + batch_size})') + file.write(buffer[idx : idx + batch_size]) + idx += batch_size + logger.info('done') + + +class TextFileProcessor(FileProcessor): + def format(self): + return None + + def load(self, file): + return [s.rstrip() for s in file.readlines()] + + def dump(self, obj, file): + if isinstance(obj, list): + for x in obj: + file.write(str(x) + '\n') + else: + file.write(str(obj)) + + +class GzipFileProcessor(FileProcessor): + def format(self): + return luigi.format.Gzip + + def load(self, file): + return [s.rstrip().decode() for s in file.readlines()] + + def dump(self, obj, file): + if isinstance(obj, list): + for x in obj: + file.write((str(x) + '\n').encode()) + else: + file.write(str(obj).encode()) + + +class XmlFileProcessor(FileProcessor): + def format(self): + return None + + def load(self, file): + try: + return ET.parse(file) + except ET.ParseError: + return ET.ElementTree() + + def dump(self, obj, file): + assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.' + obj.write(file) + + +class NpzFileProcessor(FileProcessor): + def format(self): + return luigi.format.Nop + + def load(self, file): + return np.load(file)['data'] + + def dump(self, obj, file): + assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.' + np.savez_compressed(file, data=obj) diff --git a/gokart/file_processor/pandas.py b/gokart/file_processor/pandas.py new file mode 100644 index 00000000..cd1b34a4 --- /dev/null +++ b/gokart/file_processor/pandas.py @@ -0,0 +1,140 @@ +"""Pandas-specific file processor implementations.""" + +from __future__ import annotations + +from io import BytesIO + +import luigi +import luigi.format +import pandas as pd +from luigi.format import TextFormat + +from gokart.file_processor.base import FileProcessor +from gokart.object_storage import ObjectStorage + + +class CsvFileProcessorPandas(FileProcessor): + """CSV file processor for pandas DataFrames.""" + + def __init__(self, sep=',', encoding: str = 'utf-8'): + self._sep = sep + self._encoding = encoding + super().__init__() + + def format(self): + return TextFormat(encoding=self._encoding) + + def load(self, file): + try: + return pd.read_csv(file, sep=self._sep, encoding=self._encoding) + except pd.errors.EmptyDataError: + return pd.DataFrame() + + def dump(self, obj, file): + if not isinstance(obj, pd.DataFrame | pd.Series): + raise TypeError(f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.') + obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) + + +class JsonFileProcessorPandas(FileProcessor): + """JSON file processor for pandas DataFrames.""" + + def __init__(self, orient: str | None = None): + self._orient = orient + + def format(self): + return luigi.format.Nop + + def load(self, file): + try: + return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) + except pd.errors.EmptyDataError: + return pd.DataFrame() + + def dump(self, obj, file): + if isinstance(obj, dict): + obj = pd.DataFrame.from_dict(obj) + if not isinstance(obj, pd.DataFrame | pd.Series | dict): + raise TypeError(f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.') + obj.to_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) + + +class ParquetFileProcessorPandas(FileProcessor): + """Parquet file processor for pandas DataFrames.""" + + def __init__(self, engine='pyarrow', compression=None): + self._engine = engine + self._compression = compression + super().__init__() + + def format(self): + return luigi.format.Nop + + def load(self, file): + # FIXME(mamo3gr): enable streaming (chunked) read with S3. + # pandas.read_parquet accepts file-like object + # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, + # which is needed for pandas to read a file in chunks. + if ObjectStorage.is_buffered_reader(file): + return pd.read_parquet(file.name) + else: + return pd.read_parquet(BytesIO(file.read())) + + def dump(self, obj, file): + if not isinstance(obj, pd.DataFrame): + raise TypeError(f'requires pd.DataFrame, but {type(obj)} is passed.') + # MEMO: to_parquet only supports a filepath as string (not a file handle) + obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) + + +class FeatherFileProcessorPandas(FileProcessor): + """Feather file processor for pandas DataFrames.""" + + def __init__(self, store_index_in_feather: bool): + super().__init__() + self._store_index_in_feather = store_index_in_feather + self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' + + def format(self): + return luigi.format.Nop + + def load(self, file): + # FIXME(mamo3gr): enable streaming (chunked) read with S3. + # pandas.read_feather accepts file-like object + # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, + # which is needed for pandas to read a file in chunks. + if ObjectStorage.is_buffered_reader(file): + loaded_df = pd.read_feather(file.name) + else: + loaded_df = pd.read_feather(BytesIO(file.read())) + + if self._store_index_in_feather: + if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): + index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] + index_column = index_columns[0] + index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] + if index_name == 'None': + index_name = None + loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) + loaded_df = loaded_df.drop(columns={index_column}) + + return loaded_df + + def dump(self, obj, file): + if not isinstance(obj, pd.DataFrame): + raise TypeError(f'requires pd.DataFrame, but {type(obj)} is passed.') + + dump_obj = obj.copy() + + if self._store_index_in_feather: + index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' + assert index_column_name not in dump_obj.columns, ( + f'column name {index_column_name} already exists in dump_obj. \nConsider not saving index by setting store_index_in_feather=False.' + ) + assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' + + dump_obj[index_column_name] = dump_obj.index + dump_obj = dump_obj.reset_index(drop=True) + + # to_feather supports "binary" file-like object, but file variable is text + dump_obj.to_feather(file.name) diff --git a/gokart/file_processor/polars.py b/gokart/file_processor/polars.py new file mode 100644 index 00000000..98b631d9 --- /dev/null +++ b/gokart/file_processor/polars.py @@ -0,0 +1,134 @@ +"""Polars-specific file processor implementations.""" + +from __future__ import annotations + +from io import BytesIO + +import luigi +import luigi.format +from luigi.format import TextFormat + +from gokart.file_processor.base import FileProcessor +from gokart.object_storage import ObjectStorage + +try: + import polars as pl + + HAS_POLARS = True +except ImportError: + HAS_POLARS = False + pl = None # type: ignore + + +class CsvFileProcessorPolars(FileProcessor): + """CSV file processor for polars DataFrames.""" + + def __init__(self, sep=',', encoding: str = 'utf-8'): + if not HAS_POLARS: + raise ImportError("polars is required for dataframe_type='polars'. Install with: pip install polars") + self._sep = sep + self._encoding = encoding + super().__init__() + + def format(self): + return TextFormat(encoding=self._encoding) + + def load(self, file): + try: + return pl.read_csv(file, separator=self._sep, encoding=self._encoding) + except Exception as e: + # Handle empty data gracefully + if 'empty' in str(e).lower() or 'no data' in str(e).lower(): + return pl.DataFrame() + raise + + def dump(self, obj, file): + if not isinstance(obj, pl.DataFrame): + raise TypeError(f'requires pl.DataFrame, but {type(obj)} is passed.') + obj.write_csv(file, separator=self._sep, include_header=True) + + +class JsonFileProcessorPolars(FileProcessor): + """JSON file processor for polars DataFrames.""" + + def __init__(self, orient: str | None = None): + if not HAS_POLARS: + raise ImportError("polars is required for dataframe_type='polars'. Install with: pip install polars") + self._orient = orient + + def format(self): + return luigi.format.Nop + + def load(self, file): + try: + if self._orient == 'records': + return pl.read_ndjson(file) + else: + return pl.read_json(file) + except Exception as e: + # Handle empty files + if 'empty' in str(e).lower() or 'no data' in str(e).lower(): + return pl.DataFrame() + raise + + def dump(self, obj, file): + if not isinstance(obj, pl.DataFrame): + raise TypeError(f'requires pl.DataFrame, but {type(obj)} is passed.') + if self._orient == 'records': + obj.write_ndjson(file) + else: + obj.write_json(file) + + +class ParquetFileProcessorPolars(FileProcessor): + """Parquet file processor for polars DataFrames.""" + + def __init__(self, engine='pyarrow', compression=None): + if not HAS_POLARS: + raise ImportError("polars is required for dataframe_type='polars'. Install with: pip install polars") + self._engine = engine # Ignored for polars + self._compression = compression + super().__init__() + + def format(self): + return luigi.format.Nop + + def load(self, file): + # polars.read_parquet can handle file paths or file-like objects + if ObjectStorage.is_buffered_reader(file): + return pl.read_parquet(file.name) + else: + return pl.read_parquet(BytesIO(file.read())) + + def dump(self, obj, file): + if not isinstance(obj, pl.DataFrame): + raise TypeError(f'requires pl.DataFrame, but {type(obj)} is passed.') + # polars write_parquet requires a file path + obj.write_parquet(file.name, compression=self._compression) + + +class FeatherFileProcessorPolars(FileProcessor): + """Feather file processor for polars DataFrames.""" + + def __init__(self, store_index_in_feather: bool): + if not HAS_POLARS: + raise ImportError("polars is required for dataframe_type='polars'. Install with: pip install polars") + super().__init__() + self._store_index_in_feather = store_index_in_feather # Ignored for polars + + def format(self): + return luigi.format.Nop + + def load(self, file): + # polars uses read_ipc for feather format + if ObjectStorage.is_buffered_reader(file): + return pl.read_ipc(file.name) + else: + return pl.read_ipc(BytesIO(file.read())) + + def dump(self, obj, file): + if not isinstance(obj, pl.DataFrame): + raise TypeError(f'requires pl.DataFrame, but {type(obj)} is passed.') + # polars uses write_ipc for feather format + # Note: store_index_in_feather is ignored for polars as it's pandas-specific + obj.write_ipc(file.name) diff --git a/pyproject.toml b/pyproject.toml index b58f35da..6caaf915 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ classifiers = [ ] dynamic = ["version"] +[project.optional-dependencies] +polars = ["polars>=0.19.0"] + [project.urls] Homepage = "https://github.com/m3dev/gokart" Repository = "https://github.com/m3dev/gokart" @@ -50,6 +53,7 @@ test = [ "matplotlib", "moto", "mypy", + "polars>=0.19.0", "pytest", "pytest-cov", "pytest-xdist", diff --git a/test/file_processor/__init__.py b/test/file_processor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/file_processor/test_base.py b/test/file_processor/test_base.py new file mode 100644 index 00000000..91fb42bb --- /dev/null +++ b/test/file_processor/test_base.py @@ -0,0 +1,78 @@ +"""Tests for base file processors (non-DataFrame processors).""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from collections.abc import Callable + +import boto3 +from luigi import LocalTarget +from moto import mock_aws + +from gokart.file_processor import PickleFileProcessor +from gokart.object_storage import ObjectStorage + + +class TestPickleFileProcessor(unittest.TestCase): + def test_dump_and_load_normal_obj(self): + var = 'abc' + processor = PickleFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.pkl' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(var, f) + with local_target.open('r') as f: + loaded = processor.load(f) + + self.assertEqual(loaded, var) + + def test_dump_and_load_class(self): + import functools + + def plus1(func: Callable[..., int]) -> Callable[..., int]: + @functools.wraps(func) + def wrapped() -> int: + ret = func() + return ret + 1 + + return wrapped + + class A: + def __init__(self) -> None: + self.run = plus1(self.run) # type: ignore + + def run(self) -> int: # type: ignore + return 1 + + obj = A() + processor = PickleFileProcessor() + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.pkl' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(obj, f) + with local_target.open('r') as f: + loaded = processor.load(f) + + self.assertEqual(loaded.run(), obj.run()) + + @mock_aws + def test_dump_and_load_with_readables3file(self): + conn = boto3.resource('s3', region_name='us-east-1') + conn.create_bucket(Bucket='test') + file_path = os.path.join('s3://test/', 'test.pkl') + + var = 'abc' + processor = PickleFileProcessor() + + target = ObjectStorage.get_object_storage_target(file_path, processor.format()) + with target.open('w') as f: + processor.dump(var, f) + with target.open('r') as f: + loaded = processor.load(f) + + self.assertEqual(loaded, var) diff --git a/test/file_processor/test_factory.py b/test/file_processor/test_factory.py new file mode 100644 index 00000000..7564358e --- /dev/null +++ b/test/file_processor/test_factory.py @@ -0,0 +1,54 @@ +"""Tests for file processor factory function.""" + +from __future__ import annotations + +import unittest + +from gokart.file_processor import ( + CsvFileProcessor, + FeatherFileProcessor, + GzipFileProcessor, + JsonFileProcessor, + NpzFileProcessor, + ParquetFileProcessor, + TextFileProcessor, + make_file_processor, +) + + +class TestMakeFileProcessor(unittest.TestCase): + def test_make_file_processor_with_txt_extension(self): + processor = make_file_processor('test.txt', store_index_in_feather=False) + self.assertIsInstance(processor, TextFileProcessor) + + def test_make_file_processor_with_csv_extension(self): + processor = make_file_processor('test.csv', store_index_in_feather=False) + self.assertIsInstance(processor, CsvFileProcessor) + + def test_make_file_processor_with_gz_extension(self): + processor = make_file_processor('test.gz', store_index_in_feather=False) + self.assertIsInstance(processor, GzipFileProcessor) + + def test_make_file_processor_with_json_extension(self): + processor = make_file_processor('test.json', store_index_in_feather=False) + self.assertIsInstance(processor, JsonFileProcessor) + + def test_make_file_processor_with_ndjson_extension(self): + processor = make_file_processor('test.ndjson', store_index_in_feather=False) + self.assertIsInstance(processor, JsonFileProcessor) + + def test_make_file_processor_with_npz_extension(self): + processor = make_file_processor('test.npz', store_index_in_feather=False) + self.assertIsInstance(processor, NpzFileProcessor) + + def test_make_file_processor_with_parquet_extension(self): + processor = make_file_processor('test.parquet', store_index_in_feather=False) + self.assertIsInstance(processor, ParquetFileProcessor) + + def test_make_file_processor_with_feather_extension(self): + processor = make_file_processor('test.feather', store_index_in_feather=True) + self.assertIsInstance(processor, FeatherFileProcessor) + + def test_make_file_processor_with_unsupported_extension(self): + with self.assertRaises(AssertionError): + make_file_processor('test.unsupported', store_index_in_feather=False) diff --git a/test/test_file_processor.py b/test/file_processor/test_pandas.py similarity index 59% rename from test/test_file_processor.py rename to test/file_processor/test_pandas.py index e7231ef4..5638a70a 100644 --- a/test/test_file_processor.py +++ b/test/file_processor/test_pandas.py @@ -1,28 +1,15 @@ +"""Tests for pandas-specific file processors.""" + from __future__ import annotations -import os import tempfile import unittest -from collections.abc import Callable -import boto3 import pandas as pd import pytest from luigi import LocalTarget -from moto import mock_aws - -from gokart.file_processor import ( - CsvFileProcessor, - FeatherFileProcessor, - GzipFileProcessor, - JsonFileProcessor, - NpzFileProcessor, - ParquetFileProcessor, - PickleFileProcessor, - TextFileProcessor, - make_file_processor, -) -from gokart.object_storage import ObjectStorage + +from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor class TestCsvFileProcessor(unittest.TestCase): @@ -126,69 +113,6 @@ def test_dump_and_load_json(self, orient, input_data, expected_json): pd.testing.assert_frame_equal(df_input, loaded_df) -class TestPickleFileProcessor(unittest.TestCase): - def test_dump_and_load_normal_obj(self): - var = 'abc' - processor = PickleFileProcessor() - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.pkl' - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(var, f) - with local_target.open('r') as f: - loaded = processor.load(f) - - self.assertEqual(loaded, var) - - def test_dump_and_load_class(self): - import functools - - def plus1(func: Callable[..., int]) -> Callable[..., int]: - @functools.wraps(func) - def wrapped() -> int: - ret = func() - return ret + 1 - - return wrapped - - class A: - def __init__(self) -> None: - self.run = plus1(self.run) # type: ignore - - def run(self) -> int: # type: ignore - return 1 - - obj = A() - processor = PickleFileProcessor() - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.pkl' - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(obj, f) - with local_target.open('r') as f: - loaded = processor.load(f) - - self.assertEqual(loaded.run(), obj.run()) - - @mock_aws - def test_dump_and_load_with_readables3file(self): - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket='test') - file_path = os.path.join('s3://test/', 'test.pkl') - - var = 'abc' - processor = PickleFileProcessor() - - target = ObjectStorage.get_object_storage_target(file_path, processor.format()) - with target.open('w') as f: - processor.dump(var, f) - with target.open('r') as f: - loaded = processor.load(f) - - self.assertEqual(loaded, var) - - class TestFeatherFileProcessor(unittest.TestCase): def test_feather_should_return_same_dataframe(self): df = pd.DataFrame({'a': [1]}) @@ -233,41 +157,3 @@ def test_feather_should_raise_error_index_name_is_None(self): with local_target.open('w') as f: with self.assertRaises(AssertionError): processor.dump(df, f) - - -class TestMakeFileProcessor(unittest.TestCase): - def test_make_file_processor_with_txt_extension(self): - processor = make_file_processor('test.txt', store_index_in_feather=False) - self.assertIsInstance(processor, TextFileProcessor) - - def test_make_file_processor_with_csv_extension(self): - processor = make_file_processor('test.csv', store_index_in_feather=False) - self.assertIsInstance(processor, CsvFileProcessor) - - def test_make_file_processor_with_gz_extension(self): - processor = make_file_processor('test.gz', store_index_in_feather=False) - self.assertIsInstance(processor, GzipFileProcessor) - - def test_make_file_processor_with_json_extension(self): - processor = make_file_processor('test.json', store_index_in_feather=False) - self.assertIsInstance(processor, JsonFileProcessor) - - def test_make_file_processor_with_ndjson_extension(self): - processor = make_file_processor('test.ndjson', store_index_in_feather=False) - self.assertIsInstance(processor, JsonFileProcessor) - - def test_make_file_processor_with_npz_extension(self): - processor = make_file_processor('test.npz', store_index_in_feather=False) - self.assertIsInstance(processor, NpzFileProcessor) - - def test_make_file_processor_with_parquet_extension(self): - processor = make_file_processor('test.parquet', store_index_in_feather=False) - self.assertIsInstance(processor, ParquetFileProcessor) - - def test_make_file_processor_with_feather_extension(self): - processor = make_file_processor('test.feather', store_index_in_feather=True) - self.assertIsInstance(processor, FeatherFileProcessor) - - def test_make_file_processor_with_unsupported_extension(self): - with self.assertRaises(AssertionError): - make_file_processor('test.unsupported', store_index_in_feather=False) diff --git a/test/file_processor/test_polars.py b/test/file_processor/test_polars.py new file mode 100644 index 00000000..7a366076 --- /dev/null +++ b/test/file_processor/test_polars.py @@ -0,0 +1,387 @@ +"""Tests for polars-specific file processors.""" + +from __future__ import annotations + +import tempfile + +import pandas as pd +import pytest +from luigi import LocalTarget + +from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor + +try: + import polars as pl + + HAS_POLARS = True +except ImportError: + HAS_POLARS = False + + +@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') +class TestCsvFileProcessorWithPolars: + """Tests for CsvFileProcessor with polars support""" + + def test_dump_polars_dataframe(self): + """Test dumping a polars DataFrame""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = CsvFileProcessor(dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # Verify file was created and can be read by polars + loaded_df = pl.read_csv(temp_path) + assert loaded_df.equals(df) + + def test_load_polars_dataframe(self): + """Test loading a CSV as polars DataFrame""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = CsvFileProcessor(dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.write_csv(temp_path) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_and_load_polars_roundtrip(self): + """Test roundtrip dump and load with polars""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = CsvFileProcessor(dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_polars_with_pandas_load(self): + """Test that polars dump can be loaded by pandas processor""" + df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor_polars = CsvFileProcessor(dataframe_type='polars') + processor_pandas = CsvFileProcessor(dataframe_type='pandas') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + # Dump with polars + local_target = LocalTarget(path=temp_path, format=processor_polars.format()) + with local_target.open('w') as f: + processor_polars.dump(df_polars, f) + + # Load with pandas + with local_target.open('r') as f: + loaded_df = processor_pandas.load(f) + + assert isinstance(loaded_df, pd.DataFrame) + # Compare values + df_polars.equals(pl.from_pandas(loaded_df)) + + def test_polars_with_different_separator(self): + """Test polars with TSV (tab-separated values)""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = CsvFileProcessor(sep='\t', dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.tsv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_error_when_polars_not_available_for_load(self): + """Test error message when polars is requested but a polars operation fails""" + # This test is a bit tricky since polars IS installed in this test class + # We'll just verify the processor accepts the parameter + processor = CsvFileProcessor(dataframe_type='polars') + assert processor._dataframe_type == 'polars' + + +@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') +class TestJsonFileProcessorWithPolars: + """Tests for JsonFileProcessor with polars support""" + + def test_dump_polars_dataframe(self): + """Test dumping a polars DataFrame to JSON""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = JsonFileProcessor(orient=None, dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # Verify file was created and can be read by polars + loaded_df = pl.read_json(temp_path) + assert loaded_df.equals(df) + + def test_load_polars_dataframe(self): + """Test loading a JSON as polars DataFrame""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = JsonFileProcessor(orient=None, dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + df.write_json(temp_path) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_and_load_polars_roundtrip(self): + """Test roundtrip dump and load with polars""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = JsonFileProcessor(orient=None, dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_and_load_ndjson_with_polars(self): + """Test ndjson (records orient) with polars""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = JsonFileProcessor(orient='records', dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.ndjson' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_polars_with_pandas_load(self): + """Test that polars dump can be loaded by pandas processor""" + df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor_polars = JsonFileProcessor(orient=None, dataframe_type='polars') + processor_pandas = JsonFileProcessor(orient=None, dataframe_type='pandas') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + + # Dump with polars + local_target = LocalTarget(path=temp_path, format=processor_polars.format()) + with local_target.open('w') as f: + processor_polars.dump(df_polars, f) + + # Load with pandas + with local_target.open('r') as f: + loaded_df = processor_pandas.load(f) + + assert isinstance(loaded_df, pd.DataFrame) + # Compare values + assert list(loaded_df['a']) == [1, 2, 3] + assert list(loaded_df['b']) == [4, 5, 6] + + +@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') +class TestParquetFileProcessorWithPolars: + """Tests for ParquetFileProcessor with polars support""" + + def test_dump_polars_dataframe(self): + """Test dumping a polars DataFrame to Parquet""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = ParquetFileProcessor(dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.parquet' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # Verify file was created and can be read by polars + loaded_df = pl.read_parquet(temp_path) + assert loaded_df.equals(df) + + def test_load_polars_dataframe(self): + """Test loading a Parquet as polars DataFrame""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = ParquetFileProcessor(dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.parquet' + df.write_parquet(temp_path) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_and_load_polars_roundtrip(self): + """Test roundtrip dump and load with polars""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = ParquetFileProcessor(dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.parquet' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_polars_with_pandas_load(self): + """Test that polars dump can be loaded by pandas processor""" + df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor_polars = ParquetFileProcessor(dataframe_type='polars') + processor_pandas = ParquetFileProcessor(dataframe_type='pandas') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.parquet' + + # Dump with polars + local_target = LocalTarget(path=temp_path, format=processor_polars.format()) + with local_target.open('w') as f: + processor_polars.dump(df_polars, f) + + # Load with pandas + with local_target.open('r') as f: + loaded_df = processor_pandas.load(f) + + assert isinstance(loaded_df, pd.DataFrame) + df_polars.equals(pl.from_pandas(loaded_df)) + + def test_parquet_with_compression(self): + """Test polars with parquet compression""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = ParquetFileProcessor(compression='gzip', dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.parquet' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + +@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') +class TestFeatherFileProcessorWithPolars: + """Tests for FeatherFileProcessor with polars support""" + + def test_dump_polars_dataframe(self): + """Test dumping a polars DataFrame to Feather""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # Verify file was created and can be read by polars + loaded_df = pl.read_ipc(temp_path) + assert loaded_df.equals(df) + + def test_load_polars_dataframe(self): + """Test loading a Feather as polars DataFrame""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + df.write_ipc(temp_path) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_and_load_polars_roundtrip(self): + """Test roundtrip dump and load with polars""" + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + assert isinstance(loaded_df, pl.DataFrame) + assert loaded_df.equals(df) + + def test_dump_polars_with_pandas_load(self): + """Test that polars dump can be loaded by pandas processor""" + df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + processor_polars = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') + processor_pandas = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='pandas') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + # Dump with polars + local_target = LocalTarget(path=temp_path, format=processor_polars.format()) + with local_target.open('w') as f: + processor_polars.dump(df_polars, f) + + # Load with pandas + with local_target.open('r') as f: + loaded_df = processor_pandas.load(f) + + assert isinstance(loaded_df, pd.DataFrame) + # Compare values + df_polars.equals(pl.from_pandas(loaded_df)) diff --git a/test/test_target.py b/test/test_target.py index 2d82e76f..792a7b1f 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -11,7 +11,7 @@ from matplotlib import pyplot from moto import mock_aws -from gokart.file_processor import _ChunkedLargeFileReader +from gokart.file_processor.base import _ChunkedLargeFileReader from gokart.target import make_model_target, make_target from test.util import _get_temporary_directory @@ -29,7 +29,7 @@ def test_save_and_load_pickle_file(self): target = make_target(file_path=file_path, unique_id=None) target.dump(obj) - with unittest.mock.patch('gokart.file_processor._ChunkedLargeFileReader', wraps=_ChunkedLargeFileReader) as monkey: + with unittest.mock.patch('gokart.file_processor.base._ChunkedLargeFileReader', wraps=_ChunkedLargeFileReader) as monkey: loaded = target.load() monkey.assert_called() diff --git a/uv.lock b/uv.lock index 8b9b021e..0cc2e97f 100644 --- a/uv.lock +++ b/uv.lock @@ -487,6 +487,11 @@ dependencies = [ { name = "uritemplate" }, ] +[package.optional-dependencies] +polars = [ + { name = "polars" }, +] + [package.dev-dependencies] lint = [ { name = "mypy" }, @@ -498,6 +503,7 @@ test = [ { name = "matplotlib" }, { name = "moto" }, { name = "mypy" }, + { name = "polars" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-xdist" }, @@ -518,12 +524,14 @@ requires-dist = [ { name = "luigi" }, { name = "numpy" }, { name = "pandas" }, + { name = "polars", marker = "extra == 'polars'", specifier = ">=0.19.0" }, { name = "pyarrow" }, { name = "redis" }, { name = "slack-sdk" }, { name = "typing-extensions", marker = "python_full_version < '3.13'", specifier = ">=4.11.0" }, { name = "uritemplate" }, ] +provides-extras = ["polars"] [package.metadata.requires-dev] lint = [ @@ -536,6 +544,7 @@ test = [ { name = "matplotlib" }, { name = "moto" }, { name = "mypy" }, + { name = "polars", specifier = ">=0.19.0" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-xdist" }, @@ -1202,6 +1211,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556, upload-time = "2024-04-20T21:34:40.434Z" }, ] +[[package]] +name = "polars" +version = "1.36.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/dc/56f2a90c79a2cb13f9e956eab6385effe54216ae7a2068b3a6406bae4345/polars-1.36.1.tar.gz", hash = "sha256:12c7616a2305559144711ab73eaa18814f7aa898c522e7645014b68f1432d54c", size = 711993, upload-time = "2025-12-10T01:14:53.033Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/c6/36a1b874036b49893ecae0ac44a2f63d1a76e6212631a5b2f50a86e0e8af/polars-1.36.1-py3-none-any.whl", hash = "sha256:853c1bbb237add6a5f6d133c15094a9b727d66dd6a4eb91dbb07cdb056b2b8ef", size = 802429, upload-time = "2025-12-10T01:13:53.838Z" }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.36.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/31/df/597c0ef5eb8d761a16d72327846599b57c5d40d7f9e74306fc154aba8c37/polars_runtime_32-1.36.1.tar.gz", hash = "sha256:201c2cfd80ceb5d5cd7b63085b5fd08d6ae6554f922bcb941035e39638528a09", size = 2788751, upload-time = "2025-12-10T01:14:54.172Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/ea/871129a2d296966c0925b078a9a93c6c5e7facb1c5eebfcd3d5811aeddc1/polars_runtime_32-1.36.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:327b621ca82594f277751f7e23d4b939ebd1be18d54b4cdf7a2f8406cecc18b2", size = 43494311, upload-time = "2025-12-10T01:13:56.096Z" }, + { url = "https://files.pythonhosted.org/packages/d8/76/0038210ad1e526ce5bb2933b13760d6b986b3045eccc1338e661bd656f77/polars_runtime_32-1.36.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ab0d1f23084afee2b97de8c37aa3e02ec3569749ae39571bd89e7a8b11ae9e83", size = 39300602, upload-time = "2025-12-10T01:13:59.366Z" }, + { url = "https://files.pythonhosted.org/packages/54/1e/2707bee75a780a953a77a2c59829ee90ef55708f02fc4add761c579bf76e/polars_runtime_32-1.36.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:899b9ad2e47ceb31eb157f27a09dbc2047efbf4969a923a6b1ba7f0412c3e64c", size = 44511780, upload-time = "2025-12-10T01:14:02.285Z" }, + { url = "https://files.pythonhosted.org/packages/11/b2/3fede95feee441be64b4bcb32444679a8fbb7a453a10251583053f6efe52/polars_runtime_32-1.36.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:d9d077bb9df711bc635a86540df48242bb91975b353e53ef261c6fae6cb0948f", size = 40688448, upload-time = "2025-12-10T01:14:05.131Z" }, + { url = "https://files.pythonhosted.org/packages/05/0f/e629713a72999939b7b4bfdbf030a32794db588b04fdf3dc977dd8ea6c53/polars_runtime_32-1.36.1-cp39-abi3-win_amd64.whl", hash = "sha256:cc17101f28c9a169ff8b5b8d4977a3683cd403621841623825525f440b564cf0", size = 44464898, upload-time = "2025-12-10T01:14:08.296Z" }, + { url = "https://files.pythonhosted.org/packages/d1/d8/a12e6aa14f63784cead437083319ec7cece0d5bb9a5bfe7678cc6578b52a/polars_runtime_32-1.36.1-cp39-abi3-win_arm64.whl", hash = "sha256:809e73857be71250141225ddd5d2b30c97e6340aeaa0d445f930e01bef6888dc", size = 39798896, upload-time = "2025-12-10T01:14:11.568Z" }, +] + [[package]] name = "proto-plus" version = "1.26.0"