Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 139 additions & 59 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import luigi.contrib.s3
import luigi.format
import numpy as np
import pandas as pd
import pandas.errors
from luigi.format import TextFormat

from gokart.object_storage import ObjectStorage
Expand Down Expand Up @@ -131,13 +129,31 @@ def format(self):

def load(self, file):
try:
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
except pd.errors.EmptyDataError:
return pd.DataFrame()
import pandas as pd

try:
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
except pd.errors.EmptyDataError:
return pd.DataFrame()
except ImportError:
import polars as pl

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it is a little confused to fallback to polars when import error occurred.

In my opinion, we need to introduce some global feature flags to switch dataframe frameworks.
For backward compatibility, we set pandas in default, then add a configure function like the follwoing.

DATAFRAME_FRAMEWORK = 'pandas'
def setup_dataframe_framework(framework: Literal['pandas', 'polars']):
   if framework == 'polars':
     try
         import  polars
     except ImportError:
          raise RuntimeError(...)
    DATAFRAME_FRAMEWORK = framework

According to the flag, we can switch the implementation.

In addition, we can probably set processor class, though I do not check this work...
This would make each CsvFileProcessor simple

if DATAFRAME_FRAMEWORK == 'pandas'
  CsvFileProcessor  =  PandasCsvFileProcessor 
elif DATAFRAME_FRAMEWORK == 'polars'
  CsvFileProcessor  =  PolarsCsvFileProcessor 

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import os
from typing import Protocol, Type


class IFeature(Protocol):
    def run(self) -> None: ...


class Feature1:
    def __init__(self): ...
    def run(self):
        print('feature1')


class Feature2:
    def __init__(self): ...
    def run(self):
        print('feature2')


Feature: Type[IFeature]
if os.environ.get('FEATURE') == '1':
    Feature = Feature1
elif os.environ.get('FEATURE') == '2':
    Feature = Feature2
else:
    raise ValueError("Invalid FEATURE environment variable value. Please set it to '1' or '2'.")


Feature().run()
❯ uv run foo.py
Traceback (most recent call last):
  File "gokart/foo.py", line 27, in <module>
    raise ValueError("Invalid FEATURE environment variable value. Please set it to '1' or '2'.")
ValueError: Invalid FEATURE environment variable value. Please set it to '1' or '2'.
❯ FEATURE=1 uv run foo.py
feature1
❯ FEATURE=2 uv run foo.py
feature2

Switching class by an environment variable works!

@hirosassa hirosassa Mar 15, 2025

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hiro-o918 Thanks for the comment. I applied your suggestion. I think it looks fine!


try:
return pl.read_csv(file, sep=self._sep, encoding=self._encoding)
except pl.exceptions.NoDataError:
return pd.DataFrame()

def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.'
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding)
try:
import pandas as pd

assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.'
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding)
except ImportError:
import polars as pl

assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.'
obj.write_csv(file, separator=self._sep, include_header=True)


class GzipFileProcessor(FileProcessor):
Expand All @@ -161,17 +177,39 @@ def format(self):

def load(self, file):
try:
return pd.read_json(file)
except pd.errors.EmptyDataError:
return pd.DataFrame()
import pandas as pd

try:
return self.read_json(file)
except pd.errors.EmptyDataError:
return pd.DataFrame()
except ImportError:
import polars as pl

try:
return self.read_json(file)
except pl.exceptions.NoDataError:
return pl.DataFrame()

def dump(self, obj, file):
assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), (
f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.'
)
if isinstance(obj, dict):
obj = pd.DataFrame.from_dict(obj)
obj.to_json(file)
try:
import pandas as pd

assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), (
f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.'
)
if isinstance(obj, dict):
obj = pd.DataFrame.from_dict(obj)
obj.to_json(file)
except ImportError:
import polars as pl

assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), (
f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.'
)
if isinstance(obj, dict):
obj = pl.from_dict(obj)
obj.write_json(file)


class XmlFileProcessor(FileProcessor):
Expand Down Expand Up @@ -211,19 +249,39 @@ def format(self):
return luigi.format.Nop

def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_parquet accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
return pd.read_parquet(file.name)
else:
return pd.read_parquet(BytesIO(file.read()))
try:
import pandas as pd

# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_parquet accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
return pd.read_parquet(file.name)
else:
return pd.read_parquet(BytesIO(file.read()))
except ImportError:
import polars as pl

if ObjectStorage.is_buffered_reader(file):
return pl.read_parquet(file.name)
else:
return pl.read_parquet(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
# MEMO: to_parquet only supports a filepath as string (not a file handle)
obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression)
try:
import pandas as pd

assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
# MEMO: to_parquet only supports a filepath as string (not a file handle)
obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression)
except ImportError:
import polars as pl

assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
use_pyarrow = self._engine == 'pyarrow'
compression = 'uncompressed' if self._compression is None else self._compression
obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression)


class FeatherFileProcessor(FileProcessor):
Expand All @@ -236,44 +294,66 @@ def format(self):
return luigi.format.Nop

def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_feather accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
loaded_df = pd.read_feather(file.name)
else:
loaded_df = pd.read_feather(BytesIO(file.read()))

if self._store_index_in_feather:
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
index_column = index_columns[0]
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
if index_name == 'None':
index_name = None
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
loaded_df = loaded_df.drop(columns={index_column})

return loaded_df
try:
import pandas as pd

# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_feather accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
loaded_df = pd.read_feather(file.name)
else:
loaded_df = pd.read_feather(BytesIO(file.read()))

if self._store_index_in_feather:
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
index_column = index_columns[0]
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
if index_name == 'None':
index_name = None
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
loaded_df = loaded_df.drop(columns={index_column})

return loaded_df
except ImportError:
import polars as pl

# Since polars' DataFrame doesn't have index, just load feather file
if ObjectStorage.is_buffered_reader(file):
loaded_df = pl.read_ipc(file.name)
else:
loaded_df = pl.read_ipc(BytesIO(file.read()))

return loaded_df

def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
dump_obj = obj.copy()
try:
import pandas as pd

if self._store_index_in_feather:
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
assert index_column_name not in dump_obj.columns, (
f'column name {index_column_name} already exists in dump_obj. \
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
dump_obj = obj.copy()

if self._store_index_in_feather:
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
assert index_column_name not in dump_obj.columns, (
f'column name {index_column_name} already exists in dump_obj. \
Consider not saving index by setting store_index_in_feather=False.'
)
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'
)
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'

dump_obj[index_column_name] = dump_obj.index
dump_obj = dump_obj.reset_index(drop=True)

dump_obj[index_column_name] = dump_obj.index
dump_obj = dump_obj.reset_index(drop=True)
# to_feather supports "binary" file-like object, but file variable is text
dump_obj.to_feather(file.name)
except ImportError:
import polars as pl

# to_feather supports "binary" file-like object, but file variable is text
dump_obj.to_feather(file.name)
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
dump_obj = obj.copy()
dump_obj.write_ipc(file.name)


def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ Homepage = "https://github.com/m3dev/gokart"
Repository = "https://github.com/m3dev/gokart"
Documentation = "https://gokart.readthedocs.io/en/latest/"

[project.optional-dependencies]
pandas = ["pandas"]
polars = ["polars"]

[dependency-groups]
test = [
"fakeredis",
Expand Down