From 795d8a21a77c11bd3671cb77bd192474471595be Mon Sep 17 00:00:00 2001 From: Srinivas Lade Date: Fri, 15 May 2026 13:59:06 -0700 Subject: [PATCH 01/11] test it out --- daft/datasets/__init__.py | 5 +- daft/datasets/lerobot.py | 255 +++++++++++++++++++++++++++++++++ docs/SUMMARY.md | 1 + docs/api/datasets.md | 29 ++++ docs/datasets/lerobot.md | 70 +++++++++ tests/datasets/test_lerobot.py | 95 ++++++++++++ 6 files changed, 453 insertions(+), 2 deletions(-) create mode 100644 daft/datasets/lerobot.py create mode 100644 docs/datasets/lerobot.md create mode 100644 tests/datasets/test_lerobot.py diff --git a/daft/datasets/__init__.py b/daft/datasets/__init__.py index f817abae8d0..899a2ffb9dc 100644 --- a/daft/datasets/__init__.py +++ b/daft/datasets/__init__.py @@ -1,3 +1,4 @@ -from daft.datasets.common_crawl import common_crawl +from . import lerobot +from .common_crawl import common_crawl -__all__ = ["common_crawl"] +__all__ = ["common_crawl", "lerobot"] diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py new file mode 100644 index 00000000000..d1121b52e67 --- /dev/null +++ b/daft/datasets/lerobot.py @@ -0,0 +1,255 @@ +"""LeRobot Dataset v3.0 helpers for `daft.datasets`. + +This module reads the file-based LeRobot v3 layout (``meta/episodes``, ``data``, +``videos``) and exposes episode-level scans plus frame expansion utilities. + +See https://huggingface.co/docs/lerobot/lerobot-dataset-v3 for format details. +""" + +from __future__ import annotations + +import json +import re +import urllib.error +import urllib.request +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from daft.api_annotations import PublicAPI +from daft.convert import from_pydict, from_pylist +from daft.datatype import DataType +from daft.exceptions import DaftCoreException +from daft.expressions import Expression, col, lit +from daft.functions import lpad +from daft.io import read_parquet + +if TYPE_CHECKING: + from daft.daft import IOConfig + from daft.dataframe import DataFrame + +__all__ = [ + "episodes", + "load_episode_frames", + "read_info", + "read_stats", + "read_tasks", +] + +# Column names used by Hugging Face LeRobot v3 metadata / data shards. +_DATA_CHUNK_CANDIDATES = ("data/chunk_index", "data_chunk_index") +_DATA_FILE_CANDIDATES = ("data/file_index", "data_file_index") + +_LEROBOT_ROOT_COL = "lerobot_dataset_root" +_DATA_PATH_COL = "lerobot_data_parquet_path" + + +def _is_probable_hf_repo_id(uri: str) -> bool: + return bool(re.fullmatch(r"[\w.-]+/[\w.-]+", uri)) + + +def _normalize_dataset_root(uri: str) -> str: + """Return a canonical dataset root prefix (no trailing slash) for path joins.""" + u = uri.strip() + if _is_probable_hf_repo_id(u): + u = f"hf://datasets/{u}" + return u.rstrip("/") + + +def _https_base_for_hf_datasets_root(root: str) -> str | None: + if not root.startswith("hf://datasets/"): + return None + repo_id = root.removeprefix("hf://datasets/") + return f"https://huggingface.co/datasets/{repo_id}/resolve/main" + + +def _read_json_object(root: str, relative_path: str) -> dict[str, Any]: + """Load a small JSON object from ``{root}/{relative_path}`` (local or hf://datasets).""" + https_base = _https_base_for_hf_datasets_root(root) + if https_base is not None: + url = f"{https_base}/{relative_path.lstrip('/')}" + try: + with urllib.request.urlopen(url, timeout=60) as resp: + return json.loads(resp.read().decode("utf-8")) + except urllib.error.HTTPError as e: + raise FileNotFoundError(f"Failed to download JSON from {url!r}: {e}") from e + + path = Path(root) / relative_path + if not path.is_file(): + raise FileNotFoundError(f"Missing file at {path}") + return json.loads(path.read_text(encoding="utf-8")) + + +def _read_jsonl_records_local(path: Path) -> list[dict[str, Any]]: + if not path.is_file(): + raise FileNotFoundError(f"Missing file at {path}") + out: list[dict[str, Any]] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + out.append(json.loads(line)) + return out + + +def _read_jsonl_records_remote(url: str) -> list[dict[str, Any]]: + try: + with urllib.request.urlopen(url, timeout=60) as resp: + text = resp.read().decode("utf-8") + except urllib.error.HTTPError as e: + raise FileNotFoundError(f"Failed to download JSONL from {url!r}: {e}") from e + out: list[dict[str, Any]] = [] + for line in text.splitlines(): + line = line.strip() + if not line: + continue + out.append(json.loads(line)) + return out + + +def _pick_first_column(df: DataFrame, candidates: tuple[str, ...]) -> str: + names = set(df.column_names) + for c in candidates: + if c in names: + return c + raise ValueError( + "Expected one of columns " + + ", ".join(repr(c) for c in candidates) + + f" in episodes dataframe, but found columns: {sorted(names)}" + ) + + +def _data_parquet_path_expr(root_expr: Expression, chunk_col: str, file_col: str) -> Expression: + """Build ``{root}/data/chunk-XXX/file-YYY.parquet``.""" + chunk_str = lpad(col(chunk_col).cast(DataType.string()), 3, "0") + file_str = lpad(col(file_col).cast(DataType.string()), 3, "0") + return ( + root_expr.cast(DataType.string()) + lit("/data/chunk-") + chunk_str + lit("/file-") + file_str + lit(".parquet") + ) + + +@PublicAPI +def episodes( + dataset_uri: str, + io_config: IOConfig | None = None, + *, + dataset_path_column: str | None = None, +) -> DataFrame: + """Load LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). + + This reads ``meta/episodes/**/*.parquet`` under the dataset root and adds + ``lerobot_dataset_root`` so downstream helpers can build ``data/`` and + ``videos/`` paths without threading the root string manually. + + Args: + dataset_uri: Local directory, ``hf://datasets/org/name`` URI, or bare + ``org/name`` which is treated as a Hub dataset id. + io_config: Optional IO configuration for remote reads. + dataset_path_column: If set, include the resolved dataset root string in + a column with this name (in addition to ``lerobot_dataset_root``). + + Returns: + Lazy episode metadata DataFrame. + """ + root = _normalize_dataset_root(dataset_uri) + meta_glob = f"{root}/meta/episodes/**/*.parquet" + df = read_parquet(meta_glob, io_config=io_config) + df = df.with_column(_LEROBOT_ROOT_COL, lit(root)) + if dataset_path_column is not None: + df = df.with_column(dataset_path_column, lit(root)) + return df + + +@PublicAPI +def load_episode_frames( + episodes: DataFrame, + *, + io_config: IOConfig | None = None, + columns: list[str] | None = None, +) -> DataFrame: + """Expand filtered episode rows into frame-level rows from ``data/`` Parquet shards. + + This executes a small eager step to discover distinct shard paths from the + current logical plan, then lazily reads only those Parquet files and keeps + rows whose ``episode_index`` appears in ``episodes``. + + Preconditions: + + - ``episodes`` must include ``lerobot_dataset_root`` (added by :func:`episodes`) + plus either ``data/chunk_index`` / ``data/file_index`` (canonical LeRobot v3) + or the ``data_chunk_index`` / ``data_file_index`` spelling. + + Args: + episodes: Episode-level dataframe (typically filtered) from :func:`episodes`. + io_config: Optional IO configuration for remote reads. + columns: Optional projection of frame columns (passed to :meth:`daft.DataFrame.select`). + + Returns: + Lazy frame-level dataframe. + """ + if _LEROBOT_ROOT_COL not in episodes.column_names: + raise ValueError( + f"Missing {_LEROBOT_ROOT_COL!r} column on episodes dataframe. " + "Construct episodes via daft.datasets.lerobot.episodes(...)." + ) + + chunk_col = _pick_first_column(episodes, _DATA_CHUNK_CANDIDATES) + file_col = _pick_first_column(episodes, _DATA_FILE_CANDIDATES) + + with_paths = episodes.with_column( + _DATA_PATH_COL, + _data_parquet_path_expr(col(_LEROBOT_ROOT_COL), chunk_col, file_col), + ) + + paths = with_paths.select(_DATA_PATH_COL).distinct().to_pydict()[_DATA_PATH_COL] + if len(paths) == 0: + empty_cols = list(columns) if columns is not None else ["episode_index", "frame_index", "timestamp"] + return from_pydict({c: [] for c in empty_cols}) + + frames = read_parquet(paths, io_config=io_config) + if columns is not None: + frames = frames.select(*columns) + allowed = with_paths.select("episode_index").distinct() + return frames.join(allowed, on="episode_index", how="inner") + + +@PublicAPI +def read_info(dataset_uri: str) -> dict[str, Any]: + """Load ``meta/info.json`` for a LeRobot v3 dataset.""" + root = _normalize_dataset_root(dataset_uri) + return _read_json_object(root, "meta/info.json") + + +@PublicAPI +def read_stats(dataset_uri: str) -> dict[str, Any]: + """Load ``meta/stats.json`` for a LeRobot v3 dataset.""" + root = _normalize_dataset_root(dataset_uri) + return _read_json_object(root, "meta/stats.json") + + +@PublicAPI +def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: + """Load task metadata as a DataFrame. + + Prefers ``meta/tasks.parquet`` (current LeRobot default). Falls back to legacy + ``meta/tasks.jsonl`` when the Parquet file is missing. + """ + root = _normalize_dataset_root(dataset_uri) + + https_base = _https_base_for_hf_datasets_root(root) + if https_base is not None: + pq_url = f"{https_base}/meta/tasks.parquet" + try: + return read_parquet(pq_url, io_config=io_config) + except (OSError, DaftCoreException, FileNotFoundError): + url = f"{https_base}/meta/tasks.jsonl" + return from_pylist(_read_jsonl_records_remote(url)) + + pq_path = Path(root) / "meta" / "tasks.parquet" + if pq_path.is_file(): + return read_parquet(str(pq_path), io_config=io_config) + + jsonl_path = Path(root) / "meta" / "tasks.jsonl" + if jsonl_path.is_file(): + return from_pylist(_read_jsonl_records_local(jsonl_path)) + + raise FileNotFoundError(f"No tasks metadata found under {root}/meta (tasks.parquet or tasks.jsonl)") diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 5401060b4d2..d7333cbd715 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -31,6 +31,7 @@ * [Batch Inference](use-case/batch-inference.md) * Datasets * [Common Crawl](datasets/common-crawl.md) + * [LeRobot v3](datasets/lerobot.md) * Data Connectors * [Connectors](connectors/index.md) * [Custom Connectors](connectors/custom.md) diff --git a/docs/api/datasets.md b/docs/api/datasets.md index 57cc969eb85..29d93ffe8d7 100644 --- a/docs/api/datasets.md +++ b/docs/api/datasets.md @@ -10,3 +10,32 @@ Check out our [Common Crawl dataset guide](../datasets/common-crawl.md) for more options: filters: ["!^_"] heading_level: 3 + +## LeRobot v3 + +See the [LeRobot v3 dataset guide](../datasets/lerobot.md) for episode vs frame workflows and Hub/local paths. + +::: daft.datasets.lerobot.episodes + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.load_episode_frames + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_tasks + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_info + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_stats + options: + filters: ["!^_"] + heading_level: 3 diff --git a/docs/datasets/lerobot.md b/docs/datasets/lerobot.md new file mode 100644 index 00000000000..98bc7d95f0f --- /dev/null +++ b/docs/datasets/lerobot.md @@ -0,0 +1,70 @@ +# LeRobot v3 datasets with Daft + +[LeRobot Dataset v3.0](https://huggingface.co/docs/lerobot/lerobot-dataset-v3) stores robot learning data as chunked Parquet (`meta/episodes`, `data/`) and per-camera MP4 shards under `videos/`. Daft exposes this layout under [`daft.datasets.lerobot`](../../api/datasets.md) so you can stay at **episode granularity** for filtering, then expand to **frames** only for the episodes you need. + +!!! warning "Beta" + + This API is new and may evolve as we add optimizations (for example deeper integration with Parquet predicate pushdown). + +## Episode metadata + +Use [`daft.datasets.lerobot.episodes`](../api/datasets.md#daft.datasets.lerobot.episodes) to scan `meta/episodes/**/*.parquet`. The dataframe includes a `lerobot_dataset_root` column used by frame expansion helpers. + +`dataset_uri` can be: + +- A local directory that contains `meta/`, `data/`, etc. +- An `hf://datasets/org/name` URI (Hub layout matches the on-disk v3 tree) +- A bare `org/name` string, which is interpreted as `hf://datasets/org/name` + +```python +import daft +from daft.datasets.lerobot import episodes, load_episode_frames + +repo = "hf://datasets/your-org/your-robot-dataset" +ep = episodes(repo) +long = ep.where(daft.col("length") > 100) +frames = load_episode_frames(long) +``` + +[`load_episode_frames`](../api/datasets.md#daft.datasets.lerobot.load_episode_frames) reads only the `data/chunk-*/file-*.parquet` shards referenced by the (possibly filtered) episode rows, then keeps rows whose `episode_index` is still present. It runs a small eager step to list **distinct shard paths**; the heavy Parquet scan stays lazy afterward. Pass `columns=[...]` to project frame fields with :meth:`daft.DataFrame.select` semantics. + +## Dataset-level JSON and tasks + +Bounded metadata files are exposed as small helpers: + +- [`read_info`](../api/datasets.md#daft.datasets.lerobot.read_info) → `meta/info.json` +- [`read_stats`](../api/datasets.md#daft.datasets.lerobot.read_stats) → `meta/stats.json` +- [`read_tasks`](../api/datasets.md#daft.datasets.lerobot.read_tasks) → prefers `meta/tasks.parquet`, falls back to `meta/tasks.jsonl` + +For Hub datasets, JSON is fetched over HTTPS from `resolve/main` (public files only unless your environment supplies credentials via your HTTP stack). + +## Video frames + +Daft already decodes video via [`read_video_frames`](../api/io.md#daft.read_video_frames) and [`daft.VideoFile`](../modalities/videos.md). Episode metadata includes per-camera chunk/file indices and timestamp offsets (`videos/{camera}/...` fields in LeRobot v3). Build the MP4 path from those columns (plus `lerobot_dataset_root`), then call `read_video_frames` with `sample_interval_seconds` or decode with `daft.functions.video_frames` for precise timestamps. + +## API reference + +::: daft.datasets.lerobot.episodes + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.load_episode_frames + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_tasks + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_info + options: + filters: ["!^_"] + heading_level: 3 + +::: daft.datasets.lerobot.read_stats + options: + filters: ["!^_"] + heading_level: 3 diff --git a/tests/datasets/test_lerobot.py b/tests/datasets/test_lerobot.py new file mode 100644 index 00000000000..b76d3b05dbb --- /dev/null +++ b/tests/datasets/test_lerobot.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +import daft +from daft.datasets.lerobot import episodes, load_episode_frames, read_info, read_stats, read_tasks + + +def _write_table(path, table: pa.Table) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, path) + + +@pytest.fixture +def tiny_lerobot_v3(tmp_path): + """Minimal on-disk LeRobot v3 layout (two episodes, one shared data shard).""" + root = tmp_path / "ds" + (root / "meta" / "episodes" / "chunk-000").mkdir(parents=True) + (root / "data" / "chunk-000").mkdir(parents=True) + + episodes_tbl = pa.table( + { + "episode_index": [0, 1], + "length": [2, 2], + "task_index": [0, 0], + "data/chunk_index": [0, 0], + "data/file_index": [0, 0], + "dataset_from_index": [0, 2], + "dataset_to_index": [2, 4], + } + ) + _write_table(root / "meta/episodes/chunk-000/file-000.parquet", episodes_tbl) + + frames_tbl = pa.table( + { + "index": [0, 1, 2, 3], + "episode_index": [0, 0, 1, 1], + "frame_index": [0, 1, 0, 1], + "timestamp": [0.0, 1 / 30, 0.0, 1 / 30], + "task_index": [0, 0, 0, 0], + } + ) + _write_table(root / "data/chunk-000/file-000.parquet", frames_tbl) + + info = { + "codebase_version": "v3.0", + "fps": 30, + "features": {}, + "total_episodes": 2, + "total_frames": 4, + "total_tasks": 1, + } + (root / "meta").mkdir(parents=True, exist_ok=True) + (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") + (root / "meta" / "stats.json").write_text(json.dumps({"ok": True}), encoding="utf-8") + + tasks_tbl = pa.table({"task_index": [0], "task": ["pick"]}) + _write_table(root / "meta" / "tasks.parquet", tasks_tbl) + + return str(root) + + +def test_episodes_and_load_episode_frames(tiny_lerobot_v3): + ep = episodes(tiny_lerobot_v3).sort("episode_index") + assert ep.count_rows() == 2 + assert "lerobot_dataset_root" in ep.column_names + + frames = load_episode_frames(ep).sort("index") + assert frames.count_rows() == 4 + assert set(frames.to_pydict()["episode_index"]) == {0, 1} + + f0 = load_episode_frames(ep.where(daft.col("episode_index") == 0)).sort("frame_index") + assert f0.count_rows() == 2 + assert f0.to_pydict()["episode_index"] == [0, 0] + + +def test_read_info_and_stats(tiny_lerobot_v3): + info = read_info(tiny_lerobot_v3) + assert info["total_episodes"] == 2 + assert read_stats(tiny_lerobot_v3) == {"ok": True} + + +def test_read_tasks_parquet(tiny_lerobot_v3): + t = read_tasks(tiny_lerobot_v3).collect() + assert t.count_rows() == 1 + + +def test_load_episode_frames_requires_root(tiny_lerobot_v3): + ep = daft.read_parquet(f"{tiny_lerobot_v3}/meta/episodes/**/*.parquet") + with pytest.raises(ValueError, match="lerobot_dataset_root"): + load_episode_frames(ep) From f8b838c2b6fa68f24018d2567e55c600ff08d1d0 Mon Sep 17 00:00:00 2001 From: Srinivas Lade Date: Mon, 18 May 2026 15:17:05 -0700 Subject: [PATCH 02/11] save some shit for now --- daft/datasets/lerobot.py | 348 ++++++++++++++++++++++----------- tests/datasets/test_lerobot.py | 113 ++++++++++- 2 files changed, 344 insertions(+), 117 deletions(-) diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py index d1121b52e67..274785afc32 100644 --- a/daft/datasets/lerobot.py +++ b/daft/datasets/lerobot.py @@ -8,38 +8,29 @@ from __future__ import annotations -import json import re -import urllib.error -import urllib.request from pathlib import Path from typing import TYPE_CHECKING, Any +import daft from daft.api_annotations import PublicAPI -from daft.convert import from_pydict, from_pylist from daft.datatype import DataType from daft.exceptions import DaftCoreException from daft.expressions import Expression, col, lit +from daft.file import VideoFile from daft.functions import lpad -from daft.io import read_parquet +from daft.functions.file_ import video_file +from daft.udf import func if TYPE_CHECKING: from daft.daft import IOConfig from daft.dataframe import DataFrame -__all__ = [ - "episodes", - "load_episode_frames", - "read_info", - "read_stats", - "read_tasks", -] # Column names used by Hugging Face LeRobot v3 metadata / data shards. _DATA_CHUNK_CANDIDATES = ("data/chunk_index", "data_chunk_index") _DATA_FILE_CANDIDATES = ("data/file_index", "data_file_index") -_LEROBOT_ROOT_COL = "lerobot_dataset_root" _DATA_PATH_COL = "lerobot_data_parquet_path" @@ -62,50 +53,6 @@ def _https_base_for_hf_datasets_root(root: str) -> str | None: return f"https://huggingface.co/datasets/{repo_id}/resolve/main" -def _read_json_object(root: str, relative_path: str) -> dict[str, Any]: - """Load a small JSON object from ``{root}/{relative_path}`` (local or hf://datasets).""" - https_base = _https_base_for_hf_datasets_root(root) - if https_base is not None: - url = f"{https_base}/{relative_path.lstrip('/')}" - try: - with urllib.request.urlopen(url, timeout=60) as resp: - return json.loads(resp.read().decode("utf-8")) - except urllib.error.HTTPError as e: - raise FileNotFoundError(f"Failed to download JSON from {url!r}: {e}") from e - - path = Path(root) / relative_path - if not path.is_file(): - raise FileNotFoundError(f"Missing file at {path}") - return json.loads(path.read_text(encoding="utf-8")) - - -def _read_jsonl_records_local(path: Path) -> list[dict[str, Any]]: - if not path.is_file(): - raise FileNotFoundError(f"Missing file at {path}") - out: list[dict[str, Any]] = [] - for line in path.read_text(encoding="utf-8").splitlines(): - line = line.strip() - if not line: - continue - out.append(json.loads(line)) - return out - - -def _read_jsonl_records_remote(url: str) -> list[dict[str, Any]]: - try: - with urllib.request.urlopen(url, timeout=60) as resp: - text = resp.read().decode("utf-8") - except urllib.error.HTTPError as e: - raise FileNotFoundError(f"Failed to download JSONL from {url!r}: {e}") from e - out: list[dict[str, Any]] = [] - for line in text.splitlines(): - line = line.strip() - if not line: - continue - out.append(json.loads(line)) - return out - - def _pick_first_column(df: DataFrame, candidates: tuple[str, ...]) -> str: names = set(df.column_names) for c in candidates: @@ -120,51 +67,175 @@ def _pick_first_column(df: DataFrame, candidates: tuple[str, ...]) -> str: def _data_parquet_path_expr(root_expr: Expression, chunk_col: str, file_col: str) -> Expression: """Build ``{root}/data/chunk-XXX/file-YYY.parquet``.""" - chunk_str = lpad(col(chunk_col).cast(DataType.string()), 3, "0") - file_str = lpad(col(file_col).cast(DataType.string()), 3, "0") + chunk_str = lpad(col(chunk_col).cast(DataType.string), 3, "0") + file_str = lpad(col(file_col).cast(DataType.string), 3, "0") return ( - root_expr.cast(DataType.string()) + lit("/data/chunk-") + chunk_str + lit("/file-") + file_str + lit(".parquet") + root_expr.cast(DataType.string) + lit("/data/chunk-") + chunk_str + lit("/file-") + file_str + lit(".parquet") ) +def _video_mp4_path_expr(root_expr: Expression, video_key: str) -> Expression: + """Build ``{root}/videos/{video_key}/chunk-XXX/file-YYY.mp4`` from episode parquet columns.""" + chunk_col = f"videos/{video_key}/chunk_index" + file_col = f"videos/{video_key}/file_index" + chunk_str = lpad(col(chunk_col).cast(DataType.string), 3, "0") + file_str = lpad(col(file_col).cast(DataType.string), 3, "0") + return ( + root_expr.cast(DataType.string) + + lit(f"/videos/{video_key}/chunk-") + + chunk_str + + lit("/file-") + + file_str + + lit(".mp4") + ) + + +def _video_feature_keys(features: dict[str, Any]) -> tuple[str, ...]: + keys: list[str] = [] + for name, meta in sorted(features.items()): + if isinstance(meta, dict) and meta.get("dtype") == "video": + keys.append(name) + return tuple(keys) + + +@func(return_dtype=DataType.image()) +def _decode_lerobot_video_timestamp( + file: VideoFile, + episode_from_timestamp_s: float, + frame_timestamp_s: float, + tolerance_s: float, + image_width_i: int, + image_height_i: int, +): + """Pick the decoded frame closest in time to ``episode_from_timestamp_s + frame_timestamp_s``.""" + try: + import av as av_mod + except ImportError as err: + raise ImportError("Decoding LeRobot MP4 shards requires PyAV. Install with `pip install av`.") from err + try: + from PIL import Image as PILImage + except ImportError as err: + raise ImportError( + "Decoding LeRobot MP4 shards requires Pillow. Install with `pip install daft[video]` or `pip install pillow`." + ) from err + abs_ts = float(episode_from_timestamp_s) + float(frame_timestamp_s) + tolerance = float(tolerance_s) + width_i = int(image_width_i) + height_i = int(image_height_i) + width: int | None + height: int | None + if width_i > 0 and height_i > 0: + width, height = width_i, height_i + else: + width, height = None, None + + loaded: list[tuple[float, Any]] = [] + decode_cap = 20_000 + decoded = 0 + + with file.open() as f_open: + with av_mod.open(f_open) as container: + stream = container.streams.video[0] + # Match LeRobot: seek backwards to preceding keyframe, then decode forwards. + container.seek(max(0, int(abs_ts / av_mod.time_base)), backward=True) + + tail_s = max(0.1, tolerance * 50.0, 1.0 / 24.0) + for vf in container.decode(stream): + if vf.pts is None: + continue + current_ts = float(vf.pts * stream.time_base) + pil_img = PILImage.fromarray(vf.to_ndarray(format="rgb24"), mode="RGB") + if width is not None and height is not None: + pil_img = pil_img.resize((width, height), PILImage.Resampling.NEAREST) + + loaded.append((current_ts, pil_img)) + decoded += 1 + + if decoded >= decode_cap: + raise ValueError("Exceeded decode frame budget while aligning to parquet timestamps.") + if current_ts >= abs_ts + tail_s: + break + + if not loaded: + raise ValueError(f"No frames decoded from shard while seeking timestamp {abs_ts:.6f}s.") + + closest_ts, closest_img = min(loaded, key=lambda item: abs(item[0] - abs_ts)) + closest_dist = abs(closest_ts - abs_ts) + if closest_dist > tolerance: + raise ValueError( + f"No frame matched timestamp {abs_ts:.6f}s within tolerance {tolerance} " + f"(closest distance observed: {closest_dist})." + ) + return closest_img + + +def _assert_episodes_have_video_cols(episodes: DataFrame, video_keys: tuple[str, ...]) -> None: + names = episodes.column_names + missing = [ + candidate + for vk in video_keys + for candidate in ( + f"videos/{vk}/chunk_index", + f"videos/{vk}/file_index", + f"videos/{vk}/from_timestamp", + ) + if candidate not in names + ] + if missing: + raise ValueError( + "Episodes dataframe is missing LeRobot video index columns needed for decoding: " + + ", ".join(repr(x) for x in missing) + ) + + @PublicAPI -def episodes( - dataset_uri: str, - io_config: IOConfig | None = None, - *, - dataset_path_column: str | None = None, -) -> DataFrame: - """Load LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). +def read_episodes(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: + """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). - This reads ``meta/episodes/**/*.parquet`` under the dataset root and adds - ``lerobot_dataset_root`` so downstream helpers can build ``data/`` and - ``videos/`` paths without threading the root string manually. + This reads ``meta/episodes/**/*.parquet`` under the dataset root. Args: dataset_uri: Local directory, ``hf://datasets/org/name`` URI, or bare ``org/name`` which is treated as a Hub dataset id. io_config: Optional IO configuration for remote reads. - dataset_path_column: If set, include the resolved dataset root string in - a column with this name (in addition to ``lerobot_dataset_root``). Returns: Lazy episode metadata DataFrame. """ root = _normalize_dataset_root(dataset_uri) - meta_glob = f"{root}/meta/episodes/**/*.parquet" - df = read_parquet(meta_glob, io_config=io_config) - df = df.with_column(_LEROBOT_ROOT_COL, lit(root)) - if dataset_path_column is not None: - df = df.with_column(dataset_path_column, lit(root)) - return df + return daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) + + +@PublicAPI +def read_frames(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: + """Read all frame data from a LeRobot v3 dataset into a lazy DataFrame (one row per frame). + + This reads the ``data/chunk-XXX/file-YYY.parquet`` under the dataset root. If you only need a subset of the frames, use :func:`load_episode_frames` instead from a filtered episodes dataframe from :func:`read_episodes`. + + Args: + dataset_uri: Same dataset root as passed to :func:`read_episodes` (local path, + ``hf://datasets/org/name``, or bare ``org/name`` Hub id). + io_config: Optional IO configuration for remote reads. + + Returns: + Lazy DataFrame of frame metadata. + """ + root = _normalize_dataset_root(dataset_uri) + return daft.read_parquet(f"{root}/data/**/*.parquet", io_config=io_config) @PublicAPI def load_episode_frames( episodes: DataFrame, + dataset_uri: str, *, io_config: IOConfig | None = None, columns: list[str] | None = None, + decode_videos: bool = False, + video_keys: list[str] | None = None, + timestamp_tolerance_seconds: float = 1e-4, + decode_image_width: int | None = None, + decode_image_height: int | None = None, ) -> DataFrame: """Expand filtered episode rows into frame-level rows from ``data/`` Parquet shards. @@ -172,60 +243,111 @@ def load_episode_frames( current logical plan, then lazily reads only those Parquet files and keeps rows whose ``episode_index`` appears in ``episodes``. - Preconditions: + Optionally decodes MP4 shards under ``videos//chunk-XXX/file-YYY.mp4`` into + :class:`~daft.datatype.DataType` ``image()`` columns keyed by LeRobot ``feature_key`` strings + (typically ``dtype: "video"`` entries in ``meta/info.json``), using the episode-level + ``videos//from_timestamp``, ``videos//chunk_index``, and ``videos//file_index`` + fields plus row ``timestamp`` values (matching how ``LeRobotDataset`` aligns frames). - - ``episodes`` must include ``lerobot_dataset_root`` (added by :func:`episodes`) - plus either ``data/chunk_index`` / ``data/file_index`` (canonical LeRobot v3) - or the ``data_chunk_index`` / ``data_file_index`` spelling. + Preconditions: + - ``episodes`` must include either ``data/chunk_index`` / ``data/file_index`` + (canonical LeRobot v3) or the ``data_chunk_index`` / ``data_file_index`` spelling. Args: - episodes: Episode-level dataframe (typically filtered) from :func:`episodes`. + episodes: Episode-level dataframe (typically filtered) from :func:`read_episodes`. + dataset_uri: Same dataset root as passed to :func:`read_episodes` (local path, + ``hf://datasets/org/name``, or bare ``org/name`` Hub id). io_config: Optional IO configuration for remote reads. columns: Optional projection of frame columns (passed to :meth:`daft.DataFrame.select`). + decode_videos: When ``True``, add decoded camera images for each declared ``video_keys`` subset. + Requires PyAV and Pillow plus per-episode columns ``videos//{chunk_index,file_index,from_timestamp}``. + video_keys: Subset of video feature keys from ``meta/info.json`` (must have ``dtype: "video"``). + When ``None`` and ``decode_videos`` is enabled, all video features are decoded. + timestamp_tolerance_seconds: Maximum |PTS - (from_timestamp + timestamp)| in seconds (LeRobot default is ~1e-4). + decode_image_width: If set with ``decode_image_height``, nearest-neighbor resize decoded frames. + decode_image_height: See ``decode_image_width``. Returns: Lazy frame-level dataframe. """ - if _LEROBOT_ROOT_COL not in episodes.column_names: - raise ValueError( - f"Missing {_LEROBOT_ROOT_COL!r} column on episodes dataframe. " - "Construct episodes via daft.datasets.lerobot.episodes(...)." - ) + root = _normalize_dataset_root(dataset_uri) + root_expr = lit(root) chunk_col = _pick_first_column(episodes, _DATA_CHUNK_CANDIDATES) file_col = _pick_first_column(episodes, _DATA_FILE_CANDIDATES) with_paths = episodes.with_column( _DATA_PATH_COL, - _data_parquet_path_expr(col(_LEROBOT_ROOT_COL), chunk_col, file_col), + _data_parquet_path_expr(root_expr, chunk_col, file_col), ) paths = with_paths.select(_DATA_PATH_COL).distinct().to_pydict()[_DATA_PATH_COL] if len(paths) == 0: empty_cols = list(columns) if columns is not None else ["episode_index", "frame_index", "timestamp"] - return from_pydict({c: [] for c in empty_cols}) + return daft.from_pydict({c: [] for c in empty_cols}) + + frames = daft.read_parquet(paths, io_config=io_config) + + decode_w = decode_image_width or 0 + decode_h = decode_image_height or 0 + if (decode_image_width is None) != (decode_image_height is None): + raise ValueError("decode_image_width and decode_image_height must both be set or both omitted.") + + if decode_videos: + info = read_info(root) + meta_vkeys = _video_feature_keys(info.get("features", {})) + if video_keys is None: + selected_vkeys = meta_vkeys + else: + unknown = sorted(set(video_keys) - set(meta_vkeys)) + if unknown: + raise ValueError( + "video_keys contains keys not declared as dtype 'video' in meta/info.json: " + + ", ".join(repr(k) for k in unknown) + ) + selected_vkeys = tuple(video_keys) + + if selected_vkeys: + if "timestamp" not in frames.column_names: + raise ValueError( + "decode_videos requires a `timestamp` column on frame rows (LeRobot v3 Parquet default)." + ) + _assert_episodes_have_video_cols(episodes, selected_vkeys) + + join_cols = ["episode_index"] + for vk in selected_vkeys: + join_cols.extend( + [ + f"videos/{vk}/chunk_index", + f"videos/{vk}/file_index", + f"videos/{vk}/from_timestamp", + ] + ) + vid_meta = episodes.select(*join_cols).distinct() + frames = frames.join(vid_meta, on="episode_index", how="inner") + + for vk in selected_vkeys: + frames = frames.with_column( + vk, + _decode_lerobot_video_timestamp( + video_file( + _video_mp4_path_expr(root_expr, vk), + io_config=io_config, + ), + col(f"videos/{vk}/from_timestamp"), + col("timestamp"), + lit(timestamp_tolerance_seconds), + lit(decode_w), + lit(decode_h), + ), + ) - frames = read_parquet(paths, io_config=io_config) if columns is not None: frames = frames.select(*columns) allowed = with_paths.select("episode_index").distinct() return frames.join(allowed, on="episode_index", how="inner") -@PublicAPI -def read_info(dataset_uri: str) -> dict[str, Any]: - """Load ``meta/info.json`` for a LeRobot v3 dataset.""" - root = _normalize_dataset_root(dataset_uri) - return _read_json_object(root, "meta/info.json") - - -@PublicAPI -def read_stats(dataset_uri: str) -> dict[str, Any]: - """Load ``meta/stats.json`` for a LeRobot v3 dataset.""" - root = _normalize_dataset_root(dataset_uri) - return _read_json_object(root, "meta/stats.json") - - @PublicAPI def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: """Load task metadata as a DataFrame. @@ -239,17 +361,23 @@ def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame if https_base is not None: pq_url = f"{https_base}/meta/tasks.parquet" try: - return read_parquet(pq_url, io_config=io_config) + return daft.read_parquet(pq_url, io_config=io_config) except (OSError, DaftCoreException, FileNotFoundError): - url = f"{https_base}/meta/tasks.jsonl" - return from_pylist(_read_jsonl_records_remote(url)) + return daft.read_json(f"{root}/meta/tasks.jsonl", io_config=io_config) pq_path = Path(root) / "meta" / "tasks.parquet" if pq_path.is_file(): - return read_parquet(str(pq_path), io_config=io_config) + return daft.read_parquet(str(pq_path), io_config=io_config) jsonl_path = Path(root) / "meta" / "tasks.jsonl" if jsonl_path.is_file(): - return from_pylist(_read_jsonl_records_local(jsonl_path)) + return daft.read_json(str(jsonl_path), io_config=io_config) raise FileNotFoundError(f"No tasks metadata found under {root}/meta (tasks.parquet or tasks.jsonl)") + + +__all__ = [ + "load_episode_frames", + "read_episodes", + "read_tasks", +] diff --git a/tests/datasets/test_lerobot.py b/tests/datasets/test_lerobot.py index b76d3b05dbb..05acb56fdd0 100644 --- a/tests/datasets/test_lerobot.py +++ b/tests/datasets/test_lerobot.py @@ -1,6 +1,8 @@ from __future__ import annotations import json +import pathlib +import shutil import pyarrow as pa import pyarrow.parquet as pq @@ -64,16 +66,114 @@ def tiny_lerobot_v3(tmp_path): return str(root) +@pytest.fixture +def tiny_lerobot_v3_video(tmp_path): + """Single-episode dataset with MP4 shards for ``decode_videos`` tests.""" + av = pytest.importorskip("av") + pytest.importorskip("PIL", reason="Pillow required for decoded image rows") + + root = tmp_path / "ds_vid" + (root / "meta" / "episodes" / "chunk-000").mkdir(parents=True) + (root / "data" / "chunk-000").mkdir(parents=True) + + video_key = "camera.test" + video_dir = root / "videos" / video_key / "chunk-000" + video_dir.mkdir(parents=True) + shutil.copy(pathlib.Path("tests/assets/sample_video.mp4"), video_dir / "file-000.mp4") + + with av.open(video_dir / "file-000.mp4") as c: + s = c.streams.video[0] + eps_from_ts = None + for fr in c.decode(s): + if fr.pts is not None: + eps_from_ts = float(fr.pts * s.time_base) + break + assert eps_from_ts is not None + + fps = 30 + n_frames = 3 + durations = [(i / fps) for i in range(n_frames)] + + episodes_tbl = pa.table( + { + "episode_index": [0], + "length": [n_frames], + "task_index": [0], + "data/chunk_index": [0], + "data/file_index": [0], + "dataset_from_index": [0], + "dataset_to_index": [n_frames], + f"videos/{video_key}/chunk_index": [0], + f"videos/{video_key}/file_index": [0], + f"videos/{video_key}/from_timestamp": [eps_from_ts], + } + ) + _write_table(root / "meta/episodes/chunk-000/file-000.parquet", episodes_tbl) + + frames_tbl = pa.table( + { + "index": list(range(n_frames)), + "episode_index": [0] * n_frames, + "frame_index": list(range(n_frames)), + "timestamp": durations, + "task_index": [0] * n_frames, + } + ) + _write_table(root / "data/chunk-000/file-000.parquet", frames_tbl) + + info = { + "codebase_version": "v3.0", + "fps": fps, + "features": { + video_key: {"dtype": "video"}, + }, + "total_episodes": 1, + "total_frames": n_frames, + "total_tasks": 1, + } + (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") + (root / "meta" / "stats.json").write_text(json.dumps({"ok": True}), encoding="utf-8") + + tasks_tbl = pa.table({"task_index": [0], "task": ["pick"]}) + _write_table(root / "meta" / "tasks.parquet", tasks_tbl) + + return str(root) + + +def test_load_episode_frames_decode_videos_explicit_key(tiny_lerobot_v3_video): + ep = episodes(tiny_lerobot_v3_video) + df = ( + load_episode_frames(ep, tiny_lerobot_v3_video, decode_videos=True, video_keys=["camera.test"]) + .select("camera.test") + .collect() + ) + assert df.count_rows() == 3 + img0 = df.to_pydict()["camera.test"][0] + if hasattr(img0, "mode"): + assert img0.mode == "RGB" + assert img0.size[0] > 10 and img0.size[1] > 10 + else: + import numpy as np + + assert isinstance(img0, np.ndarray) + assert img0.ndim == 3 and img0.shape[2] >= 3 + + +def test_load_episode_frames_decode_videos_inferred_keys(tiny_lerobot_v3_video): + ep = episodes(tiny_lerobot_v3_video) + df = load_episode_frames(ep, tiny_lerobot_v3_video, decode_videos=True).select("camera.test").collect() + assert df.count_rows() == 3 + + def test_episodes_and_load_episode_frames(tiny_lerobot_v3): ep = episodes(tiny_lerobot_v3).sort("episode_index") assert ep.count_rows() == 2 - assert "lerobot_dataset_root" in ep.column_names - frames = load_episode_frames(ep).sort("index") + frames = load_episode_frames(ep, tiny_lerobot_v3).sort("index") assert frames.count_rows() == 4 assert set(frames.to_pydict()["episode_index"]) == {0, 1} - f0 = load_episode_frames(ep.where(daft.col("episode_index") == 0)).sort("frame_index") + f0 = load_episode_frames(ep.where(daft.col("episode_index") == 0), tiny_lerobot_v3).sort("frame_index") assert f0.count_rows() == 2 assert f0.to_pydict()["episode_index"] == [0, 0] @@ -89,7 +189,6 @@ def test_read_tasks_parquet(tiny_lerobot_v3): assert t.count_rows() == 1 -def test_load_episode_frames_requires_root(tiny_lerobot_v3): - ep = daft.read_parquet(f"{tiny_lerobot_v3}/meta/episodes/**/*.parquet") - with pytest.raises(ValueError, match="lerobot_dataset_root"): - load_episode_frames(ep) +def test_read_episodes_has_no_dataset_root_column(tiny_lerobot_v3): + ep = episodes(tiny_lerobot_v3) + assert "lerobot_dataset_root" not in ep.column_names From 70ac6926c72905b4ec2a9caf31fbe730bcb77b40 Mon Sep 17 00:00:00 2001 From: Srinivas Lade Date: Tue, 2 Jun 2026 01:14:50 +0200 Subject: [PATCH 03/11] save one more time --- daft/datasets/lerobot.py | 329 +++++++++++---------------------------- daft/functions/video.py | 10 +- 2 files changed, 99 insertions(+), 240 deletions(-) diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py index 274785afc32..627b3f7efc5 100644 --- a/daft/datasets/lerobot.py +++ b/daft/datasets/lerobot.py @@ -1,7 +1,7 @@ """LeRobot Dataset v3.0 helpers for `daft.datasets`. -This module reads the file-based LeRobot v3 layout (``meta/episodes``, ``data``, -``videos``) and exposes episode-level scans plus frame expansion utilities. +This module reads the file-based LeRobot v3 layout (`meta/episodes`, `data`, +`videos`) and exposes episode-level scans plus frame expansion utilities. See https://huggingface.co/docs/lerobot/lerobot-dataset-v3 for format details. """ @@ -9,15 +9,15 @@ from __future__ import annotations import re -from pathlib import Path -from typing import TYPE_CHECKING, Any +import json +from typing import TYPE_CHECKING, Any, TypedDict, cast import daft from daft.api_annotations import PublicAPI from daft.datatype import DataType -from daft.exceptions import DaftCoreException -from daft.expressions import Expression, col, lit +from daft.expressions import col, lit from daft.file import VideoFile +from daft.exceptions import DaftCoreException from daft.functions import lpad from daft.functions.file_ import video_file from daft.udf import func @@ -27,77 +27,18 @@ from daft.dataframe import DataFrame -# Column names used by Hugging Face LeRobot v3 metadata / data shards. -_DATA_CHUNK_CANDIDATES = ("data/chunk_index", "data_chunk_index") -_DATA_FILE_CANDIDATES = ("data/file_index", "data_file_index") - -_DATA_PATH_COL = "lerobot_data_parquet_path" - - -def _is_probable_hf_repo_id(uri: str) -> bool: - return bool(re.fullmatch(r"[\w.-]+/[\w.-]+", uri)) - def _normalize_dataset_root(uri: str) -> str: """Return a canonical dataset root prefix (no trailing slash) for path joins.""" u = uri.strip() - if _is_probable_hf_repo_id(u): + # Input looks like a Hugging Face repo ID, i.e. "org/name" + is_hf_repo_id = bool(re.fullmatch(r"[\w.-]+/[\w.-]+", uri)) + + if is_hf_repo_id: u = f"hf://datasets/{u}" return u.rstrip("/") -def _https_base_for_hf_datasets_root(root: str) -> str | None: - if not root.startswith("hf://datasets/"): - return None - repo_id = root.removeprefix("hf://datasets/") - return f"https://huggingface.co/datasets/{repo_id}/resolve/main" - - -def _pick_first_column(df: DataFrame, candidates: tuple[str, ...]) -> str: - names = set(df.column_names) - for c in candidates: - if c in names: - return c - raise ValueError( - "Expected one of columns " - + ", ".join(repr(c) for c in candidates) - + f" in episodes dataframe, but found columns: {sorted(names)}" - ) - - -def _data_parquet_path_expr(root_expr: Expression, chunk_col: str, file_col: str) -> Expression: - """Build ``{root}/data/chunk-XXX/file-YYY.parquet``.""" - chunk_str = lpad(col(chunk_col).cast(DataType.string), 3, "0") - file_str = lpad(col(file_col).cast(DataType.string), 3, "0") - return ( - root_expr.cast(DataType.string) + lit("/data/chunk-") + chunk_str + lit("/file-") + file_str + lit(".parquet") - ) - - -def _video_mp4_path_expr(root_expr: Expression, video_key: str) -> Expression: - """Build ``{root}/videos/{video_key}/chunk-XXX/file-YYY.mp4`` from episode parquet columns.""" - chunk_col = f"videos/{video_key}/chunk_index" - file_col = f"videos/{video_key}/file_index" - chunk_str = lpad(col(chunk_col).cast(DataType.string), 3, "0") - file_str = lpad(col(file_col).cast(DataType.string), 3, "0") - return ( - root_expr.cast(DataType.string) - + lit(f"/videos/{video_key}/chunk-") - + chunk_str - + lit("/file-") - + file_str - + lit(".mp4") - ) - - -def _video_feature_keys(features: dict[str, Any]) -> tuple[str, ...]: - keys: list[str] = [] - for name, meta in sorted(features.items()): - if isinstance(meta, dict) and meta.get("dtype") == "video": - keys.append(name) - return tuple(keys) - - @func(return_dtype=DataType.image()) def _decode_lerobot_video_timestamp( file: VideoFile, @@ -169,183 +110,110 @@ def _decode_lerobot_video_timestamp( return closest_img -def _assert_episodes_have_video_cols(episodes: DataFrame, video_keys: tuple[str, ...]) -> None: - names = episodes.column_names - missing = [ - candidate - for vk in video_keys - for candidate in ( - f"videos/{vk}/chunk_index", - f"videos/{vk}/file_index", - f"videos/{vk}/from_timestamp", - ) - if candidate not in names - ] - if missing: - raise ValueError( - "Episodes dataframe is missing LeRobot video index columns needed for decoding: " - + ", ".join(repr(x) for x in missing) - ) +class Feature(TypedDict): + dtype: str +class LeRobotInfo(TypedDict): + codebase_version: str + data_path: str + video_path: str + features: dict[str, Feature] -@PublicAPI -def read_episodes(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: - """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). - This reads ``meta/episodes/**/*.parquet`` under the dataset root. +def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobotInfo: + with daft.open_file(f"{normalized_uri}/meta/info.json", io_config=io_config) as f: + info = cast(LeRobotInfo, json.load(f)) + if info["codebase_version"] != "v3.0": + raise ValueError("`daft.datasets.lerobot` currently only supports LeRobot datasets of v3 and above") + return info - Args: - dataset_uri: Local directory, ``hf://datasets/org/name`` URI, or bare - ``org/name`` which is treated as a Hub dataset id. - io_config: Optional IO configuration for remote reads. - Returns: - Lazy episode metadata DataFrame. - """ +@PublicAPI +def read( + dataset_uri: str, + io_config: IOConfig | None = None, + include_stats: bool = False, + load_video_frames: str | list[str] | bool = False, +) -> DataFrame: + """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per frame with episode metadata).""" root = _normalize_dataset_root(dataset_uri) - return daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) + episode_df = daft.datasets.lerobot.read_episodes(dataset_uri, io_config=io_config, include_stats=include_stats) + frame_df = daft.read_parquet(f"{root}/data/**") + df = episode_df.join(frame_df, on=["episode_index"]) + df = df.exclude("data/chunk_index", "data/file_index") + + # Load video frames into memory + if load_video_frames is not False: + if load_video_frames is True: + video_keys = [] # TODO + elif isinstance(load_video_frames, str): + video_keys = [load_video_frames] + elif isinstance(load_video_frames, list) and all(isinstance(k, str) for k in load_video_frames): + video_keys = load_video_frames + else: + raise ValueError("TODO") -@PublicAPI -def read_frames(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: - """Read all frame data from a LeRobot v3 dataset into a lazy DataFrame (one row per frame). + # To increase parallelism, reduce batch size + df = df.into_batches(16) # TODO: Set it in the batch UDF instead? + for k in video_keys: + # TODO: Optimize by using a batch UDF to avoid opening the same video multiple times + df = df.with_column(k, get_video_frame_by_idx(f"videos/{k}/video", col("frame_idx"))) + df = df.exclude(f"videos/{k}/video") - This reads the ``data/chunk-XXX/file-YYY.parquet`` under the dataset root. If you only need a subset of the frames, use :func:`load_episode_frames` instead from a filtered episodes dataframe from :func:`read_episodes`. + # TODO: What about raw images, what do i do about them? Is that a thing in LeRobot v3 - Args: - dataset_uri: Same dataset root as passed to :func:`read_episodes` (local path, - ``hf://datasets/org/name``, or bare ``org/name`` Hub id). - io_config: Optional IO configuration for remote reads. - - Returns: - Lazy DataFrame of frame metadata. - """ - root = _normalize_dataset_root(dataset_uri) - return daft.read_parquet(f"{root}/data/**/*.parquet", io_config=io_config) + return df @PublicAPI -def load_episode_frames( - episodes: DataFrame, +def read_episodes( dataset_uri: str, - *, io_config: IOConfig | None = None, - columns: list[str] | None = None, - decode_videos: bool = False, - video_keys: list[str] | None = None, - timestamp_tolerance_seconds: float = 1e-4, - decode_image_width: int | None = None, - decode_image_height: int | None = None, + include_meta: bool = False, + include_stats: bool = False, + include_video_metadata: bool = False, ) -> DataFrame: - """Expand filtered episode rows into frame-level rows from ``data/`` Parquet shards. - - This executes a small eager step to discover distinct shard paths from the - current logical plan, then lazily reads only those Parquet files and keeps - rows whose ``episode_index`` appears in ``episodes``. - - Optionally decodes MP4 shards under ``videos//chunk-XXX/file-YYY.mp4`` into - :class:`~daft.datatype.DataType` ``image()`` columns keyed by LeRobot ``feature_key`` strings - (typically ``dtype: "video"`` entries in ``meta/info.json``), using the episode-level - ``videos//from_timestamp``, ``videos//chunk_index``, and ``videos//file_index`` - fields plus row ``timestamp`` values (matching how ``LeRobotDataset`` aligns frames). + """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). - Preconditions: - - ``episodes`` must include either ``data/chunk_index`` / ``data/file_index`` - (canonical LeRobot v3) or the ``data_chunk_index`` / ``data_file_index`` spelling. + This reads the `meta/episodes/**/*.parquet` path under the dataset root. Args: - episodes: Episode-level dataframe (typically filtered) from :func:`read_episodes`. - dataset_uri: Same dataset root as passed to :func:`read_episodes` (local path, - ``hf://datasets/org/name``, or bare ``org/name`` Hub id). + dataset_uri: Huggingface repo id (`org/name`), + or a local / remote directory (`s3://...`, `hf://datasets/...`) io_config: Optional IO configuration for remote reads. - columns: Optional projection of frame columns (passed to :meth:`daft.DataFrame.select`). - decode_videos: When ``True``, add decoded camera images for each declared ``video_keys`` subset. - Requires PyAV and Pillow plus per-episode columns ``videos//{chunk_index,file_index,from_timestamp}``. - video_keys: Subset of video feature keys from ``meta/info.json`` (must have ``dtype: "video"``). - When ``None`` and ``decode_videos`` is enabled, all video features are decoded. - timestamp_tolerance_seconds: Maximum |PTS - (from_timestamp + timestamp)| in seconds (LeRobot default is ~1e-4). - decode_image_width: If set with ``decode_image_height``, nearest-neighbor resize decoded frames. - decode_image_height: See ``decode_image_width``. Returns: - Lazy frame-level dataframe. + Lazy DataFrame of episode metadata. """ root = _normalize_dataset_root(dataset_uri) - root_expr = lit(root) - - chunk_col = _pick_first_column(episodes, _DATA_CHUNK_CANDIDATES) - file_col = _pick_first_column(episodes, _DATA_FILE_CANDIDATES) - - with_paths = episodes.with_column( - _DATA_PATH_COL, - _data_parquet_path_expr(root_expr, chunk_col, file_col), - ) - - paths = with_paths.select(_DATA_PATH_COL).distinct().to_pydict()[_DATA_PATH_COL] - if len(paths) == 0: - empty_cols = list(columns) if columns is not None else ["episode_index", "frame_index", "timestamp"] - return daft.from_pydict({c: [] for c in empty_cols}) + info = _read_info(root, io_config=io_config) + + # TODO: What is the `meta` episodes into used for? How is it different from the `videos` info? + df = daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) + if not include_meta: + df = df.exclude(*(c for c in df.column_names if c.startswith("meta/"))) + if not include_stats: + df = df.exclude(*(c for c in df.column_names if c.startswith("stats/"))) + + # Get the video keys + video_keys = set(name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video") + + for key in video_keys: + file_name_expr = ( + lit(f"{root}/videos/{key}/chunk-") + + lpad(col(f"videos/{key}/chunk_index").cast(DataType.string), 3, "0") + + lit("/file-") + + lpad(col(f"videos/{key}/file_index").cast(DataType.string), 3, "0") + + lit(".mp4") + ) - frames = daft.read_parquet(paths, io_config=io_config) + df = df.with_column(f"videos/{key}/video", video_file(file_name_expr, verify=False, io_config=io_config)) - decode_w = decode_image_width or 0 - decode_h = decode_image_height or 0 - if (decode_image_width is None) != (decode_image_height is None): - raise ValueError("decode_image_width and decode_image_height must both be set or both omitted.") + if not include_video_metadata: + df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) - if decode_videos: - info = read_info(root) - meta_vkeys = _video_feature_keys(info.get("features", {})) - if video_keys is None: - selected_vkeys = meta_vkeys - else: - unknown = sorted(set(video_keys) - set(meta_vkeys)) - if unknown: - raise ValueError( - "video_keys contains keys not declared as dtype 'video' in meta/info.json: " - + ", ".join(repr(k) for k in unknown) - ) - selected_vkeys = tuple(video_keys) - - if selected_vkeys: - if "timestamp" not in frames.column_names: - raise ValueError( - "decode_videos requires a `timestamp` column on frame rows (LeRobot v3 Parquet default)." - ) - _assert_episodes_have_video_cols(episodes, selected_vkeys) - - join_cols = ["episode_index"] - for vk in selected_vkeys: - join_cols.extend( - [ - f"videos/{vk}/chunk_index", - f"videos/{vk}/file_index", - f"videos/{vk}/from_timestamp", - ] - ) - vid_meta = episodes.select(*join_cols).distinct() - frames = frames.join(vid_meta, on="episode_index", how="inner") - - for vk in selected_vkeys: - frames = frames.with_column( - vk, - _decode_lerobot_video_timestamp( - video_file( - _video_mp4_path_expr(root_expr, vk), - io_config=io_config, - ), - col(f"videos/{vk}/from_timestamp"), - col("timestamp"), - lit(timestamp_tolerance_seconds), - lit(decode_w), - lit(decode_h), - ), - ) - - if columns is not None: - frames = frames.select(*columns) - allowed = with_paths.select("episode_index").distinct() - return frames.join(allowed, on="episode_index", how="inner") + return df @PublicAPI @@ -357,27 +225,16 @@ def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame """ root = _normalize_dataset_root(dataset_uri) - https_base = _https_base_for_hf_datasets_root(root) - if https_base is not None: - pq_url = f"{https_base}/meta/tasks.parquet" - try: - return daft.read_parquet(pq_url, io_config=io_config) - except (OSError, DaftCoreException, FileNotFoundError): - return daft.read_json(f"{root}/meta/tasks.jsonl", io_config=io_config) - - pq_path = Path(root) / "meta" / "tasks.parquet" - if pq_path.is_file(): - return daft.read_parquet(str(pq_path), io_config=io_config) - - jsonl_path = Path(root) / "meta" / "tasks.jsonl" - if jsonl_path.is_file(): - return daft.read_json(str(jsonl_path), io_config=io_config) - - raise FileNotFoundError(f"No tasks metadata found under {root}/meta (tasks.parquet or tasks.jsonl)") + pq_url = f"{root}/meta/tasks.parquet" + try: + return daft.read_parquet(pq_url, io_config=io_config) + except (OSError, DaftCoreException, FileNotFoundError): + return daft.read_json(f"{root}/meta/tasks.jsonl", io_config=io_config) __all__ = [ "load_episode_frames", + "read", "read_episodes", "read_tasks", ] diff --git a/daft/functions/video.py b/daft/functions/video.py index 98fef6d093c..e5890364855 100644 --- a/daft/functions/video.py +++ b/daft/functions/video.py @@ -144,8 +144,8 @@ def frames_impl( def video_frames( file_expr: Expression, *, - start_time: float = 0, - end_time: float | None = None, + start_time: float | Expression = 0, + end_time: float | None | Expression = None, width: int | None = None, height: int | None = None, is_key_frame: bool | None = None, @@ -157,8 +157,10 @@ def video_frames( Args: file_expr (VideoFile Expression): The video file to decode frames from. - start_time (float, optional): Start of the time range in seconds. Defaults to 0. - end_time (float | None, optional): End of the time range in seconds. Defaults to None (all frames). + start_time (float | Expression, optional): Start of the time range in seconds. Defaults to 0. + If an expression is provided, the start time will be dynamic per row. + end_time (float | None | Expression, optional): End of the time range in seconds. Defaults to None (all frames). + If an expression is provided, the end time will be dynamic per row. width (int | None, optional): Target width for resizing frames. Must be provided with ``height``. height (int | None, optional): Target height for resizing frames. Must be provided with ``width``. is_key_frame (bool | None, optional): If True, decode only keyframes. If False, From 1fe73e0e5d03fc824e1afcb43f0804f1e423be84 Mon Sep 17 00:00:00 2001 From: Srinivas Lade Date: Sat, 6 Jun 2026 15:45:26 -0400 Subject: [PATCH 04/11] save again --- daft/datasets/lerobot.py | 4 ++-- daft/file/video.py | 41 ++++++++++++++++++++++++++++++++++++++++ daft/functions/video.py | 35 ++++++++++++++++++++++++++++++++++ tests/file/test_video.py | 21 ++++++++++++++++++++ 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py index 627b3f7efc5..4430e288456 100644 --- a/daft/datasets/lerobot.py +++ b/daft/datasets/lerobot.py @@ -20,6 +20,7 @@ from daft.exceptions import DaftCoreException from daft.functions import lpad from daft.functions.file_ import video_file +from daft.functions.video import get_video_frame_by_idx from daft.udf import func if TYPE_CHECKING: @@ -27,7 +28,6 @@ from daft.dataframe import DataFrame - def _normalize_dataset_root(uri: str) -> str: """Return a canonical dataset root prefix (no trailing slash) for path joins.""" u = uri.strip() @@ -152,7 +152,7 @@ def read( elif isinstance(load_video_frames, list) and all(isinstance(k, str) for k in load_video_frames): video_keys = load_video_frames else: - raise ValueError("TODO") + raise ValueError(f"Invalid value provided for argument load_video_frames=`{load_video_frames}`") # To increase parallelism, reduce batch size df = df.into_batches(16) # TODO: Set it in the batch UDF instead? diff --git a/daft/file/video.py b/daft/file/video.py index 4069e4832c0..4bc07145a59 100644 --- a/daft/file/video.py +++ b/daft/file/video.py @@ -220,3 +220,44 @@ def frames( ) frame_index += 1 + + def get_frame_by_idx(self, idx: int) -> PIL.Image.Image: + if not pil_image.module_available(): + raise ImportError( + "The 'pillow' module is required for frame decoding. Install it with `pip install daft[video]`." + ) + if idx < 0: + raise IndexError(f"Frame index {idx} is out of range") + + with self.open() as f: + with av.open(f) as container: + video = next( + (stream for stream in container.streams if stream.type == "video"), + None, + ) + if video is None: + raise ValueError("No video stream found") + + time_base = float(video.time_base) if video.time_base else None + fps = float(video.average_rate) if video.average_rate else None + if fps is None and video.guessed_rate: + fps = float(video.guessed_rate) + start_pts = video.start_time or 0 + + # Seek to the nearest preceding keyframe at or before the target frame. + if idx > 0 and time_base is not None and fps is not None: + target_time = idx / fps + seek_timestamp = int(target_time / time_base) + container.seek(seek_timestamp, stream=video, backward=True) + + for frame_idx, frame in enumerate(container.decode(video)): + current_frame_index = frame_idx + if frame.pts is not None and time_base is not None and fps is not None: + current_frame_index = int(round((frame.pts - start_pts) * time_base * fps)) + + if current_frame_index == idx: + return frame.to_image() + if current_frame_index > idx: + break + + raise IndexError(f"Frame index {idx} is out of range") diff --git a/daft/functions/video.py b/daft/functions/video.py index e5890364855..c6d4ebba209 100644 --- a/daft/functions/video.py +++ b/daft/functions/video.py @@ -192,3 +192,38 @@ def video_frames( is_key_frame=is_key_frame, sample_interval_seconds=sample_interval_seconds, ) # type: ignore + + +def get_frame_by_idx_impl( + file: daft.VideoFile, + idx: int, +) -> PIL.Image.Image: + return file.get_frame_by_idx(idx) + +video_get_frame_by_idx_fn = Func._from_func( + get_frame_by_idx_impl, + return_dtype=daft.DataType.image(), + unnest=False, + use_process=None, + is_batch=False, + batch_size=None, + max_retries=None, + on_error=None, + name_override="video_get_frame_by_idx", +) + +def get_video_frame_by_idx( + file_expr: Expression, + idx: int | Expression, +) -> Expression: + """Get a frame from a video file by index. + + Args: + file_expr (VideoFile Expression): The video file to get the frame from. + idx (int | Integer Expression): The index of the frame to get. + If an expression is provided, the index will be dynamic per row. + + Returns: + Expression (Image Expression): The frame as an image. + """ + return video_get_frame_by_idx_fn(file_expr, idx) diff --git a/tests/file/test_video.py b/tests/file/test_video.py index 067a1cf8071..539c36946a9 100644 --- a/tests/file/test_video.py +++ b/tests/file/test_video.py @@ -338,3 +338,24 @@ def test_video_frames_expression_sample_interval(sample_video_path): result = df.to_pydict()["frames"][0] assert len(result) == 10 + + +def test_get_frame_by_idx(sample_video_path): + """get_frame_by_idx returns the same frame as frames() for a given index.""" + import numpy as np + + file = daft.VideoFile(sample_video_path) + all_frames = list(file.frames()) + + for idx in (0, 1, 50, 150, 289): + expected = all_frames[idx]["data"] + actual = file.get_frame_by_idx(idx) + np.testing.assert_array_equal(np.array(actual), np.array(expected)) + + +def test_get_frame_by_idx_out_of_range(sample_video_path): + file = daft.VideoFile(sample_video_path) + with pytest.raises(IndexError): + file.get_frame_by_idx(290) + with pytest.raises(IndexError): + file.get_frame_by_idx(-1) From ffe65e4842bc8dd305bc7aeccd1dd0c341202472 Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Thu, 11 Jun 2026 13:57:30 -0700 Subject: [PATCH 05/11] fix(lerobot): decode video frames by timestamp instead of frame index MP4 shards pack multiple episodes back to back, so a shard's internal frame numbering does not match the parquet's episode-local frame_index (it only lines up for the first episode in each shard). Seek by absolute timestamp instead: the episode's `from_timestamp` within the shard plus the frame's episode-local `timestamp`, accepting the closest decoded frame within half a frame period. Also: - populate `video_keys` from info.json features (was a TODO) - have read() reuse read_episodes() + load_episode_frames() instead of duplicating the episode/frame join - sync docs/api/datasets.md with the current public API (read / read_episodes / load_episode_frames / read_tasks) Co-Authored-By: Claude Fable 5 --- daft/datasets/lerobot.py | 83 ++++++++++++++++++++++++++++++++++------ docs/api/datasets.md | 13 ++----- 2 files changed, 75 insertions(+), 21 deletions(-) diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py index 4430e288456..dfbb7ef2a5c 100644 --- a/daft/datasets/lerobot.py +++ b/daft/datasets/lerobot.py @@ -20,7 +20,6 @@ from daft.exceptions import DaftCoreException from daft.functions import lpad from daft.functions.file_ import video_file -from daft.functions.video import get_video_frame_by_idx from daft.udf import func if TYPE_CHECKING: @@ -109,7 +108,6 @@ def _decode_lerobot_video_timestamp( ) return closest_img - class Feature(TypedDict): dtype: str @@ -117,9 +115,9 @@ class LeRobotInfo(TypedDict): codebase_version: str data_path: str video_path: str + fps: float features: dict[str, Feature] - def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobotInfo: with daft.open_file(f"{normalized_uri}/meta/info.json", io_config=io_config) as f: info = cast(LeRobotInfo, json.load(f)) @@ -127,7 +125,6 @@ def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobo raise ValueError("`daft.datasets.lerobot` currently only supports LeRobot datasets of v3 and above") return info - @PublicAPI def read( dataset_uri: str, @@ -137,16 +134,21 @@ def read( ) -> DataFrame: """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per frame with episode metadata).""" root = _normalize_dataset_root(dataset_uri) + info = _read_info(root, io_config=io_config) - episode_df = daft.datasets.lerobot.read_episodes(dataset_uri, io_config=io_config, include_stats=include_stats) - frame_df = daft.read_parquet(f"{root}/data/**") - df = episode_df.join(frame_df, on=["episode_index"]) - df = df.exclude("data/chunk_index", "data/file_index") + # Keep the per-episode video metadata (notably `videos/{key}/from_timestamp`, + # the time within the shard where each episode's footage begins). We need it + # to translate episode-local frame timestamps into absolute shard timestamps + # when decoding, and drop these internal columns again before returning. + episode_df = read_episodes( + dataset_uri, io_config=io_config, include_stats=include_stats, include_video_metadata=True + ) + df = load_episode_frames(episode_df, dataset_uri, io_config=io_config) # Load video frames into memory if load_video_frames is not False: if load_video_frames is True: - video_keys = [] # TODO + video_keys = [name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video"] elif isinstance(load_video_frames, str): video_keys = [load_video_frames] elif isinstance(load_video_frames, list) and all(isinstance(k, str) for k in load_video_frames): @@ -154,15 +156,38 @@ def read( else: raise ValueError(f"Invalid value provided for argument load_video_frames=`{load_video_frames}`") + # An MP4 shard packs many episodes back to back, so the shard's internal + # frame numbering is NOT the parquet's episode-local `frame_index` (which + # resets to 0 each episode). Seeking by `frame_index` only happens to work + # for the first episode in each shard. Instead, seek by absolute timestamp: + # `from_timestamp` (where this episode begins in the shard) + the per-frame + # episode-local `timestamp`. That keeps a single coordinate system end to end. + fps = float(info["fps"]) + tolerance_s = 1.0 / fps / 2.0 # half a frame period: any closer frame is unambiguously "the" frame + # To increase parallelism, reduce batch size - df = df.into_batches(16) # TODO: Set it in the batch UDF instead? + df = df.into_batches(16) # TODO (for later): Set it in the batch UDF instead? for k in video_keys: - # TODO: Optimize by using a batch UDF to avoid opening the same video multiple times - df = df.with_column(k, get_video_frame_by_idx(f"videos/{k}/video", col("frame_idx"))) + # TODO (for later): Optimize by using a batch UDF to avoid opening the same video multiple times + df = df.with_column( + k, + _decode_lerobot_video_timestamp( + col(f"videos/{k}/video"), + col(f"videos/{k}/from_timestamp"), + col("timestamp"), + lit(tolerance_s), + lit(0), # image_width: 0 disables resize (decode at native resolution) + lit(0), # image_height: 0 disables resize + ), + ) df = df.exclude(f"videos/{k}/video") # TODO: What about raw images, what do i do about them? Is that a thing in LeRobot v3 + # Drop the internal per-episode video metadata we kept above (chunk/file index, + # from/to timestamp). This restores read_episodes' default of hiding these. + df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) + return df @@ -216,6 +241,40 @@ def read_episodes( return df +@PublicAPI +def load_episode_frames( + episodes: DataFrame, + dataset_uri: str, + io_config: IOConfig | None = None, +) -> DataFrame: + """Expand an episode-level DataFrame into a frame-level DataFrame. + + Reads the per-frame parquet under ``data/**`` and joins it to the provided + episode metadata on ``episode_index``, producing one row per frame. Episode + metadata is broadcast across each episode's frames. + + Filter ``episodes`` before calling this to expand only the episodes you need; + only the surviving episodes contribute to the join. + + Args: + episodes: Episode-level DataFrame, typically from :func:`read_episodes` + (optionally filtered). Must contain an ``episode_index`` column. + dataset_uri: The same dataset identifier passed to :func:`read_episodes` + (Huggingface repo id ``org/name``, or a local / remote directory such + as ``s3://...`` or ``hf://datasets/...``). + io_config: Optional IO configuration for remote reads. + + Returns: + Lazy DataFrame with one row per frame. + """ + root = _normalize_dataset_root(dataset_uri) + + frame_df = daft.read_parquet(f"{root}/data/**", io_config=io_config) + df = episodes.join(frame_df, on=["episode_index"]) + df = df.exclude("data/chunk_index", "data/file_index") + return df + + @PublicAPI def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: """Load task metadata as a DataFrame. diff --git a/docs/api/datasets.md b/docs/api/datasets.md index 29d93ffe8d7..8adf8604ca5 100644 --- a/docs/api/datasets.md +++ b/docs/api/datasets.md @@ -15,27 +15,22 @@ Check out our [Common Crawl dataset guide](../datasets/common-crawl.md) for more See the [LeRobot v3 dataset guide](../datasets/lerobot.md) for episode vs frame workflows and Hub/local paths. -::: daft.datasets.lerobot.episodes +::: daft.datasets.lerobot.read options: filters: ["!^_"] heading_level: 3 -::: daft.datasets.lerobot.load_episode_frames - options: - filters: ["!^_"] - heading_level: 3 - -::: daft.datasets.lerobot.read_tasks +::: daft.datasets.lerobot.read_episodes options: filters: ["!^_"] heading_level: 3 -::: daft.datasets.lerobot.read_info +::: daft.datasets.lerobot.load_episode_frames options: filters: ["!^_"] heading_level: 3 -::: daft.datasets.lerobot.read_stats +::: daft.datasets.lerobot.read_tasks options: filters: ["!^_"] heading_level: 3 From 918f6107c867d6f5d65aa4d55549caee1daeef54 Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Thu, 11 Jun 2026 14:27:48 -0700 Subject: [PATCH 06/11] feat(examples): add LeRobot + H-RDT pose prediction example End-to-end example using daft.datasets.lerobot on the EgoDex test dataset: batched H-RDT inference as a @daft.cls UDF (predict_poses.py), EgoDex-paper keypoint-error metrics (compute_metrics.py), and overlay visualizations projecting predicted vs ground-truth hand poses onto the video frames (visualize_predictions.py). Includes a vendored copy of the reader so the scripts also run against released daft wheels. --- examples/lerobot_pose/README.md | 108 +++++++ examples/lerobot_pose/compute_metrics.py | 82 +++++ .../lerobot_pose/encode_task_embeddings.py | 108 +++++++ examples/lerobot_pose/lerobot.py | 305 ++++++++++++++++++ examples/lerobot_pose/predict_poses.py | 269 +++++++++++++++ .../lerobot_pose/visualize_predictions.py | 166 ++++++++++ 6 files changed, 1038 insertions(+) create mode 100644 examples/lerobot_pose/README.md create mode 100644 examples/lerobot_pose/compute_metrics.py create mode 100644 examples/lerobot_pose/encode_task_embeddings.py create mode 100644 examples/lerobot_pose/lerobot.py create mode 100644 examples/lerobot_pose/predict_poses.py create mode 100644 examples/lerobot_pose/visualize_predictions.py diff --git a/examples/lerobot_pose/README.md b/examples/lerobot_pose/README.md new file mode 100644 index 00000000000..67c95578b33 --- /dev/null +++ b/examples/lerobot_pose/README.md @@ -0,0 +1,108 @@ +# LeRobot + H-RDT: per-frame 48-D action prediction with Daft +This example reads the [EgoDex test dataset](https://huggingface.co/datasets/pepijn223/egodex-test) (LeRobot v3 format) as a lazy Daft DataFrame — one row per frame, with the decoded camera image and the 48-D hand state — runs the [H-RDT](https://github.com/HongzheBi/H_RDT) policy on every frame, and stores the predicted 48-D action vector as a new column. + +``` +LeRobot v3 dataset (parquet + mp4 shards on HF Hub) + │ lerobot.read(..., load_video_frames=True) # lazy, streaming + ▼ +one row per frame: observation.image · observation.state (48) · task_index · ... + │ HRDTPredictor.predict(...) # @daft.cls batch UDF on GPU + ▼ ++ predicted_action (48-D float vector) + │ write_parquet + ▼ +out/egodex_hrdt_predictions/ +``` + +## Files + +| File | Purpose | +| --- | --- | +| `lerobot.py` | Vendored copy of `daft.datasets.lerobot` (LeRobot v3 reader). Delete once it ships in a Daft release. | +| `encode_task_embeddings.py` | One-time preprocessing: encode each task instruction with T5-XXL and cache it (H-RDT consumes language *embeddings*, not text). | +| `predict_poses.py` | The pipeline: decode frames → batched H-RDT inference → write parquet (pure predictions). | +| `compute_metrics.py` | Score the predictions: per-frame `avg_keypoint_distance_m` (EgoDex paper metric, arXiv:2505.11709 §4.3) + per-episode and overall summaries. Torch-free, re-runnable in seconds. | +| `visualize_predictions.py` | Project predicted vs ground-truth hand poses onto frames; writes PNG overlays and per-episode mp4s. | + +## Setup + +1. **Clone H_RDT** as a sibling of this repo (or set `HRDT_PROJECT_ROOT`): + + ```bash + git clone https://github.com/HongzheBi/H_RDT ../../../H_RDT + ``` + +2. **Download the pretrained weights** into the clone (~8.8 GB). This pulls the + EgoDex pretrain checkpoint (`checkpoints/pretrain-0618/checkpoint-500000`, + ~4.1 GB) and the DinoSigLIP vision backbone weights, skipping the duplicate + safetensors copies the code doesn't load: + + ```bash + uvx --from huggingface_hub hf download embodiedfoundation/H-RDT \ + --include "checkpoints/*" \ + --include "bak/dino-siglip/vit_large_patch14_reg4_dinov2.lvd142m/pytorch_model.bin" \ + --include "bak/dino-siglip/vit_large_patch14_reg4_dinov2.lvd142m/config.json" \ + --include "bak/dino-siglip/vit_so400m_patch14_siglip_384/open_clip_pytorch_model.bin" \ + --include "bak/dino-siglip/vit_so400m_patch14_siglip_384/*.json" \ + --local-dir ../../../H_RDT + ``` + +3. **T5-XXL encoder** downloads automatically (~9.5 GB into the HF cache) the + first time you run `encode_task_embeddings.py`. The default model is + `city96/t5-v1_1-xxl-encoder-bf16`, an encoder-only bfloat16 conversion of + `google/t5-v1_1-xxl` — numerically equivalent for our purposes (H-RDT's own + pipeline ran the encoder in bfloat16) but 4.7x smaller than the official + fp32 encoder+decoder repo. Loading takes ~10 GB RAM, CPU is fine. Set + `T5_MODEL_PATH=google/t5-v1_1-xxl` to use the original instead (~44.5 GB). + +## Run + +```bash +# 1. One-time: cache a T5 embedding per task (3 tasks in egodex-test) +uv run encode_task_embeddings.py + +# 2. Quick local trial: predict only the first 8 frames +MAX_FRAMES=8 uv run predict_poses.py + +# 3. Full run: predict an action for every frame and write parquet +uv run predict_poses.py + +# 4. Score the predictions (EgoDex paper's 12-keypoint metric) — re-runnable +uv run compute_metrics.py + +# 5. Render overlay PNGs and per-episode mp4s of predicted vs ground truth +uv run visualize_predictions.py +``` + +## How it works + +- **Lazy frames.** `lerobot.read` only builds a plan. Execution streams shard + downloads, video decoding, inference, and writing — the full dataset never + sits in memory. +- **`@daft.cls` for the model.** Daft constructs the predictor once per worker + process (loading ~9 GB of weights in `__init__`) and reuses it for every + batch. A plain function UDF would have nowhere to keep the loaded model. +- **Batched inference.** `@daft.method.batch(batch_size=16)` hands the UDF + whole columns (`daft.Series`) at a time, so DinoSigLIP and the policy run one + forward pass per 16 frames instead of per frame. +- **Concurrency.** Frame decoding (CPU) is fanned out across cores by Daft — + the reader splits work with `into_batches(16)` — and overlaps with inference, + which runs on a single model instance (`@daft.cls(gpus=..., max_concurrency=1)`). + To run N concurrent model replicas on one GPU, use fractional GPUs: + `@daft.cls(gpus=1/N, max_concurrency=N)` — each replica holds its own copy of + the weights (~6.5 GB VRAM in bf16), so size VRAM accordingly. +- **Normalization contract.** Following H-RDT's EgoDex pretraining + (`datasets/pretrain/egodex_dataset.py`): the input state is min/max scaled to + `[-1, 1]` using `egodex_stat.json`, and the predicted chunk is denormalized + with the inverse mapping. The model predicts 16 future steps; we keep step 0, + so `predicted_action` is one 48-D vector per frame (24 dims per hand: wrist + pose + finger keypoints). + +## Output + +`out/egodex_hrdt_predictions/` — parquet with one row per frame: +`episode_index`, `frame_index`, `timestamp`, `task_index`, +`observation.state`, `ground_truth_action`, `predicted_action` (48-D +`embedding` column), ready for `daft.read_parquet` to evaluate prediction error +against the ground-truth actions. + diff --git a/examples/lerobot_pose/compute_metrics.py b/examples/lerobot_pose/compute_metrics.py new file mode 100644 index 00000000000..8fecdc89bd8 --- /dev/null +++ b/examples/lerobot_pose/compute_metrics.py @@ -0,0 +1,82 @@ +# /// script +# description = "Compute EgoDex-paper keypoint error metrics over H-RDT predictions" +# requires-python = ">=3.12, <3.13" +# dependencies = [ +# "daft>=0.7.15", +# "numpy", +# ] +# /// +"""Score the predictions written by predict_hrdt.py. + +Kept separate from the prediction pipeline on purpose: predictions cost +minutes of model time, metrics cost milliseconds. Splitting them means you can +re-score (or add new metrics) without re-running the model, and this script's +environment needs no torch at all. + +Reads `out/egodex_hrdt_predictions/`, writes per-frame metrics to +`out/egodex_hrdt_metrics/`, and prints overall + per-episode summaries. + + uv run compute_metrics.py +""" + +import os + +import numpy as np + +import daft +from daft import DataType, col + +PREDICTIONS_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_predictions") +METRICS_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_metrics") + + +@daft.func(return_dtype=DataType.float64()) +def avg_keypoint_distance_m(predicted_action: list[float], ground_truth_action: list[float]) -> float: + """EgoDex paper's trajectory-prediction metric (arXiv:2505.11709, Sec 4.3). + + "Euclidean distance between predicted 3D keypoint positions and their + ground truth 3D counterparts, averaged over ... each of the 12 keypoints" + (both wrists + all 10 fingertips), in meters. The 2x6 wrist-rotation dims + are excluded: they are unitless rotation-matrix columns. + + Ours is the metric at a 1-step horizon with K=1: the paper averages over + every timestep of the predicted chunk and scores the best of K sampled + trajectories, while we keep only the chunk's first step and sample once. + """ + predicted = np.asarray(predicted_action, dtype=np.float64) + ground_truth = np.asarray(ground_truth_action, dtype=np.float64) + distances = [] + for base in (0, 24): # left hand dims 0-23, right hand dims 24-47 + keypoint_starts = [base] + [base + 9 + 3 * i for i in range(5)] # wrist, then 5 fingertips + for start in keypoint_starts: + distances.append(np.linalg.norm(predicted[start : start + 3] - ground_truth[start : start + 3])) + return float(np.mean(distances)) + + +if __name__ == "__main__": + metrics = ( + daft.read_parquet(f"{PREDICTIONS_DIR}/**") + .with_column( + "avg_keypoint_distance_m", + avg_keypoint_distance_m(col("predicted_action"), col("ground_truth_action")), + ) + # Keep only identifiers + the metric; the predictions stay in their own files. + .select("episode_index", "frame_index", "timestamp", "task_index", "avg_keypoint_distance_m") + ) + + metrics.write_parquet(METRICS_DIR) + print(f"Wrote per-frame metrics to {METRICS_DIR}\n") + + results = daft.read_parquet(f"{METRICS_DIR}/**") + + print("Per-episode:") + results.groupby("episode_index").agg( + col("avg_keypoint_distance_m").mean().alias("mean_m"), + col("avg_keypoint_distance_m").max().alias("worst_frame_m"), + col("avg_keypoint_distance_m").count().alias("frames"), + ).sort("episode_index").show() + + print("Overall:") + results.select( + col("avg_keypoint_distance_m").mean().alias("dataset_avg_keypoint_distance_m"), + ).show() diff --git a/examples/lerobot_pose/encode_task_embeddings.py b/examples/lerobot_pose/encode_task_embeddings.py new file mode 100644 index 00000000000..9503257c61b --- /dev/null +++ b/examples/lerobot_pose/encode_task_embeddings.py @@ -0,0 +1,108 @@ +# /// script +# description = "Precompute T5 language embeddings for every task in a LeRobot dataset (H-RDT input)" +# requires-python = ">=3.12, <3.13" +# dependencies = [ +# "daft>=0.7.15", +# "torch", +# "transformers", +# "sentencepiece", +# "protobuf", +# "accelerate", +# ] +# /// +"""Encode each task instruction in the dataset with T5 and cache it to disk. + +H-RDT never sees raw text. At train and inference time it consumes language +*embeddings*: the task instruction run through a frozen T5 encoder +(`t5-v1_1-xxl`, 4096-dim features). Encoding is expensive (T5-XXL is an ~11B +parameter model), but a dataset only has a handful of distinct task strings, +so we encode each one exactly once up front and store it as a `.pt` file keyed +by `task_index`. The prediction pipeline then does a dictionary lookup per row +instead of running T5 per row. + +This mirrors H_RDT's `models/encoder/t5_encoder.py` (same model, tokenizer +settings, and bfloat16 dtype) but calls `transformers` directly, because +`T5Embedder` hardcodes the author's local weight path in an assert. Only the +encoder half of T5 is loaded (~4.7B params, ~10 GB RAM in bfloat16), so this +runs fine on CPU. + +Run this once before `predict_hrdt.py`: + + uv run encode_task_embeddings.py +""" + +import os + +import lerobot # vendored copy of daft.datasets.lerobot +import torch + +DATASET_URI = "pepijn223/egodex-test" +# HF id of (or local path to) the T5 model. Must be the XXL variant: H-RDT's +# text adapter expects 4096-dim features. The default is an encoder-only +# bfloat16 conversion of google/t5-v1_1-xxl (~9.5 GB download instead of the +# official repo's 44.5 GB fp32 encoder+decoder). It is numerically equivalent +# here: H-RDT's own T5Embedder loaded the encoder in bfloat16 too. +T5_MODEL_PATH = os.environ.get("T5_MODEL_PATH", "city96/t5-v1_1-xxl-encoder-bf16") +# Matches `tokenizer_max_length` in H_RDT's configs/hrdt_pretrain.yaml. +TOKENIZER_MAX_LENGTH = 1024 +OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "task_embeddings") + + +def load_tasks(dataset_uri: str) -> dict[int, str]: + """Return {task_index: task instruction} from the dataset's `meta/tasks.parquet`.""" + df = lerobot.read_tasks(dataset_uri).collect() + data = df.to_pydict() + # The task string column name varies with how the dataset was exported + # (e.g. pandas writes the task as the index column `__index_level_0__`), + # so find it by dtype rather than by name. + text_col = next(name for name, values in data.items() if values and isinstance(values[0], str)) + return dict(zip(data["task_index"], data[text_col])) + + +def main() -> None: + from transformers import AutoTokenizer, T5EncoderModel + + tasks = load_tasks(DATASET_URI) + print(f"Found {len(tasks)} tasks in {DATASET_URI}:") + for idx, text in sorted(tasks.items()): + print(f" [{idx}] {text}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_PATH, model_max_length=TOKENIZER_MAX_LENGTH) + print(f"Loading T5 encoder on {device} (bfloat16)...") + model = ( + T5EncoderModel.from_pretrained( + T5_MODEL_PATH, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, # stream + convert weights instead of loading all fp32 at once + ) + .to(device) + .eval() + ) + + os.makedirs(OUTPUT_DIR, exist_ok=True) + with torch.no_grad(): + for idx, text in sorted(tasks.items()): + tokenized = tokenizer( + [text], + max_length=TOKENIZER_MAX_LENGTH, + padding="longest", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + embeddings = model( + input_ids=tokenized["input_ids"].to(device), + attention_mask=tokenized["attention_mask"].to(device), + )["last_hidden_state"] + # Trim padding so we only store (and later attend over) real tokens. + num_tokens = int(tokenized["attention_mask"][0].sum()) + trimmed = embeddings[0, :num_tokens].to(torch.float32).cpu() + out_path = os.path.join(OUTPUT_DIR, f"task_{idx:03d}.pt") + torch.save({"task_index": idx, "task": text, "embeddings": trimmed}, out_path) + print(f"Saved {out_path} (shape {tuple(trimmed.shape)})") + + +if __name__ == "__main__": + main() diff --git a/examples/lerobot_pose/lerobot.py b/examples/lerobot_pose/lerobot.py new file mode 100644 index 00000000000..b2abc4148a1 --- /dev/null +++ b/examples/lerobot_pose/lerobot.py @@ -0,0 +1,305 @@ +# ruff: noqa +"""LeRobot Dataset v3.0 helpers for `daft.datasets`. + +NOTE: This is a vendored copy of `daft/datasets/lerobot.py` from the Daft repo, +included here so this example runs against released `daft` wheels. Once the +module ships in a Daft release, delete this file and switch the imports in this +directory to `from daft.datasets import lerobot`. + +This module reads the file-based LeRobot v3 layout (`meta/episodes`, `data`, +`videos`) and exposes episode-level scans plus frame expansion utilities. + +See https://huggingface.co/docs/lerobot/lerobot-dataset-v3 for format details. +""" + +from __future__ import annotations + +import re +import json +from typing import TYPE_CHECKING, Any, TypedDict, cast + +import daft +from daft.api_annotations import PublicAPI +from daft.datatype import DataType +from daft.expressions import col, lit +from daft.file import VideoFile +from daft.exceptions import DaftCoreException +from daft.functions import lpad +from daft.functions.file_ import video_file +from daft.udf import func + +if TYPE_CHECKING: + from daft.daft import IOConfig + from daft.dataframe import DataFrame + + +def _normalize_dataset_root(uri: str) -> str: + """Return a canonical dataset root prefix (no trailing slash) for path joins.""" + u = uri.strip() + # Input looks like a Hugging Face repo ID, i.e. "org/name" + is_hf_repo_id = bool(re.fullmatch(r"[\w.-]+/[\w.-]+", uri)) + + if is_hf_repo_id: + u = f"hf://datasets/{u}" + return u.rstrip("/") + + +@func(return_dtype=DataType.image()) +def _decode_lerobot_video_timestamp( + file: VideoFile, + episode_from_timestamp_s: float, + frame_timestamp_s: float, + tolerance_s: float, + image_width_i: int, + image_height_i: int, +): + """Pick the decoded frame closest in time to ``episode_from_timestamp_s + frame_timestamp_s``.""" + try: + import av as av_mod + except ImportError as err: + raise ImportError("Decoding LeRobot MP4 shards requires PyAV. Install with `pip install av`.") from err + try: + from PIL import Image as PILImage + except ImportError as err: + raise ImportError( + "Decoding LeRobot MP4 shards requires Pillow. Install with `pip install daft[video]` or `pip install pillow`." + ) from err + abs_ts = float(episode_from_timestamp_s) + float(frame_timestamp_s) + tolerance = float(tolerance_s) + width_i = int(image_width_i) + height_i = int(image_height_i) + width: int | None + height: int | None + if width_i > 0 and height_i > 0: + width, height = width_i, height_i + else: + width, height = None, None + + loaded: list[tuple[float, Any]] = [] + decode_cap = 20_000 + decoded = 0 + + with file.open() as f_open: + with av_mod.open(f_open) as container: + stream = container.streams.video[0] + # Match LeRobot: seek backwards to preceding keyframe, then decode forwards. + container.seek(max(0, int(abs_ts / av_mod.time_base)), backward=True) + + tail_s = max(0.1, tolerance * 50.0, 1.0 / 24.0) + for vf in container.decode(stream): + if vf.pts is None: + continue + current_ts = float(vf.pts * stream.time_base) + pil_img = PILImage.fromarray(vf.to_ndarray(format="rgb24"), mode="RGB") + if width is not None and height is not None: + pil_img = pil_img.resize((width, height), PILImage.Resampling.NEAREST) + + loaded.append((current_ts, pil_img)) + decoded += 1 + + if decoded >= decode_cap: + raise ValueError("Exceeded decode frame budget while aligning to parquet timestamps.") + if current_ts >= abs_ts + tail_s: + break + + if not loaded: + raise ValueError(f"No frames decoded from shard while seeking timestamp {abs_ts:.6f}s.") + + closest_ts, closest_img = min(loaded, key=lambda item: abs(item[0] - abs_ts)) + closest_dist = abs(closest_ts - abs_ts) + if closest_dist > tolerance: + raise ValueError( + f"No frame matched timestamp {abs_ts:.6f}s within tolerance {tolerance} " + f"(closest distance observed: {closest_dist})." + ) + return closest_img + +class Feature(TypedDict): + dtype: str + +class LeRobotInfo(TypedDict): + codebase_version: str + data_path: str + video_path: str + fps: float + features: dict[str, Feature] + +def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobotInfo: + with daft.open_file(f"{normalized_uri}/meta/info.json", io_config=io_config) as f: + info = cast(LeRobotInfo, json.load(f)) + if info["codebase_version"] != "v3.0": + raise ValueError("`daft.datasets.lerobot` currently only supports LeRobot datasets of v3 and above") + return info + +@PublicAPI +def read( + dataset_uri: str, + io_config: IOConfig | None = None, + include_stats: bool = False, + load_video_frames: str | list[str] | bool = False, +) -> DataFrame: + """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per frame with episode metadata).""" + root = _normalize_dataset_root(dataset_uri) + info = _read_info(root, io_config=io_config) + + # Keep the per-episode video metadata (notably `videos/{key}/from_timestamp`, + # the time within the shard where each episode's footage begins). We need it + # to translate episode-local frame timestamps into absolute shard timestamps + # when decoding, and drop these internal columns again before returning. + episode_df = read_episodes( + dataset_uri, io_config=io_config, include_stats=include_stats, include_video_metadata=True + ) + df = load_episode_frames(episode_df, dataset_uri, io_config=io_config) + + # Load video frames into memory + if load_video_frames is not False: + if load_video_frames is True: + video_keys = [name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video"] + elif isinstance(load_video_frames, str): + video_keys = [load_video_frames] + elif isinstance(load_video_frames, list) and all(isinstance(k, str) for k in load_video_frames): + video_keys = load_video_frames + else: + raise ValueError(f"Invalid value provided for argument load_video_frames=`{load_video_frames}`") + + # An MP4 shard packs many episodes back to back, so the shard's internal + # frame numbering is NOT the parquet's episode-local `frame_index` (which + # resets to 0 each episode). Seeking by `frame_index` only happens to work + # for the first episode in each shard. Instead, seek by absolute timestamp: + # `from_timestamp` (where this episode begins in the shard) + the per-frame + # episode-local `timestamp`. That keeps a single coordinate system end to end. + fps = float(info["fps"]) + tolerance_s = 1.0 / fps / 2.0 # half a frame period: any closer frame is unambiguously "the" frame + + # To increase parallelism, reduce batch size + df = df.into_batches(16) # TODO (for later): Set it in the batch UDF instead? + for k in video_keys: + # TODO (for later): Optimize by using a batch UDF to avoid opening the same video multiple times + df = df.with_column( + k, + _decode_lerobot_video_timestamp( + col(f"videos/{k}/video"), + col(f"videos/{k}/from_timestamp"), + col("timestamp"), + lit(tolerance_s), + lit(0), # image_width: 0 disables resize (decode at native resolution) + lit(0), # image_height: 0 disables resize + ), + ) + df = df.exclude(f"videos/{k}/video") + + # TODO: What about raw images, what do i do about them? Is that a thing in LeRobot v3 + + # Drop the internal per-episode video metadata we kept above (chunk/file index, + # from/to timestamp). This restores read_episodes' default of hiding these. + df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) + + return df + + +@PublicAPI +def read_episodes( + dataset_uri: str, + io_config: IOConfig | None = None, + include_meta: bool = False, + include_stats: bool = False, + include_video_metadata: bool = False, +) -> DataFrame: + """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). + + This reads the `meta/episodes/**/*.parquet` path under the dataset root. + + Args: + dataset_uri: Huggingface repo id (`org/name`), + or a local / remote directory (`s3://...`, `hf://datasets/...`) + io_config: Optional IO configuration for remote reads. + + Returns: + Lazy DataFrame of episode metadata. + """ + root = _normalize_dataset_root(dataset_uri) + info = _read_info(root, io_config=io_config) + + # TODO: What is the `meta` episodes into used for? How is it different from the `videos` info? + df = daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) + if not include_meta: + df = df.exclude(*(c for c in df.column_names if c.startswith("meta/"))) + if not include_stats: + df = df.exclude(*(c for c in df.column_names if c.startswith("stats/"))) + + # Get the video keys + video_keys = set(name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video") + + for key in video_keys: + file_name_expr = ( + lit(f"{root}/videos/{key}/chunk-") + + lpad(col(f"videos/{key}/chunk_index").cast(DataType.string), 3, "0") + + lit("/file-") + + lpad(col(f"videos/{key}/file_index").cast(DataType.string), 3, "0") + + lit(".mp4") + ) + + df = df.with_column(f"videos/{key}/video", video_file(file_name_expr, verify=False, io_config=io_config)) + + if not include_video_metadata: + df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) + + return df + + +@PublicAPI +def load_episode_frames( + episodes: DataFrame, + dataset_uri: str, + io_config: IOConfig | None = None, +) -> DataFrame: + """Expand an episode-level DataFrame into a frame-level DataFrame. + + Reads the per-frame parquet under ``data/**`` and joins it to the provided + episode metadata on ``episode_index``, producing one row per frame. Episode + metadata is broadcast across each episode's frames. + + Filter ``episodes`` before calling this to expand only the episodes you need; + only the surviving episodes contribute to the join. + + Args: + episodes: Episode-level DataFrame, typically from :func:`read_episodes` + (optionally filtered). Must contain an ``episode_index`` column. + dataset_uri: The same dataset identifier passed to :func:`read_episodes` + (Huggingface repo id ``org/name``, or a local / remote directory such + as ``s3://...`` or ``hf://datasets/...``). + io_config: Optional IO configuration for remote reads. + + Returns: + Lazy DataFrame with one row per frame. + """ + root = _normalize_dataset_root(dataset_uri) + + frame_df = daft.read_parquet(f"{root}/data/**", io_config=io_config) + df = episodes.join(frame_df, on=["episode_index"]) + df = df.exclude("data/chunk_index", "data/file_index") + return df + + +@PublicAPI +def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: + """Load task metadata as a DataFrame. + + Prefers ``meta/tasks.parquet`` (current LeRobot default). Falls back to legacy + ``meta/tasks.jsonl`` when the Parquet file is missing. + """ + root = _normalize_dataset_root(dataset_uri) + + pq_url = f"{root}/meta/tasks.parquet" + try: + return daft.read_parquet(pq_url, io_config=io_config) + except (OSError, DaftCoreException, FileNotFoundError): + return daft.read_json(f"{root}/meta/tasks.jsonl", io_config=io_config) + + +__all__ = [ + "load_episode_frames", + "read", + "read_episodes", + "read_tasks", +] diff --git a/examples/lerobot_pose/predict_poses.py b/examples/lerobot_pose/predict_poses.py new file mode 100644 index 00000000000..ca694296cd2 --- /dev/null +++ b/examples/lerobot_pose/predict_poses.py @@ -0,0 +1,269 @@ +# /// script +# description = "Run H-RDT action prediction over a LeRobot dataset with Daft" +# requires-python = ">=3.12, <3.13" +# dependencies = [ +# "daft>=0.7.15", +# "torch", +# "torchvision", +# "timm", +# "transformers", +# "sentencepiece", +# "diffusers", +# "huggingface-hub", +# "av", +# "pillow", +# "opencv-python", +# "pyyaml", +# "numpy", +# ] +# /// +"""Predict 48-D hand actions for every frame of an EgoDex LeRobot dataset. + +The pipeline is as follows: 1) we load in a lazy Daft DataFrame +with 1 row per frame. Each row contains the decoded camera image, the +48-D observation state, and the episode/task metadata. 2) we wrap the +H-RDT model as a Daft class, which allows for a persistent state in which +each worker builds the model once in __init__ and reuses it for every batch. +3) we append the predicted step per row as a new column in the DataFrame. +4) we write the DataFrame to a parquet file. + +The model predicts a chunk of 16 future actions. However we keep only the first step of each chunk, +so the new column is a single 48-D float vector per frame. + +Run `encode_task_embeddings.py` first to cache the T5 task embeddings, then: + + uv run predict_poses.py +""" + +# ruff: noqa: E402 -- sys.path must be extended before the `models.*` imports below +import os +import sys + +# Let any operator MPS doesn't implement fall back to CPU instead of crashing. +# Must be set before torch initializes the MPS backend. +os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") + +HRDT_ROOT = os.environ.get( + "HRDT_PROJECT_ROOT", + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../H_RDT")), +) # the HRDT model repo must be cloned either in the directory specified by env or at the sibling level of the current repo +sys.path.append(HRDT_ROOT) + +import json + +import lerobot # vendored copy of daft.datasets.lerobot +import numpy as np +import torch +import yaml +from PIL import Image as PILImage + +# NOTE: the H_RDT imports (`from models...`) deliberately live inside +# HRDTPredictor.__init__, NOT here. Daft pickles the class to ship it to worker +# processes; pip-installed packages pickle by reference, but H_RDT is only +# importable via the sys.path hack above, so pickle tries to serialize its +# classes by value and fails ("cannot pickle 'GenericModule' object"). +# Importing inside __init__ defers the import to the worker process instead. +import daft +from daft import DataType, col + +DATASET_URI = "pepijn223/egodex-test" +OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_predictions") +TASK_EMBEDDINGS_DIR = os.path.join(os.path.dirname(__file__), "task_embeddings") + +# The EgoDex-pretrained checkpoint published in the H-RDT model repo. +CHECKPOINT_PATH = os.path.join(HRDT_ROOT, "checkpoints", "pretrain-0618", "checkpoint-500000") +CONFIG_PATH = os.path.join(HRDT_ROOT, "configs", "hrdt_pretrain.yaml") +STAT_PATH = os.path.join(HRDT_ROOT, "datasets", "pretrain", "egodex_stat.json") + +HAS_GPU = torch.cuda.is_available() +PREDICT_BATCH_SIZE = 16 +# For a quick local trial, set MAX_FRAMES to only predict the first N frames +# (e.g. `MAX_FRAMES=8 uv run predict_hrdt.py`). 0 means the whole dataset. +MAX_FRAMES = int(os.environ.get("MAX_FRAMES", "0")) +NUM_GPUS = 1 if HAS_GPU else 0 + +# Resolve the torch device HERE, at module level, not inside the class: +# `torch.backends.mps` is a property-object that cloudpickle cannot serialize, +# and Daft pickles the class (including everything its methods reference) to +# ship it to workers. A plain string global pickles fine. +if HAS_GPU: + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" # Apple-silicon GPU via Metal +else: + DEVICE = "cpu" + +@daft.cls(gpus=NUM_GPUS, max_concurrency=1) +class HRDTPredictor: + def __init__(self, ckpt_path: str, config_path: str, stat_path: str, embeddings_dir: str): + # Runs once per replica, on the worker, when execution starts. + # H_RDT imports happen here so they're never pickled (see note at top). + from models.encoder.dinosiglip_vit import DinoSigLIPViTBackbone # from H_RDT + from models.hrdt_runner import HRDTRunner # from H_RDT + + self.device = torch.device(DEVICE) + print(f"Using device: {self.device}") + + if self.device.type == "cuda": + self.dtype = torch.bfloat16 + else: + # float32 on MPS and CPU: bfloat16 support on MPS is still patchy. + self.dtype = torch.float32 + + # task_index -> (num_tokens, 4096) T5 embedding, cached by + # encode_task_embeddings.py. Loaded FIRST so a missing/empty cache fails + # in milliseconds, before the minutes-long model weight loading below. + self.lang_embeddings: dict[int, torch.Tensor] = {} + try: + fnames = sorted(os.listdir(embeddings_dir)) + except FileNotFoundError: + fnames = [] # missing dir is handled the same as an empty one, below + for fname in fnames: + if fname.endswith(".pt"): + payload = torch.load(os.path.join(embeddings_dir, fname), map_location="cpu") + self.lang_embeddings[int(payload["task_index"])] = payload["embeddings"].to( + self.device, dtype=self.dtype + ) + if not self.lang_embeddings: + raise FileNotFoundError( + f"No task embeddings found in {embeddings_dir}. Run encode_task_embeddings.py first." + ) + + with open(config_path) as f: + config = yaml.safe_load(f) + + + self.vision_encoder = DinoSigLIPViTBackbone( + vision_backbone_id="dino-siglip", + image_resize_strategy="letterbox" + if config["dataset"]["image_aspect_ratio"] == "pad" + else "resize-naive", + default_image_size=384, + ) + self.vision_encoder.to(self.device, dtype=self.dtype).eval() + self.image_transform = self.vision_encoder.get_image_transform() + + common = config["common"] + self.pred_horizon = common["action_chunk_size"] + self.policy = HRDTRunner( + state_dim=common["state_dim"], + action_dim=common["action_dim"], + pred_horizon=self.pred_horizon, + config=config["model"], + act_pos_emb_config=[ + ("state", 1), + ("action", self.pred_horizon), + ], + img_pos_emb_config=[ + ("image", (common["img_history_size"], common["num_cameras"], -self.vision_encoder.num_patches)), + ], + lang_pos_emb_config=[ + ("language", -config["dataset"]["tokenizer_max_length"]), + ], + max_img_len=common["img_history_size"] * common["num_cameras"] * self.vision_encoder.num_patches, + max_lang_len=config["dataset"]["tokenizer_max_length"], + training_mode="lang", + mode="pretrain", + dtype=self.dtype, + ) + state_dict = torch.load( + os.path.join(ckpt_path, "pytorch_model.bin"), map_location="cpu", weights_only=True + ) + + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("video_adapter.")} + self.policy.load_state_dict(state_dict) + self.policy.to(self.device, dtype=self.dtype).eval() + + with open(stat_path) as f: + stat = json.load(f)["egodex"] + self.action_min = np.array(stat["min"], dtype=np.float32) + self.action_max = np.array(stat["max"], dtype=np.float32) + + @daft.method.batch( + return_dtype=DataType.embedding(DataType.float32(), 48), + batch_size=PREDICT_BATCH_SIZE, + ) + def predict(self, images: daft.Series, states: daft.Series, task_indices: daft.Series): + """Predict the next 48-D action for a batch of frames. + + Batch methods receive whole columns as `daft.Series` so we can run the + vision encoder and the policy once per batch instead of once per row. + """ + image_arrays = images.to_pylist() # list of HWC uint8 numpy arrays + state_batch = np.asarray(states.to_pylist(), dtype=np.float32) # (B, 48) + task_batch = task_indices.to_pylist() + batch_size = len(image_arrays) + + with torch.no_grad(): + # State: normalize to [-1, 1] exactly like pretraining did. + normalized = (state_batch - self.action_min) / (self.action_max - self.action_min) * 2 - 1 + state_tokens = ( + torch.from_numpy(np.clip(normalized, -1, 1)) + .reshape(batch_size, 1, -1) + .to(self.device, dtype=self.dtype) + ) + + # Images: letterbox + normalize each frame, then encode the whole + # batch in one DinoSigLIP forward pass. + transformed = [self.image_transform(PILImage.fromarray(arr)) for arr in image_arrays] + image_inputs = { + key: torch.stack([t[key] for t in transformed]).to(self.device, dtype=self.dtype) + for key in transformed[0] + } + image_features = self.vision_encoder(image_inputs) # (B, num_patches, embed_dim) + image_tokens = image_features.view(batch_size, -1, self.vision_encoder.embed_dim) + + # Language: look up each row's cached T5 embedding and pad to the + # longest one in the batch, with an attention mask marking padding. + embeds = [self.lang_embeddings[int(idx)] for idx in task_batch] + max_len = max(e.shape[0] for e in embeds) + lang_tokens = torch.zeros(batch_size, max_len, embeds[0].shape[1], device=self.device, dtype=self.dtype) + lang_attn_mask = torch.zeros(batch_size, max_len, device=self.device, dtype=torch.bool) + for i, e in enumerate(embeds): + lang_tokens[i, : e.shape[0]] = e + lang_attn_mask[i, : e.shape[0]] = True + + action_pred = self.policy.predict_action( + state_tokens=state_tokens, + image_tokens=image_tokens, + lang_tokens=lang_tokens, + lang_attn_mask=lang_attn_mask, + ) # (B, pred_horizon, 48), normalized to [-1, 1] + + chunk = action_pred.float().cpu().numpy() + # Denormalize (inverse of the [-1, 1] scaling) and keep only the + # first step of each predicted 16-step chunk. + denorm = (chunk + 1) / 2 * (self.action_max - self.action_min) + self.action_min + return [row for row in denorm[:, 0, :].astype(np.float32)] + + +if __name__ == "__main__": + predictor = HRDTPredictor(CHECKPOINT_PATH, CONFIG_PATH, STAT_PATH, TASK_EMBEDDINGS_DIR) + + df = lerobot.read(DATASET_URI, load_video_frames=True) + if MAX_FRAMES: + df = df.limit(MAX_FRAMES) + + df = ( + df.with_column( + "predicted_action", + predictor.predict(col("observation.image"), col("observation.state"), col("task_index")), + ) + # Keep the trajectory data and identifiers; drop the decoded frames so the + # output stays small (the images are reproducible from the dataset anyway). + .select( + "episode_index", + "frame_index", + "timestamp", + "task_index", + col("observation.state"), + col("action").alias("ground_truth_action"), + "predicted_action", + ) + ) + + df.write_parquet(OUTPUT_DIR) + print(f"Wrote predictions to {OUTPUT_DIR}") + print("Score them with: uv run compute_metrics.py") + + daft.read_parquet(f"{OUTPUT_DIR}/**").show(8) \ No newline at end of file diff --git a/examples/lerobot_pose/visualize_predictions.py b/examples/lerobot_pose/visualize_predictions.py new file mode 100644 index 00000000000..2ec5584a116 --- /dev/null +++ b/examples/lerobot_pose/visualize_predictions.py @@ -0,0 +1,166 @@ +# /// script +# description = "Overlay predicted vs ground-truth 48-D hand poses on EgoDex video frames" +# requires-python = ">=3.12, <3.13" +# dependencies = [ +# "daft>=0.7.15", +# "av", +# "pillow", +# "numpy", +# ] +# /// +"""Project predicted and ground-truth hand poses onto the camera frames. + +The 48-D vectors are 3D *world-frame* points (per hand: wrist position, wrist +6D rotation, 5 fingertip positions — see H_RDT's precompute_48d_actions.py). +To draw them on a frame we follow Apple's reference visualizer +(ml-egodex/visualize_2d.py): + + 1. world -> camera: multiply by the inverse of that frame's camera pose + (`observation.extrinsics`, a 4x4 matrix). + 2. camera -> pixels: pinhole projection u = fx*X/Z + cx, v = fy*Y/Z + cy, + with EgoDex's constant intrinsics (fx = fy = 736.6339, cx = 960, cy = 540). + +Ground truth is drawn in green, the model's prediction in red. Note Apple's +caveat: Vision Pro video is synthesized from multiple cameras, so even +ground-truth reprojections can be a few pixels off the visible hands. + +Run after predict_hrdt.py: + + uv run visualize_predictions.py +""" + +import os + +import lerobot # vendored copy of daft.datasets.lerobot +import numpy as np +from PIL import Image, ImageDraw + +import daft +from daft import col + +DATASET_URI = "pepijn223/egodex-test" +PREDICTIONS_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_predictions") +OVERLAYS_DIR = os.path.join(os.path.dirname(__file__), "out", "overlays") + +# EgoDex camera intrinsics (constant across the dataset, from apple/ml-egodex). +FX = FY = 736.6339 +CX, CY = 960.0, 540.0 + +GROUND_TRUTH_COLOR = (0, 220, 0) # green +PREDICTION_COLOR = (255, 40, 40) # red + +# Playback speed of the stitched per-episode mp4s. The dataset is 30 fps, so +# 15 fps plays at half speed (e.g. 30 predicted frames -> a 2-second video). +VIDEO_FPS = int(os.environ.get("VIDEO_FPS", "15")) + + +def hand_points(vec48: np.ndarray, side: int) -> np.ndarray: + """Extract the 6 drawable 3D points of one hand (wrist + 5 fingertips). + + Layout per hand (24 dims): [0:3] wrist xyz, [3:9] wrist 6D rotation + (not drawable as a point, skipped), [9:24] thumb/index/middle/ring/little + fingertip xyz. side: 0 = left hand (dims 0-23), 1 = right hand (24-47). + """ + base = side * 24 + wrist = vec48[base : base + 3] + fingertips = vec48[base + 9 : base + 24].reshape(5, 3) + return np.vstack([wrist, fingertips]) # (6, 3): wrist first + + +def project_to_pixels(points_world: np.ndarray, extrinsics: np.ndarray) -> np.ndarray: + """World-frame 3D points -> (u, v) pixel coordinates (NaN if behind camera).""" + cam_from_world = np.linalg.inv(extrinsics) + homogeneous = np.hstack([points_world, np.ones((len(points_world), 1))]) # (N, 4) + in_camera = (cam_from_world @ homogeneous.T).T[:, :3] + x, y, z = in_camera[:, 0], in_camera[:, 1], in_camera[:, 2] + with np.errstate(divide="ignore", invalid="ignore"): + u = FX * x / z + CX + v = FY * y / z + CY + uv = np.stack([u, v], axis=1) + uv[z <= 0] = np.nan # behind the camera + return uv + + +def write_episode_video(episode: int, frames: list[tuple[int, np.ndarray]], fps: int) -> None: + """Encode one episode's overlay frames (sorted by frame_index) into an mp4. + + Uses PyAV, which bundles its own FFmpeg — no ffmpeg binary required. + """ + import av + + frames = sorted(frames, key=lambda pair: pair[0]) + out_path = os.path.join(OVERLAYS_DIR, f"episode{episode:03d}.mp4") + height, width = frames[0][1].shape[:2] + container = av.open(out_path, mode="w") + stream = container.add_stream("h264", rate=fps) + stream.width, stream.height = width, height + stream.pix_fmt = "yuv420p" + for _, image_array in frames: + for packet in stream.encode(av.VideoFrame.from_ndarray(image_array, format="rgb24")): + container.mux(packet) + for packet in stream.encode(): # flush buffered frames + container.mux(packet) + container.close() + print(f"Saved {out_path} ({len(frames)} frames @ {fps} fps = {len(frames) / fps:.2f}s)") + + +def draw_skeleton(draw: ImageDraw.ImageDraw, vec48: np.ndarray, extrinsics: np.ndarray, color: tuple) -> None: + """Draw both hands of one 48-D pose: wrist dot + lines fanning to fingertips.""" + for side in (0, 1): + uv = project_to_pixels(hand_points(vec48, side), extrinsics) + if np.isnan(uv).any(): + continue + wrist, fingertips = uv[0], uv[1:] + for tip in fingertips: + draw.line([tuple(wrist), tuple(tip)], fill=color, width=3) + draw.ellipse([tip[0] - 7, tip[1] - 7, tip[0] + 7, tip[1] + 7], fill=color) + draw.ellipse([wrist[0] - 11, wrist[1] - 11, wrist[0] + 11, wrist[1] + 11], fill=color) + + +if __name__ == "__main__": + # 1. Load the predictions and note which frames they belong to. + preds = daft.read_parquet(f"{PREDICTIONS_DIR}/**").select( + "episode_index", "frame_index", "ground_truth_action", "predicted_action" + ) + pred_rows = preds.to_pydict() + episodes = sorted(set(pred_rows["episode_index"])) + frames = sorted(set(pred_rows["frame_index"])) + print(f"Found {len(pred_rows['frame_index'])} predicted frames: episodes {episodes}, frames {frames}") + + # 2. Re-read just those frames from the dataset, keeping the image and the + # per-frame camera pose. The .where filter pushes down past the video + # decoder, so only the frames we need are decoded. + frames_df = ( + lerobot.read(DATASET_URI, load_video_frames=True) + .where(col("episode_index").is_in(episodes) & col("frame_index").is_in(frames)) + .select("episode_index", "frame_index", col("observation.image"), col("observation.extrinsics")) + ) + + # 3. Join images to predictions on (episode, frame) and pull the handful of + # rows into plain Python for drawing. + rows = frames_df.join(preds, on=["episode_index", "frame_index"]).to_pydict() + + os.makedirs(OVERLAYS_DIR, exist_ok=True) + episode_frames: dict[int, list[tuple[int, np.ndarray]]] = {} + for i in range(len(rows["frame_index"])): + episode = rows["episode_index"][i] + frame = rows["frame_index"][i] + extrinsics = np.array(rows["observation.extrinsics"][i], dtype=np.float64).reshape(4, 4) + ground_truth = np.array(rows["ground_truth_action"][i], dtype=np.float64) + predicted = np.array(rows["predicted_action"][i], dtype=np.float64) + + image = Image.fromarray(np.asarray(rows["observation.image"][i])) + draw = ImageDraw.Draw(image) + draw_skeleton(draw, ground_truth, extrinsics, GROUND_TRUTH_COLOR) + draw_skeleton(draw, predicted, extrinsics, PREDICTION_COLOR) + draw.text((20, 20), "green = ground truth red = predicted", fill=(255, 255, 255)) + + out_path = os.path.join(OVERLAYS_DIR, f"episode{episode:03d}_frame{frame:05d}.png") + image.save(out_path) + print(f"Saved {out_path}") + episode_frames.setdefault(episode, []).append((frame, np.asarray(image))) + + # Stitch each episode's overlays into a watchable mp4. The join above does + # not guarantee row order, so write_episode_video sorts by frame_index. + for episode, frames in sorted(episode_frames.items()): + write_episode_video(episode, frames, VIDEO_FPS) From 754326cbe5101fac7ae3d0852c5755529e058048 Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Thu, 11 Jun 2026 14:48:59 -0700 Subject: [PATCH 07/11] test(lerobot): update tests to the renamed reader API The module's public surface changed (episodes -> read_episodes, read_info/read_stats folded into include_meta/include_stats kwargs, new read() entry point, video decode moved from load_episode_frames flags to read(load_video_frames=...)), but the tests still imported the old names, failing at collection. - rename call sites to read_episodes / load_episode_frames(ep, uri) - replace the read_info/read_stats test with coverage for the include_meta / include_stats column toggles - add a read() frame-level test and a v2-dataset rejection test - port the two video decode tests to read(load_video_frames=...), exercising the new timestamp-based frame matching 8 tests, all passing locally with DAFT_RUNNER=native. --- tests/datasets/test_lerobot.py | 108 +++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 40 deletions(-) diff --git a/tests/datasets/test_lerobot.py b/tests/datasets/test_lerobot.py index 05acb56fdd0..81e451a7dc4 100644 --- a/tests/datasets/test_lerobot.py +++ b/tests/datasets/test_lerobot.py @@ -9,7 +9,7 @@ import pytest import daft -from daft.datasets.lerobot import episodes, load_episode_frames, read_info, read_stats, read_tasks +from daft.datasets.lerobot import load_episode_frames, read, read_episodes, read_tasks def _write_table(path, table: pa.Table) -> None: @@ -33,6 +33,9 @@ def tiny_lerobot_v3(tmp_path): "data/file_index": [0, 0], "dataset_from_index": [0, 2], "dataset_to_index": [2, 4], + # Hidden by default; exposed via include_meta / include_stats. + "meta/episode_uuid": ["a", "b"], + "stats/action_mean": [0.1, 0.2], } ) _write_table(root / "meta/episodes/chunk-000/file-000.parquet", episodes_tbl) @@ -51,6 +54,8 @@ def tiny_lerobot_v3(tmp_path): info = { "codebase_version": "v3.0", "fps": 30, + "data_path": "data/chunk-{chunk_index:03d}/file-{file_index:03d}.parquet", + "video_path": "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4", "features": {}, "total_episodes": 2, "total_frames": 4, @@ -58,7 +63,6 @@ def tiny_lerobot_v3(tmp_path): } (root / "meta").mkdir(parents=True, exist_ok=True) (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") - (root / "meta" / "stats.json").write_text(json.dumps({"ok": True}), encoding="utf-8") tasks_tbl = pa.table({"task_index": [0], "task": ["pick"]}) _write_table(root / "meta" / "tasks.parquet", tasks_tbl) @@ -68,7 +72,7 @@ def tiny_lerobot_v3(tmp_path): @pytest.fixture def tiny_lerobot_v3_video(tmp_path): - """Single-episode dataset with MP4 shards for ``decode_videos`` tests.""" + """Single-episode dataset with an MP4 shard for ``load_video_frames`` tests.""" av = pytest.importorskip("av") pytest.importorskip("PIL", reason="Pillow required for decoded image rows") @@ -81,8 +85,11 @@ def tiny_lerobot_v3_video(tmp_path): video_dir.mkdir(parents=True) shutil.copy(pathlib.Path("tests/assets/sample_video.mp4"), video_dir / "file-000.mp4") - with av.open(video_dir / "file-000.mp4") as c: + # The episode's footage starts at the shard's first frame timestamp; frame + # timestamps in the data parquet are episode-local (start at 0). + with av.open(str(video_dir / "file-000.mp4")) as c: s = c.streams.video[0] + fps = float(s.average_rate) eps_from_ts = None for fr in c.decode(s): if fr.pts is not None: @@ -90,9 +97,8 @@ def tiny_lerobot_v3_video(tmp_path): break assert eps_from_ts is not None - fps = 30 n_frames = 3 - durations = [(i / fps) for i in range(n_frames)] + timestamps = [i / fps for i in range(n_frames)] episodes_tbl = pa.table( { @@ -115,7 +121,7 @@ def tiny_lerobot_v3_video(tmp_path): "index": list(range(n_frames)), "episode_index": [0] * n_frames, "frame_index": list(range(n_frames)), - "timestamp": durations, + "timestamp": timestamps, "task_index": [0] * n_frames, } ) @@ -124,6 +130,8 @@ def tiny_lerobot_v3_video(tmp_path): info = { "codebase_version": "v3.0", "fps": fps, + "data_path": "data/chunk-{chunk_index:03d}/file-{file_index:03d}.parquet", + "video_path": "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4", "features": { video_key: {"dtype": "video"}, }, @@ -132,7 +140,6 @@ def tiny_lerobot_v3_video(tmp_path): "total_tasks": 1, } (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") - (root / "meta" / "stats.json").write_text(json.dumps({"ok": True}), encoding="utf-8") tasks_tbl = pa.table({"task_index": [0], "task": ["pick"]}) _write_table(root / "meta" / "tasks.parquet", tasks_tbl) @@ -140,33 +147,8 @@ def tiny_lerobot_v3_video(tmp_path): return str(root) -def test_load_episode_frames_decode_videos_explicit_key(tiny_lerobot_v3_video): - ep = episodes(tiny_lerobot_v3_video) - df = ( - load_episode_frames(ep, tiny_lerobot_v3_video, decode_videos=True, video_keys=["camera.test"]) - .select("camera.test") - .collect() - ) - assert df.count_rows() == 3 - img0 = df.to_pydict()["camera.test"][0] - if hasattr(img0, "mode"): - assert img0.mode == "RGB" - assert img0.size[0] > 10 and img0.size[1] > 10 - else: - import numpy as np - - assert isinstance(img0, np.ndarray) - assert img0.ndim == 3 and img0.shape[2] >= 3 - - -def test_load_episode_frames_decode_videos_inferred_keys(tiny_lerobot_v3_video): - ep = episodes(tiny_lerobot_v3_video) - df = load_episode_frames(ep, tiny_lerobot_v3_video, decode_videos=True).select("camera.test").collect() - assert df.count_rows() == 3 - - -def test_episodes_and_load_episode_frames(tiny_lerobot_v3): - ep = episodes(tiny_lerobot_v3).sort("episode_index") +def test_read_episodes_and_load_episode_frames(tiny_lerobot_v3): + ep = read_episodes(tiny_lerobot_v3).sort("episode_index") assert ep.count_rows() == 2 frames = load_episode_frames(ep, tiny_lerobot_v3).sort("index") @@ -178,10 +160,34 @@ def test_episodes_and_load_episode_frames(tiny_lerobot_v3): assert f0.to_pydict()["episode_index"] == [0, 0] -def test_read_info_and_stats(tiny_lerobot_v3): - info = read_info(tiny_lerobot_v3) - assert info["total_episodes"] == 2 - assert read_stats(tiny_lerobot_v3) == {"ok": True} +def test_read_frame_rows(tiny_lerobot_v3): + df = read(tiny_lerobot_v3).sort("index").collect() + assert df.count_rows() == 4 + cols = df.column_names + for expected in ("episode_index", "frame_index", "timestamp"): + assert expected in cols + + +def test_read_episodes_meta_and_stats_columns(tiny_lerobot_v3): + default_cols = read_episodes(tiny_lerobot_v3).column_names + assert "meta/episode_uuid" not in default_cols + assert "stats/action_mean" not in default_cols + + with_meta = read_episodes(tiny_lerobot_v3, include_meta=True).column_names + assert "meta/episode_uuid" in with_meta + + with_stats = read_episodes(tiny_lerobot_v3, include_stats=True).column_names + assert "stats/action_mean" in with_stats + + +def test_read_episodes_rejects_non_v3(tmp_path): + root = tmp_path / "ds_old" + (root / "meta").mkdir(parents=True) + info = {"codebase_version": "v2.1", "fps": 30, "features": {}} + (root / "meta" / "info.json").write_text(json.dumps(info), encoding="utf-8") + + with pytest.raises(ValueError, match="v3"): + read_episodes(str(root)) def test_read_tasks_parquet(tiny_lerobot_v3): @@ -190,5 +196,27 @@ def test_read_tasks_parquet(tiny_lerobot_v3): def test_read_episodes_has_no_dataset_root_column(tiny_lerobot_v3): - ep = episodes(tiny_lerobot_v3) + ep = read_episodes(tiny_lerobot_v3) assert "lerobot_dataset_root" not in ep.column_names + + +def test_read_load_video_frames_explicit_key(tiny_lerobot_v3_video): + df = read(tiny_lerobot_v3_video, load_video_frames=["camera.test"]).select("camera.test").collect() + assert df.count_rows() == 3 + img0 = df.to_pydict()["camera.test"][0] + if hasattr(img0, "mode"): + assert img0.mode == "RGB" + assert img0.size[0] > 10 and img0.size[1] > 10 + else: + import numpy as np + + assert isinstance(img0, np.ndarray) + assert img0.ndim == 3 and img0.shape[2] >= 3 + + +def test_read_load_video_frames_inferred_keys(tiny_lerobot_v3_video): + df = read(tiny_lerobot_v3_video, load_video_frames=True).collect() + assert df.count_rows() == 3 + assert "camera.test" in df.column_names + # Internal per-episode video metadata must not leak into the result. + assert not any(c.startswith("videos/") for c in df.column_names) From 820b0f8c03297cd0982563e2117bd28a0c210e8a Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Thu, 11 Jun 2026 16:07:01 -0700 Subject: [PATCH 08/11] chore(examples): move H-RDT example out of this PR Keep this PR scoped to the daft.datasets.lerobot reader itself. The end-to-end H-RDT pose prediction example (prediction, metrics, visualization scripts) moves to the daft-examples repository. --- examples/lerobot_pose/README.md | 108 ------- examples/lerobot_pose/compute_metrics.py | 82 ----- .../lerobot_pose/encode_task_embeddings.py | 108 ------- examples/lerobot_pose/lerobot.py | 305 ------------------ examples/lerobot_pose/predict_poses.py | 269 --------------- .../lerobot_pose/visualize_predictions.py | 166 ---------- 6 files changed, 1038 deletions(-) delete mode 100644 examples/lerobot_pose/README.md delete mode 100644 examples/lerobot_pose/compute_metrics.py delete mode 100644 examples/lerobot_pose/encode_task_embeddings.py delete mode 100644 examples/lerobot_pose/lerobot.py delete mode 100644 examples/lerobot_pose/predict_poses.py delete mode 100644 examples/lerobot_pose/visualize_predictions.py diff --git a/examples/lerobot_pose/README.md b/examples/lerobot_pose/README.md deleted file mode 100644 index 67c95578b33..00000000000 --- a/examples/lerobot_pose/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# LeRobot + H-RDT: per-frame 48-D action prediction with Daft -This example reads the [EgoDex test dataset](https://huggingface.co/datasets/pepijn223/egodex-test) (LeRobot v3 format) as a lazy Daft DataFrame — one row per frame, with the decoded camera image and the 48-D hand state — runs the [H-RDT](https://github.com/HongzheBi/H_RDT) policy on every frame, and stores the predicted 48-D action vector as a new column. - -``` -LeRobot v3 dataset (parquet + mp4 shards on HF Hub) - │ lerobot.read(..., load_video_frames=True) # lazy, streaming - ▼ -one row per frame: observation.image · observation.state (48) · task_index · ... - │ HRDTPredictor.predict(...) # @daft.cls batch UDF on GPU - ▼ -+ predicted_action (48-D float vector) - │ write_parquet - ▼ -out/egodex_hrdt_predictions/ -``` - -## Files - -| File | Purpose | -| --- | --- | -| `lerobot.py` | Vendored copy of `daft.datasets.lerobot` (LeRobot v3 reader). Delete once it ships in a Daft release. | -| `encode_task_embeddings.py` | One-time preprocessing: encode each task instruction with T5-XXL and cache it (H-RDT consumes language *embeddings*, not text). | -| `predict_poses.py` | The pipeline: decode frames → batched H-RDT inference → write parquet (pure predictions). | -| `compute_metrics.py` | Score the predictions: per-frame `avg_keypoint_distance_m` (EgoDex paper metric, arXiv:2505.11709 §4.3) + per-episode and overall summaries. Torch-free, re-runnable in seconds. | -| `visualize_predictions.py` | Project predicted vs ground-truth hand poses onto frames; writes PNG overlays and per-episode mp4s. | - -## Setup - -1. **Clone H_RDT** as a sibling of this repo (or set `HRDT_PROJECT_ROOT`): - - ```bash - git clone https://github.com/HongzheBi/H_RDT ../../../H_RDT - ``` - -2. **Download the pretrained weights** into the clone (~8.8 GB). This pulls the - EgoDex pretrain checkpoint (`checkpoints/pretrain-0618/checkpoint-500000`, - ~4.1 GB) and the DinoSigLIP vision backbone weights, skipping the duplicate - safetensors copies the code doesn't load: - - ```bash - uvx --from huggingface_hub hf download embodiedfoundation/H-RDT \ - --include "checkpoints/*" \ - --include "bak/dino-siglip/vit_large_patch14_reg4_dinov2.lvd142m/pytorch_model.bin" \ - --include "bak/dino-siglip/vit_large_patch14_reg4_dinov2.lvd142m/config.json" \ - --include "bak/dino-siglip/vit_so400m_patch14_siglip_384/open_clip_pytorch_model.bin" \ - --include "bak/dino-siglip/vit_so400m_patch14_siglip_384/*.json" \ - --local-dir ../../../H_RDT - ``` - -3. **T5-XXL encoder** downloads automatically (~9.5 GB into the HF cache) the - first time you run `encode_task_embeddings.py`. The default model is - `city96/t5-v1_1-xxl-encoder-bf16`, an encoder-only bfloat16 conversion of - `google/t5-v1_1-xxl` — numerically equivalent for our purposes (H-RDT's own - pipeline ran the encoder in bfloat16) but 4.7x smaller than the official - fp32 encoder+decoder repo. Loading takes ~10 GB RAM, CPU is fine. Set - `T5_MODEL_PATH=google/t5-v1_1-xxl` to use the original instead (~44.5 GB). - -## Run - -```bash -# 1. One-time: cache a T5 embedding per task (3 tasks in egodex-test) -uv run encode_task_embeddings.py - -# 2. Quick local trial: predict only the first 8 frames -MAX_FRAMES=8 uv run predict_poses.py - -# 3. Full run: predict an action for every frame and write parquet -uv run predict_poses.py - -# 4. Score the predictions (EgoDex paper's 12-keypoint metric) — re-runnable -uv run compute_metrics.py - -# 5. Render overlay PNGs and per-episode mp4s of predicted vs ground truth -uv run visualize_predictions.py -``` - -## How it works - -- **Lazy frames.** `lerobot.read` only builds a plan. Execution streams shard - downloads, video decoding, inference, and writing — the full dataset never - sits in memory. -- **`@daft.cls` for the model.** Daft constructs the predictor once per worker - process (loading ~9 GB of weights in `__init__`) and reuses it for every - batch. A plain function UDF would have nowhere to keep the loaded model. -- **Batched inference.** `@daft.method.batch(batch_size=16)` hands the UDF - whole columns (`daft.Series`) at a time, so DinoSigLIP and the policy run one - forward pass per 16 frames instead of per frame. -- **Concurrency.** Frame decoding (CPU) is fanned out across cores by Daft — - the reader splits work with `into_batches(16)` — and overlaps with inference, - which runs on a single model instance (`@daft.cls(gpus=..., max_concurrency=1)`). - To run N concurrent model replicas on one GPU, use fractional GPUs: - `@daft.cls(gpus=1/N, max_concurrency=N)` — each replica holds its own copy of - the weights (~6.5 GB VRAM in bf16), so size VRAM accordingly. -- **Normalization contract.** Following H-RDT's EgoDex pretraining - (`datasets/pretrain/egodex_dataset.py`): the input state is min/max scaled to - `[-1, 1]` using `egodex_stat.json`, and the predicted chunk is denormalized - with the inverse mapping. The model predicts 16 future steps; we keep step 0, - so `predicted_action` is one 48-D vector per frame (24 dims per hand: wrist - pose + finger keypoints). - -## Output - -`out/egodex_hrdt_predictions/` — parquet with one row per frame: -`episode_index`, `frame_index`, `timestamp`, `task_index`, -`observation.state`, `ground_truth_action`, `predicted_action` (48-D -`embedding` column), ready for `daft.read_parquet` to evaluate prediction error -against the ground-truth actions. - diff --git a/examples/lerobot_pose/compute_metrics.py b/examples/lerobot_pose/compute_metrics.py deleted file mode 100644 index 8fecdc89bd8..00000000000 --- a/examples/lerobot_pose/compute_metrics.py +++ /dev/null @@ -1,82 +0,0 @@ -# /// script -# description = "Compute EgoDex-paper keypoint error metrics over H-RDT predictions" -# requires-python = ">=3.12, <3.13" -# dependencies = [ -# "daft>=0.7.15", -# "numpy", -# ] -# /// -"""Score the predictions written by predict_hrdt.py. - -Kept separate from the prediction pipeline on purpose: predictions cost -minutes of model time, metrics cost milliseconds. Splitting them means you can -re-score (or add new metrics) without re-running the model, and this script's -environment needs no torch at all. - -Reads `out/egodex_hrdt_predictions/`, writes per-frame metrics to -`out/egodex_hrdt_metrics/`, and prints overall + per-episode summaries. - - uv run compute_metrics.py -""" - -import os - -import numpy as np - -import daft -from daft import DataType, col - -PREDICTIONS_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_predictions") -METRICS_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_metrics") - - -@daft.func(return_dtype=DataType.float64()) -def avg_keypoint_distance_m(predicted_action: list[float], ground_truth_action: list[float]) -> float: - """EgoDex paper's trajectory-prediction metric (arXiv:2505.11709, Sec 4.3). - - "Euclidean distance between predicted 3D keypoint positions and their - ground truth 3D counterparts, averaged over ... each of the 12 keypoints" - (both wrists + all 10 fingertips), in meters. The 2x6 wrist-rotation dims - are excluded: they are unitless rotation-matrix columns. - - Ours is the metric at a 1-step horizon with K=1: the paper averages over - every timestep of the predicted chunk and scores the best of K sampled - trajectories, while we keep only the chunk's first step and sample once. - """ - predicted = np.asarray(predicted_action, dtype=np.float64) - ground_truth = np.asarray(ground_truth_action, dtype=np.float64) - distances = [] - for base in (0, 24): # left hand dims 0-23, right hand dims 24-47 - keypoint_starts = [base] + [base + 9 + 3 * i for i in range(5)] # wrist, then 5 fingertips - for start in keypoint_starts: - distances.append(np.linalg.norm(predicted[start : start + 3] - ground_truth[start : start + 3])) - return float(np.mean(distances)) - - -if __name__ == "__main__": - metrics = ( - daft.read_parquet(f"{PREDICTIONS_DIR}/**") - .with_column( - "avg_keypoint_distance_m", - avg_keypoint_distance_m(col("predicted_action"), col("ground_truth_action")), - ) - # Keep only identifiers + the metric; the predictions stay in their own files. - .select("episode_index", "frame_index", "timestamp", "task_index", "avg_keypoint_distance_m") - ) - - metrics.write_parquet(METRICS_DIR) - print(f"Wrote per-frame metrics to {METRICS_DIR}\n") - - results = daft.read_parquet(f"{METRICS_DIR}/**") - - print("Per-episode:") - results.groupby("episode_index").agg( - col("avg_keypoint_distance_m").mean().alias("mean_m"), - col("avg_keypoint_distance_m").max().alias("worst_frame_m"), - col("avg_keypoint_distance_m").count().alias("frames"), - ).sort("episode_index").show() - - print("Overall:") - results.select( - col("avg_keypoint_distance_m").mean().alias("dataset_avg_keypoint_distance_m"), - ).show() diff --git a/examples/lerobot_pose/encode_task_embeddings.py b/examples/lerobot_pose/encode_task_embeddings.py deleted file mode 100644 index 9503257c61b..00000000000 --- a/examples/lerobot_pose/encode_task_embeddings.py +++ /dev/null @@ -1,108 +0,0 @@ -# /// script -# description = "Precompute T5 language embeddings for every task in a LeRobot dataset (H-RDT input)" -# requires-python = ">=3.12, <3.13" -# dependencies = [ -# "daft>=0.7.15", -# "torch", -# "transformers", -# "sentencepiece", -# "protobuf", -# "accelerate", -# ] -# /// -"""Encode each task instruction in the dataset with T5 and cache it to disk. - -H-RDT never sees raw text. At train and inference time it consumes language -*embeddings*: the task instruction run through a frozen T5 encoder -(`t5-v1_1-xxl`, 4096-dim features). Encoding is expensive (T5-XXL is an ~11B -parameter model), but a dataset only has a handful of distinct task strings, -so we encode each one exactly once up front and store it as a `.pt` file keyed -by `task_index`. The prediction pipeline then does a dictionary lookup per row -instead of running T5 per row. - -This mirrors H_RDT's `models/encoder/t5_encoder.py` (same model, tokenizer -settings, and bfloat16 dtype) but calls `transformers` directly, because -`T5Embedder` hardcodes the author's local weight path in an assert. Only the -encoder half of T5 is loaded (~4.7B params, ~10 GB RAM in bfloat16), so this -runs fine on CPU. - -Run this once before `predict_hrdt.py`: - - uv run encode_task_embeddings.py -""" - -import os - -import lerobot # vendored copy of daft.datasets.lerobot -import torch - -DATASET_URI = "pepijn223/egodex-test" -# HF id of (or local path to) the T5 model. Must be the XXL variant: H-RDT's -# text adapter expects 4096-dim features. The default is an encoder-only -# bfloat16 conversion of google/t5-v1_1-xxl (~9.5 GB download instead of the -# official repo's 44.5 GB fp32 encoder+decoder). It is numerically equivalent -# here: H-RDT's own T5Embedder loaded the encoder in bfloat16 too. -T5_MODEL_PATH = os.environ.get("T5_MODEL_PATH", "city96/t5-v1_1-xxl-encoder-bf16") -# Matches `tokenizer_max_length` in H_RDT's configs/hrdt_pretrain.yaml. -TOKENIZER_MAX_LENGTH = 1024 -OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "task_embeddings") - - -def load_tasks(dataset_uri: str) -> dict[int, str]: - """Return {task_index: task instruction} from the dataset's `meta/tasks.parquet`.""" - df = lerobot.read_tasks(dataset_uri).collect() - data = df.to_pydict() - # The task string column name varies with how the dataset was exported - # (e.g. pandas writes the task as the index column `__index_level_0__`), - # so find it by dtype rather than by name. - text_col = next(name for name, values in data.items() if values and isinstance(values[0], str)) - return dict(zip(data["task_index"], data[text_col])) - - -def main() -> None: - from transformers import AutoTokenizer, T5EncoderModel - - tasks = load_tasks(DATASET_URI) - print(f"Found {len(tasks)} tasks in {DATASET_URI}:") - for idx, text in sorted(tasks.items()): - print(f" [{idx}] {text}") - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_PATH, model_max_length=TOKENIZER_MAX_LENGTH) - print(f"Loading T5 encoder on {device} (bfloat16)...") - model = ( - T5EncoderModel.from_pretrained( - T5_MODEL_PATH, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, # stream + convert weights instead of loading all fp32 at once - ) - .to(device) - .eval() - ) - - os.makedirs(OUTPUT_DIR, exist_ok=True) - with torch.no_grad(): - for idx, text in sorted(tasks.items()): - tokenized = tokenizer( - [text], - max_length=TOKENIZER_MAX_LENGTH, - padding="longest", - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - embeddings = model( - input_ids=tokenized["input_ids"].to(device), - attention_mask=tokenized["attention_mask"].to(device), - )["last_hidden_state"] - # Trim padding so we only store (and later attend over) real tokens. - num_tokens = int(tokenized["attention_mask"][0].sum()) - trimmed = embeddings[0, :num_tokens].to(torch.float32).cpu() - out_path = os.path.join(OUTPUT_DIR, f"task_{idx:03d}.pt") - torch.save({"task_index": idx, "task": text, "embeddings": trimmed}, out_path) - print(f"Saved {out_path} (shape {tuple(trimmed.shape)})") - - -if __name__ == "__main__": - main() diff --git a/examples/lerobot_pose/lerobot.py b/examples/lerobot_pose/lerobot.py deleted file mode 100644 index b2abc4148a1..00000000000 --- a/examples/lerobot_pose/lerobot.py +++ /dev/null @@ -1,305 +0,0 @@ -# ruff: noqa -"""LeRobot Dataset v3.0 helpers for `daft.datasets`. - -NOTE: This is a vendored copy of `daft/datasets/lerobot.py` from the Daft repo, -included here so this example runs against released `daft` wheels. Once the -module ships in a Daft release, delete this file and switch the imports in this -directory to `from daft.datasets import lerobot`. - -This module reads the file-based LeRobot v3 layout (`meta/episodes`, `data`, -`videos`) and exposes episode-level scans plus frame expansion utilities. - -See https://huggingface.co/docs/lerobot/lerobot-dataset-v3 for format details. -""" - -from __future__ import annotations - -import re -import json -from typing import TYPE_CHECKING, Any, TypedDict, cast - -import daft -from daft.api_annotations import PublicAPI -from daft.datatype import DataType -from daft.expressions import col, lit -from daft.file import VideoFile -from daft.exceptions import DaftCoreException -from daft.functions import lpad -from daft.functions.file_ import video_file -from daft.udf import func - -if TYPE_CHECKING: - from daft.daft import IOConfig - from daft.dataframe import DataFrame - - -def _normalize_dataset_root(uri: str) -> str: - """Return a canonical dataset root prefix (no trailing slash) for path joins.""" - u = uri.strip() - # Input looks like a Hugging Face repo ID, i.e. "org/name" - is_hf_repo_id = bool(re.fullmatch(r"[\w.-]+/[\w.-]+", uri)) - - if is_hf_repo_id: - u = f"hf://datasets/{u}" - return u.rstrip("/") - - -@func(return_dtype=DataType.image()) -def _decode_lerobot_video_timestamp( - file: VideoFile, - episode_from_timestamp_s: float, - frame_timestamp_s: float, - tolerance_s: float, - image_width_i: int, - image_height_i: int, -): - """Pick the decoded frame closest in time to ``episode_from_timestamp_s + frame_timestamp_s``.""" - try: - import av as av_mod - except ImportError as err: - raise ImportError("Decoding LeRobot MP4 shards requires PyAV. Install with `pip install av`.") from err - try: - from PIL import Image as PILImage - except ImportError as err: - raise ImportError( - "Decoding LeRobot MP4 shards requires Pillow. Install with `pip install daft[video]` or `pip install pillow`." - ) from err - abs_ts = float(episode_from_timestamp_s) + float(frame_timestamp_s) - tolerance = float(tolerance_s) - width_i = int(image_width_i) - height_i = int(image_height_i) - width: int | None - height: int | None - if width_i > 0 and height_i > 0: - width, height = width_i, height_i - else: - width, height = None, None - - loaded: list[tuple[float, Any]] = [] - decode_cap = 20_000 - decoded = 0 - - with file.open() as f_open: - with av_mod.open(f_open) as container: - stream = container.streams.video[0] - # Match LeRobot: seek backwards to preceding keyframe, then decode forwards. - container.seek(max(0, int(abs_ts / av_mod.time_base)), backward=True) - - tail_s = max(0.1, tolerance * 50.0, 1.0 / 24.0) - for vf in container.decode(stream): - if vf.pts is None: - continue - current_ts = float(vf.pts * stream.time_base) - pil_img = PILImage.fromarray(vf.to_ndarray(format="rgb24"), mode="RGB") - if width is not None and height is not None: - pil_img = pil_img.resize((width, height), PILImage.Resampling.NEAREST) - - loaded.append((current_ts, pil_img)) - decoded += 1 - - if decoded >= decode_cap: - raise ValueError("Exceeded decode frame budget while aligning to parquet timestamps.") - if current_ts >= abs_ts + tail_s: - break - - if not loaded: - raise ValueError(f"No frames decoded from shard while seeking timestamp {abs_ts:.6f}s.") - - closest_ts, closest_img = min(loaded, key=lambda item: abs(item[0] - abs_ts)) - closest_dist = abs(closest_ts - abs_ts) - if closest_dist > tolerance: - raise ValueError( - f"No frame matched timestamp {abs_ts:.6f}s within tolerance {tolerance} " - f"(closest distance observed: {closest_dist})." - ) - return closest_img - -class Feature(TypedDict): - dtype: str - -class LeRobotInfo(TypedDict): - codebase_version: str - data_path: str - video_path: str - fps: float - features: dict[str, Feature] - -def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobotInfo: - with daft.open_file(f"{normalized_uri}/meta/info.json", io_config=io_config) as f: - info = cast(LeRobotInfo, json.load(f)) - if info["codebase_version"] != "v3.0": - raise ValueError("`daft.datasets.lerobot` currently only supports LeRobot datasets of v3 and above") - return info - -@PublicAPI -def read( - dataset_uri: str, - io_config: IOConfig | None = None, - include_stats: bool = False, - load_video_frames: str | list[str] | bool = False, -) -> DataFrame: - """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per frame with episode metadata).""" - root = _normalize_dataset_root(dataset_uri) - info = _read_info(root, io_config=io_config) - - # Keep the per-episode video metadata (notably `videos/{key}/from_timestamp`, - # the time within the shard where each episode's footage begins). We need it - # to translate episode-local frame timestamps into absolute shard timestamps - # when decoding, and drop these internal columns again before returning. - episode_df = read_episodes( - dataset_uri, io_config=io_config, include_stats=include_stats, include_video_metadata=True - ) - df = load_episode_frames(episode_df, dataset_uri, io_config=io_config) - - # Load video frames into memory - if load_video_frames is not False: - if load_video_frames is True: - video_keys = [name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video"] - elif isinstance(load_video_frames, str): - video_keys = [load_video_frames] - elif isinstance(load_video_frames, list) and all(isinstance(k, str) for k in load_video_frames): - video_keys = load_video_frames - else: - raise ValueError(f"Invalid value provided for argument load_video_frames=`{load_video_frames}`") - - # An MP4 shard packs many episodes back to back, so the shard's internal - # frame numbering is NOT the parquet's episode-local `frame_index` (which - # resets to 0 each episode). Seeking by `frame_index` only happens to work - # for the first episode in each shard. Instead, seek by absolute timestamp: - # `from_timestamp` (where this episode begins in the shard) + the per-frame - # episode-local `timestamp`. That keeps a single coordinate system end to end. - fps = float(info["fps"]) - tolerance_s = 1.0 / fps / 2.0 # half a frame period: any closer frame is unambiguously "the" frame - - # To increase parallelism, reduce batch size - df = df.into_batches(16) # TODO (for later): Set it in the batch UDF instead? - for k in video_keys: - # TODO (for later): Optimize by using a batch UDF to avoid opening the same video multiple times - df = df.with_column( - k, - _decode_lerobot_video_timestamp( - col(f"videos/{k}/video"), - col(f"videos/{k}/from_timestamp"), - col("timestamp"), - lit(tolerance_s), - lit(0), # image_width: 0 disables resize (decode at native resolution) - lit(0), # image_height: 0 disables resize - ), - ) - df = df.exclude(f"videos/{k}/video") - - # TODO: What about raw images, what do i do about them? Is that a thing in LeRobot v3 - - # Drop the internal per-episode video metadata we kept above (chunk/file index, - # from/to timestamp). This restores read_episodes' default of hiding these. - df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) - - return df - - -@PublicAPI -def read_episodes( - dataset_uri: str, - io_config: IOConfig | None = None, - include_meta: bool = False, - include_stats: bool = False, - include_video_metadata: bool = False, -) -> DataFrame: - """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per episode). - - This reads the `meta/episodes/**/*.parquet` path under the dataset root. - - Args: - dataset_uri: Huggingface repo id (`org/name`), - or a local / remote directory (`s3://...`, `hf://datasets/...`) - io_config: Optional IO configuration for remote reads. - - Returns: - Lazy DataFrame of episode metadata. - """ - root = _normalize_dataset_root(dataset_uri) - info = _read_info(root, io_config=io_config) - - # TODO: What is the `meta` episodes into used for? How is it different from the `videos` info? - df = daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) - if not include_meta: - df = df.exclude(*(c for c in df.column_names if c.startswith("meta/"))) - if not include_stats: - df = df.exclude(*(c for c in df.column_names if c.startswith("stats/"))) - - # Get the video keys - video_keys = set(name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video") - - for key in video_keys: - file_name_expr = ( - lit(f"{root}/videos/{key}/chunk-") - + lpad(col(f"videos/{key}/chunk_index").cast(DataType.string), 3, "0") - + lit("/file-") - + lpad(col(f"videos/{key}/file_index").cast(DataType.string), 3, "0") - + lit(".mp4") - ) - - df = df.with_column(f"videos/{key}/video", video_file(file_name_expr, verify=False, io_config=io_config)) - - if not include_video_metadata: - df = df.exclude(*(c for c in df.column_names if c.startswith("videos/") and not c.endswith("/video"))) - - return df - - -@PublicAPI -def load_episode_frames( - episodes: DataFrame, - dataset_uri: str, - io_config: IOConfig | None = None, -) -> DataFrame: - """Expand an episode-level DataFrame into a frame-level DataFrame. - - Reads the per-frame parquet under ``data/**`` and joins it to the provided - episode metadata on ``episode_index``, producing one row per frame. Episode - metadata is broadcast across each episode's frames. - - Filter ``episodes`` before calling this to expand only the episodes you need; - only the surviving episodes contribute to the join. - - Args: - episodes: Episode-level DataFrame, typically from :func:`read_episodes` - (optionally filtered). Must contain an ``episode_index`` column. - dataset_uri: The same dataset identifier passed to :func:`read_episodes` - (Huggingface repo id ``org/name``, or a local / remote directory such - as ``s3://...`` or ``hf://datasets/...``). - io_config: Optional IO configuration for remote reads. - - Returns: - Lazy DataFrame with one row per frame. - """ - root = _normalize_dataset_root(dataset_uri) - - frame_df = daft.read_parquet(f"{root}/data/**", io_config=io_config) - df = episodes.join(frame_df, on=["episode_index"]) - df = df.exclude("data/chunk_index", "data/file_index") - return df - - -@PublicAPI -def read_tasks(dataset_uri: str, io_config: IOConfig | None = None) -> DataFrame: - """Load task metadata as a DataFrame. - - Prefers ``meta/tasks.parquet`` (current LeRobot default). Falls back to legacy - ``meta/tasks.jsonl`` when the Parquet file is missing. - """ - root = _normalize_dataset_root(dataset_uri) - - pq_url = f"{root}/meta/tasks.parquet" - try: - return daft.read_parquet(pq_url, io_config=io_config) - except (OSError, DaftCoreException, FileNotFoundError): - return daft.read_json(f"{root}/meta/tasks.jsonl", io_config=io_config) - - -__all__ = [ - "load_episode_frames", - "read", - "read_episodes", - "read_tasks", -] diff --git a/examples/lerobot_pose/predict_poses.py b/examples/lerobot_pose/predict_poses.py deleted file mode 100644 index ca694296cd2..00000000000 --- a/examples/lerobot_pose/predict_poses.py +++ /dev/null @@ -1,269 +0,0 @@ -# /// script -# description = "Run H-RDT action prediction over a LeRobot dataset with Daft" -# requires-python = ">=3.12, <3.13" -# dependencies = [ -# "daft>=0.7.15", -# "torch", -# "torchvision", -# "timm", -# "transformers", -# "sentencepiece", -# "diffusers", -# "huggingface-hub", -# "av", -# "pillow", -# "opencv-python", -# "pyyaml", -# "numpy", -# ] -# /// -"""Predict 48-D hand actions for every frame of an EgoDex LeRobot dataset. - -The pipeline is as follows: 1) we load in a lazy Daft DataFrame -with 1 row per frame. Each row contains the decoded camera image, the -48-D observation state, and the episode/task metadata. 2) we wrap the -H-RDT model as a Daft class, which allows for a persistent state in which -each worker builds the model once in __init__ and reuses it for every batch. -3) we append the predicted step per row as a new column in the DataFrame. -4) we write the DataFrame to a parquet file. - -The model predicts a chunk of 16 future actions. However we keep only the first step of each chunk, -so the new column is a single 48-D float vector per frame. - -Run `encode_task_embeddings.py` first to cache the T5 task embeddings, then: - - uv run predict_poses.py -""" - -# ruff: noqa: E402 -- sys.path must be extended before the `models.*` imports below -import os -import sys - -# Let any operator MPS doesn't implement fall back to CPU instead of crashing. -# Must be set before torch initializes the MPS backend. -os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") - -HRDT_ROOT = os.environ.get( - "HRDT_PROJECT_ROOT", - os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../H_RDT")), -) # the HRDT model repo must be cloned either in the directory specified by env or at the sibling level of the current repo -sys.path.append(HRDT_ROOT) - -import json - -import lerobot # vendored copy of daft.datasets.lerobot -import numpy as np -import torch -import yaml -from PIL import Image as PILImage - -# NOTE: the H_RDT imports (`from models...`) deliberately live inside -# HRDTPredictor.__init__, NOT here. Daft pickles the class to ship it to worker -# processes; pip-installed packages pickle by reference, but H_RDT is only -# importable via the sys.path hack above, so pickle tries to serialize its -# classes by value and fails ("cannot pickle 'GenericModule' object"). -# Importing inside __init__ defers the import to the worker process instead. -import daft -from daft import DataType, col - -DATASET_URI = "pepijn223/egodex-test" -OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_predictions") -TASK_EMBEDDINGS_DIR = os.path.join(os.path.dirname(__file__), "task_embeddings") - -# The EgoDex-pretrained checkpoint published in the H-RDT model repo. -CHECKPOINT_PATH = os.path.join(HRDT_ROOT, "checkpoints", "pretrain-0618", "checkpoint-500000") -CONFIG_PATH = os.path.join(HRDT_ROOT, "configs", "hrdt_pretrain.yaml") -STAT_PATH = os.path.join(HRDT_ROOT, "datasets", "pretrain", "egodex_stat.json") - -HAS_GPU = torch.cuda.is_available() -PREDICT_BATCH_SIZE = 16 -# For a quick local trial, set MAX_FRAMES to only predict the first N frames -# (e.g. `MAX_FRAMES=8 uv run predict_hrdt.py`). 0 means the whole dataset. -MAX_FRAMES = int(os.environ.get("MAX_FRAMES", "0")) -NUM_GPUS = 1 if HAS_GPU else 0 - -# Resolve the torch device HERE, at module level, not inside the class: -# `torch.backends.mps` is a property-object that cloudpickle cannot serialize, -# and Daft pickles the class (including everything its methods reference) to -# ship it to workers. A plain string global pickles fine. -if HAS_GPU: - DEVICE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE = "mps" # Apple-silicon GPU via Metal -else: - DEVICE = "cpu" - -@daft.cls(gpus=NUM_GPUS, max_concurrency=1) -class HRDTPredictor: - def __init__(self, ckpt_path: str, config_path: str, stat_path: str, embeddings_dir: str): - # Runs once per replica, on the worker, when execution starts. - # H_RDT imports happen here so they're never pickled (see note at top). - from models.encoder.dinosiglip_vit import DinoSigLIPViTBackbone # from H_RDT - from models.hrdt_runner import HRDTRunner # from H_RDT - - self.device = torch.device(DEVICE) - print(f"Using device: {self.device}") - - if self.device.type == "cuda": - self.dtype = torch.bfloat16 - else: - # float32 on MPS and CPU: bfloat16 support on MPS is still patchy. - self.dtype = torch.float32 - - # task_index -> (num_tokens, 4096) T5 embedding, cached by - # encode_task_embeddings.py. Loaded FIRST so a missing/empty cache fails - # in milliseconds, before the minutes-long model weight loading below. - self.lang_embeddings: dict[int, torch.Tensor] = {} - try: - fnames = sorted(os.listdir(embeddings_dir)) - except FileNotFoundError: - fnames = [] # missing dir is handled the same as an empty one, below - for fname in fnames: - if fname.endswith(".pt"): - payload = torch.load(os.path.join(embeddings_dir, fname), map_location="cpu") - self.lang_embeddings[int(payload["task_index"])] = payload["embeddings"].to( - self.device, dtype=self.dtype - ) - if not self.lang_embeddings: - raise FileNotFoundError( - f"No task embeddings found in {embeddings_dir}. Run encode_task_embeddings.py first." - ) - - with open(config_path) as f: - config = yaml.safe_load(f) - - - self.vision_encoder = DinoSigLIPViTBackbone( - vision_backbone_id="dino-siglip", - image_resize_strategy="letterbox" - if config["dataset"]["image_aspect_ratio"] == "pad" - else "resize-naive", - default_image_size=384, - ) - self.vision_encoder.to(self.device, dtype=self.dtype).eval() - self.image_transform = self.vision_encoder.get_image_transform() - - common = config["common"] - self.pred_horizon = common["action_chunk_size"] - self.policy = HRDTRunner( - state_dim=common["state_dim"], - action_dim=common["action_dim"], - pred_horizon=self.pred_horizon, - config=config["model"], - act_pos_emb_config=[ - ("state", 1), - ("action", self.pred_horizon), - ], - img_pos_emb_config=[ - ("image", (common["img_history_size"], common["num_cameras"], -self.vision_encoder.num_patches)), - ], - lang_pos_emb_config=[ - ("language", -config["dataset"]["tokenizer_max_length"]), - ], - max_img_len=common["img_history_size"] * common["num_cameras"] * self.vision_encoder.num_patches, - max_lang_len=config["dataset"]["tokenizer_max_length"], - training_mode="lang", - mode="pretrain", - dtype=self.dtype, - ) - state_dict = torch.load( - os.path.join(ckpt_path, "pytorch_model.bin"), map_location="cpu", weights_only=True - ) - - state_dict = {k: v for k, v in state_dict.items() if not k.startswith("video_adapter.")} - self.policy.load_state_dict(state_dict) - self.policy.to(self.device, dtype=self.dtype).eval() - - with open(stat_path) as f: - stat = json.load(f)["egodex"] - self.action_min = np.array(stat["min"], dtype=np.float32) - self.action_max = np.array(stat["max"], dtype=np.float32) - - @daft.method.batch( - return_dtype=DataType.embedding(DataType.float32(), 48), - batch_size=PREDICT_BATCH_SIZE, - ) - def predict(self, images: daft.Series, states: daft.Series, task_indices: daft.Series): - """Predict the next 48-D action for a batch of frames. - - Batch methods receive whole columns as `daft.Series` so we can run the - vision encoder and the policy once per batch instead of once per row. - """ - image_arrays = images.to_pylist() # list of HWC uint8 numpy arrays - state_batch = np.asarray(states.to_pylist(), dtype=np.float32) # (B, 48) - task_batch = task_indices.to_pylist() - batch_size = len(image_arrays) - - with torch.no_grad(): - # State: normalize to [-1, 1] exactly like pretraining did. - normalized = (state_batch - self.action_min) / (self.action_max - self.action_min) * 2 - 1 - state_tokens = ( - torch.from_numpy(np.clip(normalized, -1, 1)) - .reshape(batch_size, 1, -1) - .to(self.device, dtype=self.dtype) - ) - - # Images: letterbox + normalize each frame, then encode the whole - # batch in one DinoSigLIP forward pass. - transformed = [self.image_transform(PILImage.fromarray(arr)) for arr in image_arrays] - image_inputs = { - key: torch.stack([t[key] for t in transformed]).to(self.device, dtype=self.dtype) - for key in transformed[0] - } - image_features = self.vision_encoder(image_inputs) # (B, num_patches, embed_dim) - image_tokens = image_features.view(batch_size, -1, self.vision_encoder.embed_dim) - - # Language: look up each row's cached T5 embedding and pad to the - # longest one in the batch, with an attention mask marking padding. - embeds = [self.lang_embeddings[int(idx)] for idx in task_batch] - max_len = max(e.shape[0] for e in embeds) - lang_tokens = torch.zeros(batch_size, max_len, embeds[0].shape[1], device=self.device, dtype=self.dtype) - lang_attn_mask = torch.zeros(batch_size, max_len, device=self.device, dtype=torch.bool) - for i, e in enumerate(embeds): - lang_tokens[i, : e.shape[0]] = e - lang_attn_mask[i, : e.shape[0]] = True - - action_pred = self.policy.predict_action( - state_tokens=state_tokens, - image_tokens=image_tokens, - lang_tokens=lang_tokens, - lang_attn_mask=lang_attn_mask, - ) # (B, pred_horizon, 48), normalized to [-1, 1] - - chunk = action_pred.float().cpu().numpy() - # Denormalize (inverse of the [-1, 1] scaling) and keep only the - # first step of each predicted 16-step chunk. - denorm = (chunk + 1) / 2 * (self.action_max - self.action_min) + self.action_min - return [row for row in denorm[:, 0, :].astype(np.float32)] - - -if __name__ == "__main__": - predictor = HRDTPredictor(CHECKPOINT_PATH, CONFIG_PATH, STAT_PATH, TASK_EMBEDDINGS_DIR) - - df = lerobot.read(DATASET_URI, load_video_frames=True) - if MAX_FRAMES: - df = df.limit(MAX_FRAMES) - - df = ( - df.with_column( - "predicted_action", - predictor.predict(col("observation.image"), col("observation.state"), col("task_index")), - ) - # Keep the trajectory data and identifiers; drop the decoded frames so the - # output stays small (the images are reproducible from the dataset anyway). - .select( - "episode_index", - "frame_index", - "timestamp", - "task_index", - col("observation.state"), - col("action").alias("ground_truth_action"), - "predicted_action", - ) - ) - - df.write_parquet(OUTPUT_DIR) - print(f"Wrote predictions to {OUTPUT_DIR}") - print("Score them with: uv run compute_metrics.py") - - daft.read_parquet(f"{OUTPUT_DIR}/**").show(8) \ No newline at end of file diff --git a/examples/lerobot_pose/visualize_predictions.py b/examples/lerobot_pose/visualize_predictions.py deleted file mode 100644 index 2ec5584a116..00000000000 --- a/examples/lerobot_pose/visualize_predictions.py +++ /dev/null @@ -1,166 +0,0 @@ -# /// script -# description = "Overlay predicted vs ground-truth 48-D hand poses on EgoDex video frames" -# requires-python = ">=3.12, <3.13" -# dependencies = [ -# "daft>=0.7.15", -# "av", -# "pillow", -# "numpy", -# ] -# /// -"""Project predicted and ground-truth hand poses onto the camera frames. - -The 48-D vectors are 3D *world-frame* points (per hand: wrist position, wrist -6D rotation, 5 fingertip positions — see H_RDT's precompute_48d_actions.py). -To draw them on a frame we follow Apple's reference visualizer -(ml-egodex/visualize_2d.py): - - 1. world -> camera: multiply by the inverse of that frame's camera pose - (`observation.extrinsics`, a 4x4 matrix). - 2. camera -> pixels: pinhole projection u = fx*X/Z + cx, v = fy*Y/Z + cy, - with EgoDex's constant intrinsics (fx = fy = 736.6339, cx = 960, cy = 540). - -Ground truth is drawn in green, the model's prediction in red. Note Apple's -caveat: Vision Pro video is synthesized from multiple cameras, so even -ground-truth reprojections can be a few pixels off the visible hands. - -Run after predict_hrdt.py: - - uv run visualize_predictions.py -""" - -import os - -import lerobot # vendored copy of daft.datasets.lerobot -import numpy as np -from PIL import Image, ImageDraw - -import daft -from daft import col - -DATASET_URI = "pepijn223/egodex-test" -PREDICTIONS_DIR = os.path.join(os.path.dirname(__file__), "out", "egodex_hrdt_predictions") -OVERLAYS_DIR = os.path.join(os.path.dirname(__file__), "out", "overlays") - -# EgoDex camera intrinsics (constant across the dataset, from apple/ml-egodex). -FX = FY = 736.6339 -CX, CY = 960.0, 540.0 - -GROUND_TRUTH_COLOR = (0, 220, 0) # green -PREDICTION_COLOR = (255, 40, 40) # red - -# Playback speed of the stitched per-episode mp4s. The dataset is 30 fps, so -# 15 fps plays at half speed (e.g. 30 predicted frames -> a 2-second video). -VIDEO_FPS = int(os.environ.get("VIDEO_FPS", "15")) - - -def hand_points(vec48: np.ndarray, side: int) -> np.ndarray: - """Extract the 6 drawable 3D points of one hand (wrist + 5 fingertips). - - Layout per hand (24 dims): [0:3] wrist xyz, [3:9] wrist 6D rotation - (not drawable as a point, skipped), [9:24] thumb/index/middle/ring/little - fingertip xyz. side: 0 = left hand (dims 0-23), 1 = right hand (24-47). - """ - base = side * 24 - wrist = vec48[base : base + 3] - fingertips = vec48[base + 9 : base + 24].reshape(5, 3) - return np.vstack([wrist, fingertips]) # (6, 3): wrist first - - -def project_to_pixels(points_world: np.ndarray, extrinsics: np.ndarray) -> np.ndarray: - """World-frame 3D points -> (u, v) pixel coordinates (NaN if behind camera).""" - cam_from_world = np.linalg.inv(extrinsics) - homogeneous = np.hstack([points_world, np.ones((len(points_world), 1))]) # (N, 4) - in_camera = (cam_from_world @ homogeneous.T).T[:, :3] - x, y, z = in_camera[:, 0], in_camera[:, 1], in_camera[:, 2] - with np.errstate(divide="ignore", invalid="ignore"): - u = FX * x / z + CX - v = FY * y / z + CY - uv = np.stack([u, v], axis=1) - uv[z <= 0] = np.nan # behind the camera - return uv - - -def write_episode_video(episode: int, frames: list[tuple[int, np.ndarray]], fps: int) -> None: - """Encode one episode's overlay frames (sorted by frame_index) into an mp4. - - Uses PyAV, which bundles its own FFmpeg — no ffmpeg binary required. - """ - import av - - frames = sorted(frames, key=lambda pair: pair[0]) - out_path = os.path.join(OVERLAYS_DIR, f"episode{episode:03d}.mp4") - height, width = frames[0][1].shape[:2] - container = av.open(out_path, mode="w") - stream = container.add_stream("h264", rate=fps) - stream.width, stream.height = width, height - stream.pix_fmt = "yuv420p" - for _, image_array in frames: - for packet in stream.encode(av.VideoFrame.from_ndarray(image_array, format="rgb24")): - container.mux(packet) - for packet in stream.encode(): # flush buffered frames - container.mux(packet) - container.close() - print(f"Saved {out_path} ({len(frames)} frames @ {fps} fps = {len(frames) / fps:.2f}s)") - - -def draw_skeleton(draw: ImageDraw.ImageDraw, vec48: np.ndarray, extrinsics: np.ndarray, color: tuple) -> None: - """Draw both hands of one 48-D pose: wrist dot + lines fanning to fingertips.""" - for side in (0, 1): - uv = project_to_pixels(hand_points(vec48, side), extrinsics) - if np.isnan(uv).any(): - continue - wrist, fingertips = uv[0], uv[1:] - for tip in fingertips: - draw.line([tuple(wrist), tuple(tip)], fill=color, width=3) - draw.ellipse([tip[0] - 7, tip[1] - 7, tip[0] + 7, tip[1] + 7], fill=color) - draw.ellipse([wrist[0] - 11, wrist[1] - 11, wrist[0] + 11, wrist[1] + 11], fill=color) - - -if __name__ == "__main__": - # 1. Load the predictions and note which frames they belong to. - preds = daft.read_parquet(f"{PREDICTIONS_DIR}/**").select( - "episode_index", "frame_index", "ground_truth_action", "predicted_action" - ) - pred_rows = preds.to_pydict() - episodes = sorted(set(pred_rows["episode_index"])) - frames = sorted(set(pred_rows["frame_index"])) - print(f"Found {len(pred_rows['frame_index'])} predicted frames: episodes {episodes}, frames {frames}") - - # 2. Re-read just those frames from the dataset, keeping the image and the - # per-frame camera pose. The .where filter pushes down past the video - # decoder, so only the frames we need are decoded. - frames_df = ( - lerobot.read(DATASET_URI, load_video_frames=True) - .where(col("episode_index").is_in(episodes) & col("frame_index").is_in(frames)) - .select("episode_index", "frame_index", col("observation.image"), col("observation.extrinsics")) - ) - - # 3. Join images to predictions on (episode, frame) and pull the handful of - # rows into plain Python for drawing. - rows = frames_df.join(preds, on=["episode_index", "frame_index"]).to_pydict() - - os.makedirs(OVERLAYS_DIR, exist_ok=True) - episode_frames: dict[int, list[tuple[int, np.ndarray]]] = {} - for i in range(len(rows["frame_index"])): - episode = rows["episode_index"][i] - frame = rows["frame_index"][i] - extrinsics = np.array(rows["observation.extrinsics"][i], dtype=np.float64).reshape(4, 4) - ground_truth = np.array(rows["ground_truth_action"][i], dtype=np.float64) - predicted = np.array(rows["predicted_action"][i], dtype=np.float64) - - image = Image.fromarray(np.asarray(rows["observation.image"][i])) - draw = ImageDraw.Draw(image) - draw_skeleton(draw, ground_truth, extrinsics, GROUND_TRUTH_COLOR) - draw_skeleton(draw, predicted, extrinsics, PREDICTION_COLOR) - draw.text((20, 20), "green = ground truth red = predicted", fill=(255, 255, 255)) - - out_path = os.path.join(OVERLAYS_DIR, f"episode{episode:03d}_frame{frame:05d}.png") - image.save(out_path) - print(f"Saved {out_path}") - episode_frames.setdefault(episode, []).append((frame, np.asarray(image))) - - # Stitch each episode's overlays into a watchable mp4. The join above does - # not guarantee row order, so write_episode_video sorts by frame_index. - for episode, frames in sorted(episode_frames.items()): - write_episode_video(episode, frames, VIDEO_FPS) From df2861cbf1c1da5e744419d27a25c1c0afc05c44 Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Thu, 11 Jun 2026 16:48:21 -0700 Subject: [PATCH 09/11] Fix style and update lerobot docs to current API --- daft/datasets/lerobot.py | 16 ++++++++----- docs/datasets/lerobot.md | 50 ++++++++++++++++++++-------------------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py index dfbb7ef2a5c..ea1f89a96b8 100644 --- a/daft/datasets/lerobot.py +++ b/daft/datasets/lerobot.py @@ -8,16 +8,16 @@ from __future__ import annotations -import re import json +import re from typing import TYPE_CHECKING, Any, TypedDict, cast import daft from daft.api_annotations import PublicAPI from daft.datatype import DataType +from daft.exceptions import DaftCoreException from daft.expressions import col, lit from daft.file import VideoFile -from daft.exceptions import DaftCoreException from daft.functions import lpad from daft.functions.file_ import video_file from daft.udf import func @@ -46,7 +46,7 @@ def _decode_lerobot_video_timestamp( tolerance_s: float, image_width_i: int, image_height_i: int, -): +) -> Any: # returns a PIL.Image; PIL is an optional dependency imported lazily below """Pick the decoded frame closest in time to ``episode_from_timestamp_s + frame_timestamp_s``.""" try: import av as av_mod @@ -108,9 +108,11 @@ def _decode_lerobot_video_timestamp( ) return closest_img + class Feature(TypedDict): dtype: str + class LeRobotInfo(TypedDict): codebase_version: str data_path: str @@ -118,13 +120,15 @@ class LeRobotInfo(TypedDict): fps: float features: dict[str, Feature] + def _read_info(normalized_uri: str, io_config: IOConfig | None = None) -> LeRobotInfo: with daft.open_file(f"{normalized_uri}/meta/info.json", io_config=io_config) as f: - info = cast(LeRobotInfo, json.load(f)) + info = cast("LeRobotInfo", json.load(f)) if info["codebase_version"] != "v3.0": raise ValueError("`daft.datasets.lerobot` currently only supports LeRobot datasets of v3 and above") return info + @PublicAPI def read( dataset_uri: str, @@ -220,14 +224,14 @@ def read_episodes( df = df.exclude(*(c for c in df.column_names if c.startswith("meta/"))) if not include_stats: df = df.exclude(*(c for c in df.column_names if c.startswith("stats/"))) - + # Get the video keys video_keys = set(name for name, feat_info in info["features"].items() if feat_info["dtype"] == "video") for key in video_keys: file_name_expr = ( lit(f"{root}/videos/{key}/chunk-") - + lpad(col(f"videos/{key}/chunk_index").cast(DataType.string), 3, "0") + + lpad(col(f"videos/{key}/chunk_index").cast(DataType.string), 3, "0") + lit("/file-") + lpad(col(f"videos/{key}/file_index").cast(DataType.string), 3, "0") + lit(".mp4") diff --git a/docs/datasets/lerobot.md b/docs/datasets/lerobot.md index 98bc7d95f0f..1b54a4cd1f5 100644 --- a/docs/datasets/lerobot.md +++ b/docs/datasets/lerobot.md @@ -1,14 +1,21 @@ # LeRobot v3 datasets with Daft -[LeRobot Dataset v3.0](https://huggingface.co/docs/lerobot/lerobot-dataset-v3) stores robot learning data as chunked Parquet (`meta/episodes`, `data/`) and per-camera MP4 shards under `videos/`. Daft exposes this layout under [`daft.datasets.lerobot`](../../api/datasets.md) so you can stay at **episode granularity** for filtering, then expand to **frames** only for the episodes you need. +[LeRobot Dataset v3.0](https://huggingface.co/docs/lerobot/lerobot-dataset-v3) stores robot learning data as chunked Parquet (`meta/episodes`, `data/`) and per-camera MP4 shards under `videos/`. Daft exposes this layout under [`daft.datasets.lerobot`](../api/datasets.md) so you can stay at **episode granularity** for filtering, then expand to **frames** only for the episodes you need. !!! warning "Beta" This API is new and may evolve as we add optimizations (for example deeper integration with Parquet predicate pushdown). -## Episode metadata +## Frame-level reads + +Use [`daft.datasets.lerobot.read`](../api/datasets.md#daft.datasets.lerobot.read) for the common case: a lazy DataFrame with one row per frame, episode metadata broadcast onto each frame. Pass `load_video_frames=True` (or a camera key / list of keys) to also decode each row's camera image from the MP4 shards. + +```python +import daft +from daft.datasets import lerobot -Use [`daft.datasets.lerobot.episodes`](../api/datasets.md#daft.datasets.lerobot.episodes) to scan `meta/episodes/**/*.parquet`. The dataframe includes a `lerobot_dataset_root` column used by frame expansion helpers. +df = lerobot.read("your-org/your-robot-dataset", load_video_frames=True) +``` `dataset_uri` can be: @@ -16,55 +23,48 @@ Use [`daft.datasets.lerobot.episodes`](../api/datasets.md#daft.datasets.lerobot. - An `hf://datasets/org/name` URI (Hub layout matches the on-disk v3 tree) - A bare `org/name` string, which is interpreted as `hf://datasets/org/name` +## Episode metadata + +Use [`daft.datasets.lerobot.read_episodes`](../api/datasets.md#daft.datasets.lerobot.read_episodes) to scan `meta/episodes/**/*.parquet` (one row per episode). Per-episode `meta/` and `stats/` columns are hidden by default; opt in with `include_meta=True` / `include_stats=True`. + ```python import daft -from daft.datasets.lerobot import episodes, load_episode_frames +from daft.datasets.lerobot import load_episode_frames, read_episodes repo = "hf://datasets/your-org/your-robot-dataset" -ep = episodes(repo) +ep = read_episodes(repo) long = ep.where(daft.col("length") > 100) -frames = load_episode_frames(long) +frames = load_episode_frames(long, repo) ``` -[`load_episode_frames`](../api/datasets.md#daft.datasets.lerobot.load_episode_frames) reads only the `data/chunk-*/file-*.parquet` shards referenced by the (possibly filtered) episode rows, then keeps rows whose `episode_index` is still present. It runs a small eager step to list **distinct shard paths**; the heavy Parquet scan stays lazy afterward. Pass `columns=[...]` to project frame fields with :meth:`daft.DataFrame.select` semantics. - -## Dataset-level JSON and tasks +[`load_episode_frames`](../api/datasets.md#daft.datasets.lerobot.load_episode_frames) reads the per-frame Parquet under `data/**` and joins it to the provided episode rows on `episode_index`, producing one row per frame. Filter the episode DataFrame first so only the surviving episodes contribute frames. -Bounded metadata files are exposed as small helpers: +## Tasks -- [`read_info`](../api/datasets.md#daft.datasets.lerobot.read_info) → `meta/info.json` -- [`read_stats`](../api/datasets.md#daft.datasets.lerobot.read_stats) → `meta/stats.json` -- [`read_tasks`](../api/datasets.md#daft.datasets.lerobot.read_tasks) → prefers `meta/tasks.parquet`, falls back to `meta/tasks.jsonl` - -For Hub datasets, JSON is fetched over HTTPS from `resolve/main` (public files only unless your environment supplies credentials via your HTTP stack). +[`read_tasks`](../api/datasets.md#daft.datasets.lerobot.read_tasks) loads task metadata, preferring `meta/tasks.parquet` and falling back to legacy `meta/tasks.jsonl`. ## Video frames -Daft already decodes video via [`read_video_frames`](../api/io.md#daft.read_video_frames) and [`daft.VideoFile`](../modalities/videos.md). Episode metadata includes per-camera chunk/file indices and timestamp offsets (`videos/{camera}/...` fields in LeRobot v3). Build the MP4 path from those columns (plus `lerobot_dataset_root`), then call `read_video_frames` with `sample_interval_seconds` or decode with `daft.functions.video_frames` for precise timestamps. +With `load_video_frames`, [`read`](../api/datasets.md#daft.datasets.lerobot.read) decodes each frame from its MP4 shard by **timestamp**: a shard packs many episodes back to back, so Daft combines the episode's `from_timestamp` offset within the shard with the frame's episode-local `timestamp`, and matches the closest decoded frame within half a frame period. Decoding requires PyAV and Pillow (`pip install av pillow`). ## API reference -::: daft.datasets.lerobot.episodes +::: daft.datasets.lerobot.read options: filters: ["!^_"] heading_level: 3 -::: daft.datasets.lerobot.load_episode_frames +::: daft.datasets.lerobot.read_episodes options: filters: ["!^_"] heading_level: 3 -::: daft.datasets.lerobot.read_tasks - options: - filters: ["!^_"] - heading_level: 3 - -::: daft.datasets.lerobot.read_info +::: daft.datasets.lerobot.load_episode_frames options: filters: ["!^_"] heading_level: 3 -::: daft.datasets.lerobot.read_stats +::: daft.datasets.lerobot.read_tasks options: filters: ["!^_"] heading_level: 3 From 0f84f2581b1ccd1eee5ee79b3ff009a7380f5901 Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Thu, 11 Jun 2026 16:59:21 -0700 Subject: [PATCH 10/11] Format daft/functions/video.py --- daft/functions/video.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/daft/functions/video.py b/daft/functions/video.py index c6d4ebba209..419fde030e9 100644 --- a/daft/functions/video.py +++ b/daft/functions/video.py @@ -200,6 +200,7 @@ def get_frame_by_idx_impl( ) -> PIL.Image.Image: return file.get_frame_by_idx(idx) + video_get_frame_by_idx_fn = Func._from_func( get_frame_by_idx_impl, return_dtype=daft.DataType.image(), @@ -212,6 +213,7 @@ def get_frame_by_idx_impl( name_override="video_get_frame_by_idx", ) + def get_video_frame_by_idx( file_expr: Expression, idx: int | Expression, From fae15849a7869ac8f9862d5465a2fa5bd33a22b9 Mon Sep 17 00:00:00 2001 From: Shreyas Garimella Date: Mon, 15 Jun 2026 02:51:31 -0700 Subject: [PATCH 11/11] lerobot with formatting changes --- daft/datasets/lerobot.py | 46 +++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/daft/datasets/lerobot.py b/daft/datasets/lerobot.py index ea1f89a96b8..c64c1c38b58 100644 --- a/daft/datasets/lerobot.py +++ b/daft/datasets/lerobot.py @@ -77,7 +77,7 @@ def _decode_lerobot_video_timestamp( with av_mod.open(f_open) as container: stream = container.streams.video[0] # Match LeRobot: seek backwards to preceding keyframe, then decode forwards. - container.seek(max(0, int(abs_ts / av_mod.time_base)), backward=True) + container.seek(max(0, int(abs_ts * av_mod.time_base)), backward=True) tail_s = max(0.1, tolerance * 50.0, 1.0 / 24.0) for vf in container.decode(stream): @@ -136,7 +136,29 @@ def read( include_stats: bool = False, load_video_frames: str | list[str] | bool = False, ) -> DataFrame: - """Read LeRobot v3 episode metadata as a lazy DataFrame (one row per frame with episode metadata).""" + """Read a LeRobot v3 dataset as a lazy DataFrame with one row per frame. + + Reads the per-episode metadata under ``meta/episodes`` and the per-frame + sensor data under ``data``, joins them on ``episode_index``, and broadcasts + each episode's metadata across its frames. Optionally decodes the matching + video frame for one or more camera keys into an image column. + + Args: + dataset_uri: Huggingface repo id (``org/name``), or a local / remote + directory (``s3://...``, ``hf://datasets/...``). + io_config: Optional IO configuration for remote reads. + include_stats: If True, keep the per-episode ``stats/*`` columns + (per-feature min/max/mean/std/quantiles). Defaults to False. + load_video_frames: Which camera keys to decode into image columns, + aligned to each frame's timestamp. Defaults to False (decode + nothing). Pass True to decode every video feature, a single key + (``"observation.image"``), or a list of keys. Decoding requires the + optional ``av`` (PyAV) and ``Pillow`` dependencies. + + Returns: + Lazy DataFrame with one row per frame: the frame's sensor columns, the + broadcast episode metadata, and one image column per decoded video key. + """ root = _normalize_dataset_root(dataset_uri) info = _read_info(root, io_config=io_config) @@ -170,9 +192,8 @@ def read( tolerance_s = 1.0 / fps / 2.0 # half a frame period: any closer frame is unambiguously "the" frame # To increase parallelism, reduce batch size - df = df.into_batches(16) # TODO (for later): Set it in the batch UDF instead? + df = df.into_batches(16) for k in video_keys: - # TODO (for later): Optimize by using a batch UDF to avoid opening the same video multiple times df = df.with_column( k, _decode_lerobot_video_timestamp( @@ -186,7 +207,6 @@ def read( ) df = df.exclude(f"videos/{k}/video") - # TODO: What about raw images, what do i do about them? Is that a thing in LeRobot v3 # Drop the internal per-episode video metadata we kept above (chunk/file index, # from/to timestamp). This restores read_episodes' default of hiding these. @@ -211,14 +231,24 @@ def read_episodes( dataset_uri: Huggingface repo id (`org/name`), or a local / remote directory (`s3://...`, `hf://datasets/...`) io_config: Optional IO configuration for remote reads. + include_meta: If True, keep the internal ``meta/episodes/*`` columns + (the chunk/file indices locating each episode's own metadata shard). + These are bookkeeping for random access into the sharded metadata + and carry no analytical value once the rows are loaded. Defaults to + False. + include_stats: If True, keep the per-episode ``stats/*`` columns + (per-feature min/max/mean/std/quantiles). Defaults to False. + include_video_metadata: If True, keep the per-episode ``videos/{key}/*`` + columns (the chunk/file indices and from/to timestamps locating each + episode's footage within its video shard). Defaults to False. Returns: - Lazy DataFrame of episode metadata. + Lazy DataFrame of episode metadata, one row per episode. Always includes + a ``videos/{key}/video`` file-handle column per video feature; the + ``include_*`` flags control which additional column families are kept. """ root = _normalize_dataset_root(dataset_uri) info = _read_info(root, io_config=io_config) - - # TODO: What is the `meta` episodes into used for? How is it different from the `videos` info? df = daft.read_parquet(f"{root}/meta/episodes/**/*.parquet", io_config=io_config) if not include_meta: df = df.exclude(*(c for c in df.column_names if c.startswith("meta/")))