Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions gokart/file_processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,21 @@ def dump(self, obj, file):
return self._impl.dump(obj, file)


def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
def make_file_processor(file_path: str, store_index_in_feather: bool = True, *, dataframe_type: DataFrameType = 'pandas') -> FileProcessor:
"""Create a file processor based on file extension with default parameters."""
extension2processor = {
'.txt': TextFileProcessor(),
'.ini': TextFileProcessor(),
'.csv': CsvFileProcessor(sep=','),
'.tsv': CsvFileProcessor(sep='\t'),
'.csv': CsvFileProcessor(sep=',', dataframe_type=dataframe_type),
'.tsv': CsvFileProcessor(sep='\t', dataframe_type=dataframe_type),
'.pkl': PickleFileProcessor(),
'.gz': GzipFileProcessor(),
'.json': JsonFileProcessor(),
'.ndjson': JsonFileProcessor(orient='records'),
'.json': JsonFileProcessor(dataframe_type=dataframe_type),
'.ndjson': JsonFileProcessor(dataframe_type=dataframe_type, orient='records'),
'.xml': XmlFileProcessor(),
'.npz': NpzFileProcessor(),
'.parquet': ParquetFileProcessor(compression='gzip'),
'.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather),
'.parquet': ParquetFileProcessor(compression='gzip', dataframe_type=dataframe_type),
'.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather, dataframe_type=dataframe_type),
'.png': BinaryFileProcessor(),
'.jpg': BinaryFileProcessor(),
}
Expand Down
12 changes: 10 additions & 2 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import gokart.target
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock
from gokart.file_processor import FileProcessor
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper
from gokart.utils import FlattenableItems, flatten, map_flattenable_items
from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items

logger = getLogger(__name__)

Expand Down Expand Up @@ -219,6 +219,10 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool
file_path = os.path.join(self.workspace_directory, formatted_relative_file_path)
unique_id = self.make_unique_id() if use_unique_id else None

# Auto-select processor based on type parameter if not provided
if processor is None and relative_file_path is not None:
processor = self._create_processor_for_dataframe_type(file_path)

task_lock_params = make_task_lock_params(
file_path=file_path,
unique_id=unique_id,
Expand All @@ -232,6 +236,10 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
)

def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor:
df_type = get_dataframe_type_from_task(self)
return make_file_processor(file_path, dataframe_type=df_type, store_index_in_feather=self.store_index_in_feather)

def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')
Expand Down
51 changes: 50 additions & 1 deletion gokart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from collections.abc import Callable, Iterable
from io import BytesIO
from typing import Any, Protocol, TypeAlias, TypeVar
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin

import dill
import luigi
Expand Down Expand Up @@ -92,3 +92,52 @@ def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> An
assert file.seekable(), f'{file} is not seekable.'
file.seek(0)
return pd.read_pickle(file)


def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars', 'polars-lazy']:
"""
Extract DataFrame type from TaskOnKart[T] type parameter.

Examines the type parameter T of a TaskOnKart subclass to determine
whether it uses pandas or polars DataFrames/LazyFrames.

Args:
task: A TaskOnKart instance or class

Returns:
'pandas', 'polars', or 'polars-lazy' (defaults to 'pandas' if type cannot be determined)

Examples:
>>> class MyTask(TaskOnKart[pd.DataFrame]): pass
>>> get_dataframe_type_from_task(MyTask())
'pandas'

>>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass
>>> get_dataframe_type_from_task(MyPolarsTask())
'polars'
"""
task_class = task if isinstance(task, type) else task.__class__

# Walk the MRO to find TaskOnKart[...] even when defined on a parent class
mro = task_class.mro() if hasattr(task_class, 'mro') else [task_class]

for cls in mro:
for base in getattr(cls, '__orig_bases__', ()):
origin = get_origin(base)
if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart':
args = get_args(base)
if not args:
continue
df_type = args[0]
module = getattr(df_type, '__module__', '')

# Check module name to determine DataFrame type
if 'polars' in module:
name = getattr(df_type, '__name__', '')
if name == 'LazyFrame':
return 'polars-lazy'
return 'polars'
elif 'pandas' in module:
return 'pandas'

return 'pandas' # Default to pandas for backward compatibility
154 changes: 153 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import unittest

from gokart.utils import flatten, map_flattenable_items
import pandas as pd
import pytest

from gokart.task import TaskOnKart
from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items

try:
import polars as pl

HAS_POLARS = True
except ImportError:
HAS_POLARS = False


class TestFlatten(unittest.TestCase):
Expand Down Expand Up @@ -34,3 +45,144 @@ def test_map_flattenable_items(self):
),
{'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}},
)


class TestGetDataFrameTypeFromTask(unittest.TestCase):
"""Tests for get_dataframe_type_from_task function."""

def test_pandas_dataframe_from_instance(self):
"""Test detecting pandas DataFrame from task instance."""

class _PandasTaskInstance(TaskOnKart[pd.DataFrame]):
pass

task = _PandasTaskInstance()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_pandas_dataframe_from_class(self):
"""Test detecting pandas DataFrame from task class."""

class _PandasTaskClass(TaskOnKart[pd.DataFrame]):
pass

self.assertEqual(get_dataframe_type_from_task(_PandasTaskClass), 'pandas')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_dataframe_from_instance(self):
"""Test detecting polars DataFrame from task instance."""

class _PolarsTaskInstance(TaskOnKart[pl.DataFrame]):
pass

task = _PolarsTaskInstance()
self.assertEqual(get_dataframe_type_from_task(task), 'polars')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_dataframe_from_class(self):
"""Test detecting polars DataFrame from task class."""

class _PolarsTaskClass(TaskOnKart[pl.DataFrame]):
pass

self.assertEqual(get_dataframe_type_from_task(_PolarsTaskClass), 'polars')

def test_no_type_parameter_defaults_to_pandas(self):
"""Test that tasks without type parameter default to pandas."""

# Create a class without __orig_bases__ by not using type parameters
class PlainTask:
pass

task = PlainTask()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_non_taskonkart_class_defaults_to_pandas(self):
"""Test that non-TaskOnKart classes default to pandas."""

class RegularClass:
pass

task = RegularClass()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_taskonkart_with_non_dataframe_type(self):
"""Test TaskOnKart with non-DataFrame type parameter defaults to pandas."""

class _StringTask(TaskOnKart[str]):
pass

task = _StringTask()
# Should default to pandas since str module is not 'pandas' or 'polars'
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_nested_inheritance_pandas(self):
"""Test that nested inheritance without direct type parameter defaults to pandas."""

class _BasePandasTask(TaskOnKart[pd.DataFrame]):
pass

class _DerivedPandasTask(_BasePandasTask):
pass

task = _DerivedPandasTask()
# _DerivedPandasTask doesn't have its own __orig_bases__ with type parameter,
# so it defaults to 'pandas'
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars(self):
"""Test detecting polars DataFrame type through nested inheritance."""

class _BasePolarsTask(TaskOnKart[pl.DataFrame]):
pass

class _DerivedPolarsTask(_BasePolarsTask):
pass

task = _DerivedPolarsTask()
# Function should detect 'polars' through the inheritance chain
self.assertEqual(get_dataframe_type_from_task(task), 'polars')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_lazyframe_from_instance(self):
class _LazyTaskInstance(TaskOnKart[pl.LazyFrame]):
pass

task = _LazyTaskInstance()
self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_lazyframe_from_class(self):
class _LazyTaskClass(TaskOnKart[pl.LazyFrame]):
pass

self.assertEqual(get_dataframe_type_from_task(_LazyTaskClass), 'polars-lazy')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars_lazyframe(self):
class _BaseLazyTask(TaskOnKart[pl.LazyFrame]):
pass

class _DerivedLazyTask(_BaseLazyTask):
pass

task = _DerivedLazyTask()
self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars_with_mixin(self):
"""Derived class with multiple bases should still detect polars through MRO."""

class _Mixin:
pass

class _BasePolarsTaskWithMixin(TaskOnKart[pl.DataFrame]):
pass

# Multiple inheritance gives _DerivedTask its own __orig_bases__,
# which shadows the parent's and doesn't contain TaskOnKart[...].
class _DerivedTaskWithMixin(_BasePolarsTaskWithMixin, _Mixin):
pass

task = _DerivedTaskWithMixin()
self.assertEqual(get_dataframe_type_from_task(task), 'polars')