diff --git a/daft/datasets/__init__.py b/daft/datasets/__init__.py index f817abae8d0..899a2ffb9dc 100644 --- a/daft/datasets/__init__.py +++ b/daft/datasets/__init__.py @@ -1,3 +1,4 @@ -from daft.datasets.common_crawl import common_crawl +from . import lerobot +from .common_crawl import common_crawl -__all__ = ["common_crawl"] +__all__ = ["common_crawl", "lerobot"] diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py new file mode 100644 index 00000000000..c64c1c38b58 --- /dev/null +++ b/daft/datasets/lerobot.py @@ -0,0 +1,333 @@ +"""LeRobot Dataset v3.0 helpers for `daft.datasets`. + +This module reads the file-based LeRobot v3 layout (`meta/episodes`, `data`, +`videos`) and exposes episode-level scans plus frame expansion utilities. + +See https://huggingface.co/docs/lerobot/lerobot-dataset-v3 for format details. +""" + +from __future__ import annotations + +import json +import re +from typing import TYPE_CHECKING, Any, TypedDict, cast + +import daft +from daft.api_annotations import PublicAPI +from daft.datatype import DataType +from daft.exceptions import DaftCoreException +from daft.expressions import col, lit +from daft.file import VideoFile +from daft.functions import lpad +from daft.functions.file_ import video_file +from daft.udf import func + +if TYPE_CHECKING: + from daft.daft import IOConfig + from daft.dataframe import DataFrame + + +def _normalize_dataset_root(uri: str) -> str: + """Return a canonical dataset root prefix (no trailing slash) for path joins.""" + u = uri.strip() + # Input looks like a Hugging Face repo ID, i.e. "org/name" + is_hf_repo_id = bool(re.fullmatch(r"[\w.-]+/[\w.-]+", uri)) + + if is_hf_repo_id: + u = f"hf://datasets/{u}" + return u.rstrip("/") + + +@func(return_dtype=DataType.image()) +def _decode_lerobot_video_timestamp( + file: VideoFile, + episode_from_timestamp_s: float, + frame_timestamp_s: float, + tolerance_s: float, + image_width_i: int, + image_height_i: int, +) -> Any: # returns a PIL.Image; PIL is an optional dependency imported lazily below + """Pick the decoded frame closest in time to ``episode_from_timestamp_s + frame_timestamp_s``.""" + try: + import av as av_mod + except ImportError as err: + raise ImportError("Decoding LeRobot MP4 shards requires PyAV. Install with `pip install av`.") from err + try: + from PIL import Image as PILImage + except ImportError as err: + raise ImportError( + "Decoding LeRobot MP4 shards requires Pillow. Install with `pip install daft[video]` or `pip install pillow`." + ) from err + abs_ts = float(episode_from_timestamp_s) + float(frame_timestamp_s) + tolerance = float(tolerance_s) + width_i = int(image_width_i) + height_i = int(image_height_i) + width: int | None + height: int | None + if width_i > 0 and height_i > 0: + width, height = width_i, height_i + else: + width, height = None, None + + loaded: list[tuple[float, Any]] = [] + decode_cap = 20_000 + decoded = 0 + + with file.open() as f_open: + with av_mod.open(f_open) as container: + stream = container.streams.video[0] + # Match LeRobot: seek backwards to preceding keyframe, then decode forwards. + container.seek(max(0, int(abs_ts * av_mod.time_base)), backward=True) + + tail_s = max(0.1, tolerance * 50.0, 1.0 / 24.0) + for vf in container.decode(stream): + if vf.pts is None: + continue + current_ts = float(vf.pts * stream.time_base) + pil_img = PILImage.fromarray(vf.to_ndarray(format="rgb24"), mode="RGB") + if width is not None and height is not None: + pil_img = pil_img.resize((width, height), PILImage.Resampling.NEAREST) + + loaded.append((current_ts, pil_img)) + decoded += 1 + + if decoded >= decode_cap: + raise ValueError("Exceeded decode frame budget while aligning to parquet timestamps.") + if current_ts >= abs_ts + tail_s: + break + + if not loaded: + raise ValueError(f"No frames decoded from shard while seeking timestamp {abs_ts:.6f}s.") + + closest_ts, closest_img = min(loaded, key=lambda item: abs(item[0] - abs_ts)) + closest_dist = abs(closest_ts - abs_ts) + if closest_dist > tolerance: + raise ValueError( + f"No frame matched timestamp {abs_ts:.6f}s within tolerance {tolerance} " + f"(closest distance observed: {closest_dist})." + ) + return closest_img + + +class Feature(TypedDict): + dtype: str + + +class LeRobotInfo(TypedDict): + codebase_version: str + data_path: str + video_path: str + fps: float + features: dict[str, Feature] + + +def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobotInfo: + with daft.open_file(f"{normalized_uri}/meta/info.json", io_config=io_config) as f: + info = cast("LeRobotInfo", json.load(f)) + if info["codebase_version"] != "v3.0": + raise ValueError("`daft.datasets.lerobot` currently only supports LeRobot datasets of v3 and above") + return info + + +@PublicAPI +def read( + dataset_uri: str, + io_config: IOConfig | None = None, + include_stats: bool = False, + load_video_frames: str | list[str] | bool = False, +) -> DataFrame: + """Read a LeRobot v3 dataset as a lazy DataFrame with one row per frame. + + Reads the per-episode metadata under ``meta/episodes`` and the per-frame + sensor data under ``data``, joins them on ``episode_index``, and broadcasts + each episode's metadata across its frames. Optionally decodes the matching + video frame for one or more camera keys into an image column. + + Args: + dataset_uri: Huggingface repo id (``org/name``), or a local / remote + directory (``s3://...``, ``hf://datasets/...``). + io_config: Optional IO configuration for remote reads. + include_stats: If True, keep the per-episode ``stats/*`` columns + (per-feature min/max/mean/std/quantiles). Defaults to False. + load_video_frames: Which camera keys to decode into image columns, + aligned to each frame's timestamp. Defaults to False (decode + nothing). Pass True to decode every video feature, a single key + (``"observation.image"``), or a list of keys. Decoding requires the + optional ``av`` (PyAV) and ``Pillow`` dependencies. + + Returns: + Lazy DataFrame with one row per frame: the frame's sensor columns, the + broadcast episode metadata, and one image column per decoded video key. + """ + root = _normalize_dataset_root(dataset_uri) + info = _read_info(root, io_config=io_config) + + # Keep the per-episode video metadata (notably `videos/{key}/from_timestamp`, + # the time within the shard where each episode's footage begins). We need it + # to translate episode-local frame timestamps into absolute shard timestamps + # when decoding, and drop these internal columns again before returning. + episode_df = read_episodes( + dataset_uri, io_config=io_config, include_stats=include_stats, include_video_metadata=True + ) + df = load_episode_frames(episode_df, dataset_uri, io_config=io_config) + + # Load video frames into memory + if load_video_frames is not False: + if load_video_frames is True: + video_keys = [name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video"] + elif isinstance(load_video_frames, str): + video_keys = [load_video_frames] + elif isinstance(load_video_frames, list) and all(isinstance(k, str) for k in load_video_frames): + video_keys = load_video_frames + else: + raise ValueError(f"Invalid value provided for argument load_video_frames=`{load_video_frames}`") + + # An MP4 shard packs many episodes back to back, so the shard's internal + # frame numbering is NOT the parquet's episode-local `frame_index` (which + # resets to 0 each episode). Seeking by `frame_index` only happens to work + # for the first episode in each shard. Instead, seek by absolute timestamp: + # `from_timestamp` (where this episode begins in the shard) + the per-frame + # episode-local `timestamp`. That keeps a single coordinate system end to end. + fps = float(info["fps"]) + tolerance_s = 1.0 / fps / 2.0 # half a frame period: any closer frame is unambiguously "the" frame + + # To increase parallelism, reduce batch size + df = df.into_batches(16) + for k in video_keys: + df = df.with_column( + k, + _decode_lerobot_video_timestamp( + col(f"videos/{k}/video"), + col(f"videos/{k}/from_timestamp"), + col("timestamp"), + lit(tolerance_s), + lit(0), # image_width: 0 disables resize (decode at native resolution) + lit(0), # image_height: 0 disables resize + ), + ) + df = df.exclude(f"videos/{k}/video") + + + # Drop the internal per-episode video metadata we kept above (chunk/file index, + # from/to timestamp). This restores read_episodes' default of hiding these. + df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) + + return df + + +@PublicAPI +def read_episodes( + dataset_uri: str, + io_config: IOConfig | None = None, + include_meta: bool = False, + include_stats: bool = False, + include_video_metadata: bool = False, +) -> DataFrame: + """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). + + This reads the `meta/episodes/**/*.parquet` path under the dataset root. + + Args: + dataset_uri: Huggingface repo id (`org/name`), + or a local / remote directory (`s3://...`, `hf://datasets/...`) + io_config: Optional IO configuration for remote reads. + include_meta: If True, keep the internal ``meta/episodes/*`` columns + (the chunk/file indices locating each episode's own metadata shard). + These are bookkeeping for random access into the sharded metadata + and carry no analytical value once the rows are loaded. Defaults to + False. + include_stats: If True, keep the per-episode ``stats/*`` columns + (per-feature min/max/mean/std/quantiles). Defaults to False. + include_video_metadata: If True, keep the per-episode ``videos/{key}/*`` + columns (the chunk/file indices and from/to timestamps locating each + episode's footage within its video shard). Defaults to False. + + Returns: + Lazy DataFrame of episode metadata, one row per episode. Always includes + a ``videos/{key}/video`` file-handle column per video feature; the + ``include_*`` flags control which additional column families are kept. + """ + root = _normalize_dataset_root(dataset_uri) + info = _read_info(root, io_config=io_config) + df = daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) + if not include_meta: + df = df.exclude(*(c for c in df.column_names if c.startswith("meta/"))) + if not include_stats: + df = df.exclude(*(c for c in df.column_names if c.startswith("stats/"))) + + # Get the video keys + video_keys = set(name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video") + + for key in video_keys: + file_name_expr = ( + lit(f"{root}/videos/{key}/chunk-") + + lpad(col(f"videos/{key}/chunk_index").cast(DataType.string), 3, "0") + + lit("/file-") + + lpad(col(f"videos/{key}/file_index").cast(DataType.string), 3, "0") + + lit(".mp4") + ) + + df = df.with_column(f"videos/{key}/video", video_file(file_name_expr, verify=False, io_config=io_config)) + + if not include_video_metadata: + df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) + + return df + + +@PublicAPI +def load_episode_frames( + episodes: DataFrame, + dataset_uri: str, + io_config: IOConfig | None = None, +) -> DataFrame: + """Expand an episode-level DataFrame into a frame-level DataFrame. + + Reads the per-frame parquet under ``data/**`` and joins it to the provided + episode metadata on ``episode_index``, producing one row per frame. Episode + metadata is broadcast across each episode's frames. + + Filter ``episodes`` before calling this to expand only the episodes you need; + only the surviving episodes contribute to the join. + + Args: + episodes: Episode-level DataFrame, typically from :func:`read_episodes` + (optionally filtered). Must contain an ``episode_index`` column. + dataset_uri: The same dataset identifier passed to :func:`read_episodes` + (Huggingface repo id ``org/name``, or a local / remote directory such + as ``s3://...`` or ``hf://datasets/...``). + io_config: Optional IO configuration for remote reads. + + Returns: + Lazy DataFrame with one row per frame. + """ + root = _normalize_dataset_root(dataset_uri) + + frame_df = daft.read_parquet(f"{root}/data/**", io_config=io_config) + df = episodes.join(frame_df, on=["episode_index"]) + df = df.exclude("data/chunk_index", "data/file_index") + return df + + +@PublicAPI +def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: + """Load task metadata as a DataFrame. + + Prefers ``meta/tasks.parquet`` (current LeRobot default). Falls back to legacy + ``meta/tasks.jsonl`` when the Parquet file is missing. + """ + root = _normalize_dataset_root(dataset_uri) + + pq_url = f"{root}/meta/tasks.parquet" + try: + return daft.read_parquet(pq_url, io_config=io_config) + except (OSError, DaftCoreException, FileNotFoundError): + return daft.read_json(f"{root}/meta/tasks.jsonl", io_config=io_config) + + +__all__ = [ + "load_episode_frames", + "read", + "read_episodes", + "read_tasks", +] diff --git a/daft/file/video.py b/daft/file/video.py index 4069e4832c0..4bc07145a59 100644 --- a/daft/file/video.py +++ b/daft/file/video.py @@ -220,3 +220,44 @@ def frames( ) frame_index += 1 + + def get_frame_by_idx(self, idx: int) -> PIL.Image.Image: + if not pil_image.module_available(): + raise ImportError( + "The 'pillow' module is required for frame decoding. Install it with `pip install daft[video]`." + ) + if idx < 0: + raise IndexError(f"Frame index {idx} is out of range") + + with self.open() as f: + with av.open(f) as container: + video = next( + (stream for stream in container.streams if stream.type == "video"), + None, + ) + if video is None: + raise ValueError("No video stream found") + + time_base = float(video.time_base) if video.time_base else None + fps = float(video.average_rate) if video.average_rate else None + if fps is None and video.guessed_rate: + fps = float(video.guessed_rate) + start_pts = video.start_time or 0 + + # Seek to the nearest preceding keyframe at or before the target frame. + if idx > 0 and time_base is not None and fps is not None: + target_time = idx / fps + seek_timestamp = int(target_time / time_base) + container.seek(seek_timestamp, stream=video, backward=True) + + for frame_idx, frame in enumerate(container.decode(video)): + current_frame_index = frame_idx + if frame.pts is not None and time_base is not None and fps is not None: + current_frame_index = int(round((frame.pts - start_pts) * time_base * fps)) + + if current_frame_index == idx: + return frame.to_image() + if current_frame_index > idx: + break + + raise IndexError(f"Frame index {idx} is out of range") diff --git a/daft/functions/video.py b/daft/functions/video.py index 98fef6d093c..419fde030e9 100644 --- a/daft/functions/video.py +++ b/daft/functions/video.py @@ -144,8 +144,8 @@ def frames_impl( def video_frames( file_expr: Expression, *, - start_time: float = 0, - end_time: float | None = None, + start_time: float | Expression = 0, + end_time: float | None | Expression = None, width: int | None = None, height: int | None = None, is_key_frame: bool | None = None, @@ -157,8 +157,10 @@ def video_frames( Args: file_expr (VideoFile Expression): The video file to decode frames from. - start_time (float, optional): Start of the time range in seconds. Defaults to 0. - end_time (float | None, optional): End of the time range in seconds. Defaults to None (all frames). + start_time (float | Expression, optional): Start of the time range in seconds. Defaults to 0. + If an expression is provided, the start time will be dynamic per row. + end_time (float | None | Expression, optional): End of the time range in seconds. Defaults to None (all frames). + If an expression is provided, the end time will be dynamic per row. width (int | None, optional): Target width for resizing frames. Must be provided with ``height``. height (int | None, optional): Target height for resizing frames. Must be provided with ``width``. is_key_frame (bool | None, optional): If True, decode only keyframes. If False, @@ -190,3 +192,40 @@ def video_frames( is_key_frame=is_key_frame, sample_interval_seconds=sample_interval_seconds, ) # type: ignore + + +def get_frame_by_idx_impl( + file: daft.VideoFile, + idx: int, +) -> PIL.Image.Image: + return file.get_frame_by_idx(idx) + + +video_get_frame_by_idx_fn = Func._from_func( + get_frame_by_idx_impl, + return_dtype=daft.DataType.image(), + unnest=False, + use_process=None, + is_batch=False, + batch_size=None, + max_retries=None, + on_error=None, + name_override="video_get_frame_by_idx", +) + + +def get_video_frame_by_idx( + file_expr: Expression, + idx: int | Expression, +) -> Expression: + """Get a frame from a video file by index. + + Args: + file_expr (VideoFile Expression): The video file to get the frame from. + idx (int | Integer Expression): The index of the frame to get. + If an expression is provided, the index will be dynamic per row. + + Returns: + Expression (Image Expression): The frame as an image. + """ + return video_get_frame_by_idx_fn(file_expr, idx) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 5401060b4d2..d7333cbd715 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -31,6 +31,7 @@ * [Batch Inference](use-case/batch-inference.md) * Datasets * [Common Crawl](datasets/common-crawl.md) + * [LeRobot v3](datasets/lerobot.md) * Data Connectors * [Connectors](connectors/index.md) * [Custom Connectors](connectors/custom.md) diff --git a/docs/api/datasets.md b/docs/api/datasets.md index 57cc969eb85..8adf8604ca5 100644 --- a/docs/api/datasets.md +++ b/docs/api/datasets.md @@ -10,3 +10,27 @@ Check out our [Common Crawl dataset guide](../datasets/common-crawl.md) for more options: filters: ["!^_"] heading_level: 3 + +## LeRobot v3 + +See the [LeRobot v3 dataset guide](../datasets/lerobot.md) for episode vs frame workflows and Hub/local paths. + +::: daft.datasets.lerobot.read + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_episodes + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.load_episode_frames + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_tasks + options: + filters: ["!^_"] + heading_level: 3 diff --git a/docs/datasets/lerobot.md b/docs/datasets/lerobot.md new file mode 100644 index 00000000000..1b54a4cd1f5 --- /dev/null +++ b/docs/datasets/lerobot.md @@ -0,0 +1,70 @@ +# LeRobot v3 datasets with Daft + +[LeRobot Dataset v3.0](https://huggingface.co/docs/lerobot/lerobot-dataset-v3) stores robot learning data as chunked Parquet (`meta/episodes`, `data/`) and per-camera MP4 shards under `videos/`. Daft exposes this layout under [`daft.datasets.lerobot`](../api/datasets.md) so you can stay at **episode granularity** for filtering, then expand to **frames** only for the episodes you need. + +!!! warning "Beta" + + This API is new and may evolve as we add optimizations (for example deeper integration with Parquet predicate pushdown). + +## Frame-level reads + +Use [`daft.datasets.lerobot.read`](../api/datasets.md#daft.datasets.lerobot.read) for the common case: a lazy DataFrame with one row per frame, episode metadata broadcast onto each frame. Pass `load_video_frames=True` (or a camera key / list of keys) to also decode each row's camera image from the MP4 shards. + +```python +import daft +from daft.datasets import lerobot + +df = lerobot.read("your-org/your-robot-dataset", load_video_frames=True) +``` + +`dataset_uri` can be: + +- A local directory that contains `meta/`, `data/`, etc. +- An `hf://datasets/org/name` URI (Hub layout matches the on-disk v3 tree) +- A bare `org/name` string, which is interpreted as `hf://datasets/org/name` + +## Episode metadata + +Use [`daft.datasets.lerobot.read_episodes`](../api/datasets.md#daft.datasets.lerobot.read_episodes) to scan `meta/episodes/**/*.parquet` (one row per episode). Per-episode `meta/` and `stats/` columns are hidden by default; opt in with `include_meta=True` / `include_stats=True`. + +```python +import daft +from daft.datasets.lerobot import load_episode_frames, read_episodes + +repo = "hf://datasets/your-org/your-robot-dataset" +ep = read_episodes(repo) +long = ep.where(daft.col("length") > 100) +frames = load_episode_frames(long, repo) +``` + +[`load_episode_frames`](../api/datasets.md#daft.datasets.lerobot.load_episode_frames) reads the per-frame Parquet under `data/**` and joins it to the provided episode rows on `episode_index`, producing one row per frame. Filter the episode DataFrame first so only the surviving episodes contribute frames. + +## Tasks + +[`read_tasks`](../api/datasets.md#daft.datasets.lerobot.read_tasks) loads task metadata, preferring `meta/tasks.parquet` and falling back to legacy `meta/tasks.jsonl`. + +## Video frames + +With `load_video_frames`, [`read`](../api/datasets.md#daft.datasets.lerobot.read) decodes each frame from its MP4 shard by **timestamp**: a shard packs many episodes back to back, so Daft combines the episode's `from_timestamp` offset within the shard with the frame's episode-local `timestamp`, and matches the closest decoded frame within half a frame period. Decoding requires PyAV and Pillow (`pip install av pillow`). + +## API reference + +::: daft.datasets.lerobot.read + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_episodes + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.load_episode_frames + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_tasks + options: + filters: ["!^_"] + heading_level: 3 diff --git a/tests/datasets/test_lerobot.py b/tests/datasets/test_lerobot.py new file mode 100644 index 00000000000..81e451a7dc4 --- /dev/null +++ b/tests/datasets/test_lerobot.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import json +import pathlib +import shutil + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +import daft +from daft.datasets.lerobot import load_episode_frames, read, read_episodes, read_tasks + + +def _write_table(path, table: pa.Table) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, path) + + +@pytest.fixture +def tiny_lerobot_v3(tmp_path): + """Minimal on-disk LeRobot v3 layout (two episodes, one shared data shard).""" + root = tmp_path / "ds" + (root / "meta" / "episodes" / "chunk-000").mkdir(parents=True) + (root / "data" / "chunk-000").mkdir(parents=True) + + episodes_tbl = pa.table( + { + "episode_index": [0, 1], + "length": [2, 2], + "task_index": [0, 0], + "data/chunk_index": [0, 0], + "data/file_index": [0, 0], + "dataset_from_index": [0, 2], + "dataset_to_index": [2, 4], + # Hidden by default; exposed via include_meta / include_stats. + "meta/episode_uuid": ["a", "b"], + "stats/action_mean": [0.1, 0.2], + } + ) + _write_table(root / "meta/episodes/chunk-000/file-000.parquet", episodes_tbl) + + frames_tbl = pa.table( + { + "index": [0, 1, 2, 3], + "episode_index": [0, 0, 1, 1], + "frame_index": [0, 1, 0, 1], + "timestamp": [0.0, 1 / 30, 0.0, 1 / 30], + "task_index": [0, 0, 0, 0], + } + ) + _write_table(root / "data/chunk-000/file-000.parquet", frames_tbl) + + info = { + "codebase_version": "v3.0", + "fps": 30, + "data_path": "data/chunk-{chunk_index:03d}/file-{file_index:03d}.parquet", + "video_path": "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4", + "features": {}, + "total_episodes": 2, + "total_frames": 4, + "total_tasks": 1, + } + (root / "meta").mkdir(parents=True, exist_ok=True) + (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") + + tasks_tbl = pa.table({"task_index": [0], "task": ["pick"]}) + _write_table(root / "meta" / "tasks.parquet", tasks_tbl) + + return str(root) + + +@pytest.fixture +def tiny_lerobot_v3_video(tmp_path): + """Single-episode dataset with an MP4 shard for ``load_video_frames`` tests.""" + av = pytest.importorskip("av") + pytest.importorskip("PIL", reason="Pillow required for decoded image rows") + + root = tmp_path / "ds_vid" + (root / "meta" / "episodes" / "chunk-000").mkdir(parents=True) + (root / "data" / "chunk-000").mkdir(parents=True) + + video_key = "camera.test" + video_dir = root / "videos" / video_key / "chunk-000" + video_dir.mkdir(parents=True) + shutil.copy(pathlib.Path("tests/assets/sample_video.mp4"), video_dir / "file-000.mp4") + + # The episode's footage starts at the shard's first frame timestamp; frame + # timestamps in the data parquet are episode-local (start at 0). + with av.open(str(video_dir / "file-000.mp4")) as c: + s = c.streams.video[0] + fps = float(s.average_rate) + eps_from_ts = None + for fr in c.decode(s): + if fr.pts is not None: + eps_from_ts = float(fr.pts * s.time_base) + break + assert eps_from_ts is not None + + n_frames = 3 + timestamps = [i / fps for i in range(n_frames)] + + episodes_tbl = pa.table( + { + "episode_index": [0], + "length": [n_frames], + "task_index": [0], + "data/chunk_index": [0], + "data/file_index": [0], + "dataset_from_index": [0], + "dataset_to_index": [n_frames], + f"videos/{video_key}/chunk_index": [0], + f"videos/{video_key}/file_index": [0], + f"videos/{video_key}/from_timestamp": [eps_from_ts], + } + ) + _write_table(root / "meta/episodes/chunk-000/file-000.parquet", episodes_tbl) + + frames_tbl = pa.table( + { + "index": list(range(n_frames)), + "episode_index": [0] * n_frames, + "frame_index": list(range(n_frames)), + "timestamp": timestamps, + "task_index": [0] * n_frames, + } + ) + _write_table(root / "data/chunk-000/file-000.parquet", frames_tbl) + + info = { + "codebase_version": "v3.0", + "fps": fps, + "data_path": "data/chunk-{chunk_index:03d}/file-{file_index:03d}.parquet", + "video_path": "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4", + "features": { + video_key: {"dtype": "video"}, + }, + "total_episodes": 1, + "total_frames": n_frames, + "total_tasks": 1, + } + (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") + + tasks_tbl = pa.table({"task_index": [0], "task": ["pick"]}) + _write_table(root / "meta" / "tasks.parquet", tasks_tbl) + + return str(root) + + +def test_read_episodes_and_load_episode_frames(tiny_lerobot_v3): + ep = read_episodes(tiny_lerobot_v3).sort("episode_index") + assert ep.count_rows() == 2 + + frames = load_episode_frames(ep, tiny_lerobot_v3).sort("index") + assert frames.count_rows() == 4 + assert set(frames.to_pydict()["episode_index"]) == {0, 1} + + f0 = load_episode_frames(ep.where(daft.col("episode_index") == 0), tiny_lerobot_v3).sort("frame_index") + assert f0.count_rows() == 2 + assert f0.to_pydict()["episode_index"] == [0, 0] + + +def test_read_frame_rows(tiny_lerobot_v3): + df = read(tiny_lerobot_v3).sort("index").collect() + assert df.count_rows() == 4 + cols = df.column_names + for expected in ("episode_index", "frame_index", "timestamp"): + assert expected in cols + + +def test_read_episodes_meta_and_stats_columns(tiny_lerobot_v3): + default_cols = read_episodes(tiny_lerobot_v3).column_names + assert "meta/episode_uuid" not in default_cols + assert "stats/action_mean" not in default_cols + + with_meta = read_episodes(tiny_lerobot_v3, include_meta=True).column_names + assert "meta/episode_uuid" in with_meta + + with_stats = read_episodes(tiny_lerobot_v3, include_stats=True).column_names + assert "stats/action_mean" in with_stats + + +def test_read_episodes_rejects_non_v3(tmp_path): + root = tmp_path / "ds_old" + (root / "meta").mkdir(parents=True) + info = {"codebase_version": "v2.1", "fps": 30, "features": {}} + (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") + + with pytest.raises(ValueError, match="v3"): + read_episodes(str(root)) + + +def test_read_tasks_parquet(tiny_lerobot_v3): + t = read_tasks(tiny_lerobot_v3).collect() + assert t.count_rows() == 1 + + +def test_read_episodes_has_no_dataset_root_column(tiny_lerobot_v3): + ep = read_episodes(tiny_lerobot_v3) + assert "lerobot_dataset_root" not in ep.column_names + + +def test_read_load_video_frames_explicit_key(tiny_lerobot_v3_video): + df = read(tiny_lerobot_v3_video, load_video_frames=["camera.test"]).select("camera.test").collect() + assert df.count_rows() == 3 + img0 = df.to_pydict()["camera.test"][0] + if hasattr(img0, "mode"): + assert img0.mode == "RGB" + assert img0.size[0] > 10 and img0.size[1] > 10 + else: + import numpy as np + + assert isinstance(img0, np.ndarray) + assert img0.ndim == 3 and img0.shape[2] >= 3 + + +def test_read_load_video_frames_inferred_keys(tiny_lerobot_v3_video): + df = read(tiny_lerobot_v3_video, load_video_frames=True).collect() + assert df.count_rows() == 3 + assert "camera.test" in df.column_names + # Internal per-episode video metadata must not leak into the result. + assert not any(c.startswith("videos/") for c in df.column_names) diff --git a/tests/file/test_video.py b/tests/file/test_video.py index 067a1cf8071..539c36946a9 100644 --- a/tests/file/test_video.py +++ b/tests/file/test_video.py @@ -338,3 +338,24 @@ def test_video_frames_expression_sample_interval(sample_video_path): result = df.to_pydict()["frames"][0] assert len(result) == 10 + + +def test_get_frame_by_idx(sample_video_path): + """get_frame_by_idx returns the same frame as frames() for a given index.""" + import numpy as np + + file = daft.VideoFile(sample_video_path) + all_frames = list(file.frames()) + + for idx in (0, 1, 50, 150, 289): + expected = all_frames[idx]["data"] + actual = file.get_frame_by_idx(idx) + np.testing.assert_array_equal(np.array(actual), np.array(expected)) + + +def test_get_frame_by_idx_out_of_range(sample_video_path): + file = daft.VideoFile(sample_video_path) + with pytest.raises(IndexError): + file.get_frame_by_idx(290) + with pytest.raises(IndexError): + file.get_frame_by_idx(-1)