From a104164ea2aa1822591c99ecfc61a4e89fcfe5ea Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Tue, 16 Dec 2025 00:06:09 +0900 Subject: [PATCH 1/3] feat: add automatic processor selection based on DataFrame type parameter --- gokart/task.py | 38 ++++++++++++- gokart/utils.py | 50 ++++++++++++++++- test/test_utils.py | 136 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 221 insertions(+), 3 deletions(-) diff --git a/gokart/task.py b/gokart/task.py index a9db1b69..4728d595 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -25,7 +25,7 @@ from gokart.required_task_output import RequiredTaskOutput from gokart.target import TargetOnKart from gokart.task_complete_check import task_complete_check_wrapper -from gokart.utils import FlattenableItems, flatten, map_flattenable_items +from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items logger = getLogger(__name__) @@ -219,6 +219,10 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None + # Auto-select processor based on type parameter if not provided + if processor is None and relative_file_path is not None: + processor = self._create_processor_for_dataframe_type(file_path) + task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, @@ -232,6 +236,38 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) + def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor | None: + """ + Create a file processor with appropriate return_type based on task's type parameter. + + Args: + file_path: Path to the file + + Returns: + FileProcessor with return_type set, or None to use default processor + """ + from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor + + extension = os.path.splitext(file_path)[1] + df_type = get_dataframe_type_from_task(self) + + # Create custom processor for DataFrame-supporting file types with type parameter + if extension == '.csv': + return CsvFileProcessor(sep=',', dataframe_type=df_type) + elif extension == '.tsv': + return CsvFileProcessor(sep='\t', dataframe_type=df_type) + elif extension == '.json': + return JsonFileProcessor(orient=None, dataframe_type=df_type) + elif extension == '.ndjson': + return JsonFileProcessor(orient='records', dataframe_type=df_type) + elif extension == '.parquet': + return ParquetFileProcessor(dataframe_type=df_type) + elif extension == '.feather': + return FeatherFileProcessor(store_index_in_feather=self.store_index_in_feather, dataframe_type=df_type) + + # For other file types, use default processor selection + return None + def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') diff --git a/gokart/utils.py b/gokart/utils.py index 510db5c9..61b16cad 100644 --- a/gokart/utils.py +++ b/gokart/utils.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable, Iterable from io import BytesIO -from typing import Any, Protocol, TypeAlias, TypeVar +from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin import dill import luigi @@ -92,3 +92,51 @@ def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> An assert file.seekable(), f'{file} is not seekable.' file.seek(0) return pd.read_pickle(file) + + +def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars', 'polars-lazy']: + """ + Extract DataFrame type from TaskOnKart[T] type parameter. + + Examines the type parameter T of a TaskOnKart subclass to determine + whether it uses pandas or polars DataFrames/LazyFrames. + + Args: + task: A TaskOnKart instance or class + + Returns: + 'pandas', 'polars', or 'polars-lazy' (defaults to 'pandas' if type cannot be determined) + + Examples: + >>> class MyTask(TaskOnKart[pd.DataFrame]): pass + >>> get_dataframe_type_from_task(MyTask()) + 'pandas' + + >>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass + >>> get_dataframe_type_from_task(MyPolarsTask()) + 'polars' + """ + task_class = task if isinstance(task, type) else task.__class__ + + if not hasattr(task_class, '__orig_bases__'): + return 'pandas' + + for base in task_class.__orig_bases__: + origin = get_origin(base) + # Check if this is a TaskOnKart subclass + if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart': + args = get_args(base) + if args: + df_type = args[0] + module = getattr(df_type, '__module__', '') + + # Check module name to determine DataFrame type + if 'polars' in module: + name = getattr(df_type, '__name__', '') + if name == 'LazyFrame': + return 'polars-lazy' + return 'polars' + elif 'pandas' in module: + return 'pandas' + + return 'pandas' # Default to pandas for backward compatibility diff --git a/test/test_utils.py b/test/test_utils.py index 9b49d330..33cdd100 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,17 @@ import unittest -from gokart.utils import flatten, map_flattenable_items +import pandas as pd +import pytest + +from gokart.task import TaskOnKart +from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items + +try: + import polars as pl + + HAS_POLARS = True +except ImportError: + HAS_POLARS = False class TestFlatten(unittest.TestCase): @@ -34,3 +45,126 @@ def test_map_flattenable_items(self): ), {'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}}, ) + + +class TestGetDataFrameTypeFromTask(unittest.TestCase): + """Tests for get_dataframe_type_from_task function.""" + + def test_pandas_dataframe_from_instance(self): + """Test detecting pandas DataFrame from task instance.""" + + class _PandasTaskInstance(TaskOnKart[pd.DataFrame]): + pass + + task = _PandasTaskInstance() + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_pandas_dataframe_from_class(self): + """Test detecting pandas DataFrame from task class.""" + + class _PandasTaskClass(TaskOnKart[pd.DataFrame]): + pass + + self.assertEqual(get_dataframe_type_from_task(_PandasTaskClass), 'pandas') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_polars_dataframe_from_instance(self): + """Test detecting polars DataFrame from task instance.""" + + class _PolarsTaskInstance(TaskOnKart[pl.DataFrame]): + pass + + task = _PolarsTaskInstance() + self.assertEqual(get_dataframe_type_from_task(task), 'polars') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_polars_dataframe_from_class(self): + """Test detecting polars DataFrame from task class.""" + + class _PolarsTaskClass(TaskOnKart[pl.DataFrame]): + pass + + self.assertEqual(get_dataframe_type_from_task(_PolarsTaskClass), 'polars') + + def test_no_type_parameter_defaults_to_pandas(self): + """Test that tasks without type parameter default to pandas.""" + + # Create a class without __orig_bases__ by not using type parameters + class PlainTask: + pass + + task = PlainTask() + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_non_taskonkart_class_defaults_to_pandas(self): + """Test that non-TaskOnKart classes default to pandas.""" + + class RegularClass: + pass + + task = RegularClass() + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_taskonkart_with_non_dataframe_type(self): + """Test TaskOnKart with non-DataFrame type parameter defaults to pandas.""" + + class _StringTask(TaskOnKart[str]): + pass + + task = _StringTask() + # Should default to pandas since str module is not 'pandas' or 'polars' + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_nested_inheritance_pandas(self): + """Test that nested inheritance without direct type parameter defaults to pandas.""" + + class _BasePandasTask(TaskOnKart[pd.DataFrame]): + pass + + class _DerivedPandasTask(_BasePandasTask): + pass + + task = _DerivedPandasTask() + # _DerivedPandasTask doesn't have its own __orig_bases__ with type parameter, + # so it defaults to 'pandas' + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_nested_inheritance_polars(self): + """Test detecting polars DataFrame type through nested inheritance.""" + + class _BasePolarsTask(TaskOnKart[pl.DataFrame]): + pass + + class _DerivedPolarsTask(_BasePolarsTask): + pass + + task = _DerivedPolarsTask() + # Function should detect 'polars' through the inheritance chain + self.assertEqual(get_dataframe_type_from_task(task), 'polars') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_polars_lazyframe_from_instance(self): + class _LazyTaskInstance(TaskOnKart[pl.LazyFrame]): + pass + + task = _LazyTaskInstance() + self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_polars_lazyframe_from_class(self): + class _LazyTaskClass(TaskOnKart[pl.LazyFrame]): + pass + + self.assertEqual(get_dataframe_type_from_task(_LazyTaskClass), 'polars-lazy') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_nested_inheritance_polars_lazyframe(self): + class _BaseLazyTask(TaskOnKart[pl.LazyFrame]): + pass + + class _DerivedLazyTask(_BaseLazyTask): + pass + + task = _DerivedLazyTask() + self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy') From 37b522daa4294bbcedb2cd3731861eb9a1e5c53d Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Sat, 14 Feb 2026 16:00:27 +0900 Subject: [PATCH 2/3] refactor: use make_file_processor in TaskOnKart to create file processors with dataframe type support --- gokart/file_processor/__init__.py | 14 +++++++------- gokart/task.py | 23 ++--------------------- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/gokart/file_processor/__init__.py b/gokart/file_processor/__init__.py index db41d3b9..a076d0d7 100644 --- a/gokart/file_processor/__init__.py +++ b/gokart/file_processor/__init__.py @@ -161,21 +161,21 @@ def dump(self, obj, file): return self._impl.dump(obj, file) -def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: +def make_file_processor(file_path: str, dataframe_type: DataFrameType = 'pandas', store_index_in_feather: bool = True) -> FileProcessor: """Create a file processor based on file extension with default parameters.""" extension2processor = { '.txt': TextFileProcessor(), '.ini': TextFileProcessor(), - '.csv': CsvFileProcessor(sep=','), - '.tsv': CsvFileProcessor(sep='\t'), + '.csv': CsvFileProcessor(sep=',', dataframe_type=dataframe_type), + '.tsv': CsvFileProcessor(sep='\t', dataframe_type=dataframe_type), '.pkl': PickleFileProcessor(), '.gz': GzipFileProcessor(), - '.json': JsonFileProcessor(), - '.ndjson': JsonFileProcessor(orient='records'), + '.json': JsonFileProcessor(dataframe_type=dataframe_type), + '.ndjson': JsonFileProcessor(dataframe_type=dataframe_type, orient='records'), '.xml': XmlFileProcessor(), '.npz': NpzFileProcessor(), - '.parquet': ParquetFileProcessor(compression='gzip'), - '.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather), + '.parquet': ParquetFileProcessor(compression='gzip', dataframe_type=dataframe_type), + '.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather, dataframe_type=dataframe_type), '.png': BinaryFileProcessor(), '.jpg': BinaryFileProcessor(), } diff --git a/gokart/task.py b/gokart/task.py index 4728d595..c67979d8 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -19,7 +19,7 @@ import gokart.target from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock -from gokart.file_processor import FileProcessor +from gokart.file_processor import FileProcessor, make_file_processor from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.required_task_output import RequiredTaskOutput @@ -246,27 +246,8 @@ def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor Returns: FileProcessor with return_type set, or None to use default processor """ - from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor - - extension = os.path.splitext(file_path)[1] df_type = get_dataframe_type_from_task(self) - - # Create custom processor for DataFrame-supporting file types with type parameter - if extension == '.csv': - return CsvFileProcessor(sep=',', dataframe_type=df_type) - elif extension == '.tsv': - return CsvFileProcessor(sep='\t', dataframe_type=df_type) - elif extension == '.json': - return JsonFileProcessor(orient=None, dataframe_type=df_type) - elif extension == '.ndjson': - return JsonFileProcessor(orient='records', dataframe_type=df_type) - elif extension == '.parquet': - return ParquetFileProcessor(dataframe_type=df_type) - elif extension == '.feather': - return FeatherFileProcessor(store_index_in_feather=self.store_index_in_feather, dataframe_type=df_type) - - # For other file types, use default processor selection - return None + return make_file_processor(file_path, dataframe_type=df_type, store_index_in_feather=self.store_index_in_feather) def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( From 48e1d6f37cf851d8e81d207aa7b35599f165ada2 Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Sun, 15 Feb 2026 13:27:27 +0900 Subject: [PATCH 3/3] fix: comments --- gokart/file_processor/__init__.py | 2 +- gokart/task.py | 11 +---------- gokart/utils.py | 19 ++++++++++--------- test/test_utils.py | 18 ++++++++++++++++++ 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/gokart/file_processor/__init__.py b/gokart/file_processor/__init__.py index a076d0d7..c5c80371 100644 --- a/gokart/file_processor/__init__.py +++ b/gokart/file_processor/__init__.py @@ -161,7 +161,7 @@ def dump(self, obj, file): return self._impl.dump(obj, file) -def make_file_processor(file_path: str, dataframe_type: DataFrameType = 'pandas', store_index_in_feather: bool = True) -> FileProcessor: +def make_file_processor(file_path: str, store_index_in_feather: bool = True, *, dataframe_type: DataFrameType = 'pandas') -> FileProcessor: """Create a file processor based on file extension with default parameters.""" extension2processor = { '.txt': TextFileProcessor(), diff --git a/gokart/task.py b/gokart/task.py index c67979d8..a32d82f3 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -236,16 +236,7 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) - def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor | None: - """ - Create a file processor with appropriate return_type based on task's type parameter. - - Args: - file_path: Path to the file - - Returns: - FileProcessor with return_type set, or None to use default processor - """ + def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor: df_type = get_dataframe_type_from_task(self) return make_file_processor(file_path, dataframe_type=df_type, store_index_in_feather=self.store_index_in_feather) diff --git a/gokart/utils.py b/gokart/utils.py index 61b16cad..4508d071 100644 --- a/gokart/utils.py +++ b/gokart/utils.py @@ -118,15 +118,16 @@ def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars', 'pola """ task_class = task if isinstance(task, type) else task.__class__ - if not hasattr(task_class, '__orig_bases__'): - return 'pandas' - - for base in task_class.__orig_bases__: - origin = get_origin(base) - # Check if this is a TaskOnKart subclass - if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart': - args = get_args(base) - if args: + # Walk the MRO to find TaskOnKart[...] even when defined on a parent class + mro = task_class.mro() if hasattr(task_class, 'mro') else [task_class] + + for cls in mro: + for base in getattr(cls, '__orig_bases__', ()): + origin = get_origin(base) + if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart': + args = get_args(base) + if not args: + continue df_type = args[0] module = getattr(df_type, '__module__', '') diff --git a/test/test_utils.py b/test/test_utils.py index 33cdd100..3616466f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -168,3 +168,21 @@ class _DerivedLazyTask(_BaseLazyTask): task = _DerivedLazyTask() self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_nested_inheritance_polars_with_mixin(self): + """Derived class with multiple bases should still detect polars through MRO.""" + + class _Mixin: + pass + + class _BasePolarsTaskWithMixin(TaskOnKart[pl.DataFrame]): + pass + + # Multiple inheritance gives _DerivedTask its own __orig_bases__, + # which shadows the parent's and doesn't contain TaskOnKart[...]. + class _DerivedTaskWithMixin(_BasePolarsTaskWithMixin, _Mixin): + pass + + task = _DerivedTaskWithMixin() + self.assertEqual(get_dataframe_type_from_task(task), 'polars')