Skip to content

Commit 18b056f

Browse files
authored
Merge pull request #5 from OpenDriveLab/dev
Dev sa code but not sa data
2 parents b490017 + bf793ff commit 18b056f

File tree

15 files changed

+2482
-158
lines changed

15 files changed

+2482
-158
lines changed

.vscode/settings.json

Lines changed: 0 additions & 11 deletions
This file was deleted.

README.md

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
χ₀ addresses the systematic distributional shift among the human demonstration distribution ($P_\text{train}$), the inductive bias learned by the policy ($Q_\text{model}$), and the test-time execution distribution ($P_\text{test}$) through three technical modules:
2121

2222
- **[Model Arithmetic](#model-arithmetic)**: A weight-space merging strategy that combines models trained on different data subsets, efficiently capturing diverse knowledge without architectural complexity. **[Released]**
23-
- **[Stage Advantage](#stage-advantage-coming-soon)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Coming Soon]**
23+
- **[Stage Advantage](#stage-advantage)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Released]**
2424
- **[Train-Deploy Alignment](#train-deploy-alignment-coming-soon)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Coming Soon]**
2525

2626
χ₀ enables two sets of dual-arm robots to collaboratively orchestrate long-horizon garment manipulation — flattening, folding, and hanging — surpassing the state-of-the-art $\pi_{0.5}$ baseline by approximately 250% in success rate, with `only 20 hours of data and 8 A100 GPUs`.
@@ -46,14 +46,15 @@ https://github.com/user-attachments/assets/3f5f0c48-ff3f-4b9b-985b-59ad0b2ea97c
4646
- [Model Arithmetic](#model-arithmetic)
4747
- [Workflow](#workflow)
4848
- [Quick Start](#quick-start)
49-
- [Stage Advantage (Coming Soon)](#stage-advantage-coming-soon)
49+
- [Stage Advantage](#stage-advantage)
5050
- [Train-Deploy Alignment (Coming Soon)](#train-deploy-alignment-coming-soon)
5151
- [Citation](#licenseandcitation)
5252
- [Troubleshooting](#troubleshooting)
5353
- [Links and Community](#links-and-community)
5454

5555
## Update
5656

57+
- [Feb 14 2026] Release of the **Stage Advantage** module: advantage estimator training, evaluation, GT labeling, and AWBC training pipeline.
5758
- [Feb 10 2026] Initial release of the **Model Arithmetic** module with support for both JAX and PyTorch checkpoints (not tested thoroughly).
5859
- [Feb 10 2026] χ₀ paper released.
5960

@@ -208,9 +209,9 @@ Checkpoints are written to the config’s checkpoint directory. You can then use
208209

209210
- [x] kai0 oracle: training and inference code with non-advantage data of three tasks
210211
- [x] Model Arithmetic: code of different baselines for weight-space interpolation
211-
- [ ] Stage Advantage: code, data (advantage labels), and checkpoints**Feb 12**
212-
- [ ] HuggingFace & ModelScope: upload Stage Advantage data and checkpoints — **Feb 12**
213-
- [ ] Train-Deploy Alignment — **Feb 15**
212+
- [x] Stage Advantage: code, data (advantage labels), and checkpoints
213+
- [ ] HuggingFace & ModelScope: upload Stage Advantage data and checkpoints — **Feb 14**
214+
- [ ] Train-Deploy Alignment — **Feb 14**
214215

215216
## Model Arithmetic
216217

@@ -265,11 +266,54 @@ python model_arithmetic/arithmetic_torch.py \
265266

266267
For gradient-based optimization, dataset splitting, and all other methods, see the full documentation in [`model_arithmetic/README.md`](model_arithmetic/README.md).
267268

268-
## Stage Advantage (Coming Soon)
269+
## Stage Advantage
269270

270271
Stage Advantage decomposes long-horizon tasks into semantic stages and provides stage-aware advantage signals for policy training. It addresses the numerical instability of prior non-stage approaches by computing advantage as progress differentials within each stage, yielding smoother and more stable supervision.
271272

272-
**This module is currently under refinement and will be released soon.**
273+
The full pipeline has four stages:
274+
275+
```
276+
Stage 0: GT Labeling → Stage 1: Train Advantage Estimator → Stage 2: Advantage Estimation → Stage 3: AWBC Training
277+
```
278+
279+
### Quick Start
280+
281+
**Stage 0 — GT Data Labeling**: Compute advantage values and discretize into `task_index` labels.
282+
283+
```bash
284+
cd stage_advantage/annotation
285+
python gt_label.py <dataset_path> \
286+
--threshold 30 --chunk-size 50 --discretion-type binary \
287+
--advantage-source absolute_advantage
288+
```
289+
290+
For batch labeling across multiple dataset variants, see `stage_advantage/annotation/gt_labeling.sh`.
291+
292+
**Stage 1 — Train Advantage Estimator**: Fine-tune a pi0-based model to predict advantage from observations.
293+
294+
```bash
295+
uv run python scripts/train_pytorch.py ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD --exp_name=run1 --save_interval 10000
296+
```
297+
298+
For a ready-to-use script with environment setup (conda/venv activation, DDP configuration) and automatic log management, see `stage_advantage/annotation/train_estimator.sh`.
299+
300+
**Stage 2 — Advantage Estimation on New Data**: Use the trained estimator to label datasets with predicted advantage values.
301+
302+
```bash
303+
uv run python stage_advantage/annotation/eval.py Flatten-Fold KAI0 /path/to/dataset
304+
```
305+
306+
For a ready-to-use script with environment setup and status logging, see `stage_advantage/annotation/eval.sh`.
307+
308+
**Stage 3 — AWBC Training**: Train a policy with Advantage-Weighted Behavior Cloning.
309+
310+
```bash
311+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_flatten_fold_awbc --exp_name=run1
312+
```
313+
314+
For a ready-to-use script with environment setup and automatic log management, see `stage_advantage/awbc/train_awbc.sh`.
315+
316+
For the full pipeline details, configuration instructions, and all parameters, see [`stage_advantage/README.md`](stage_advantage/README.md).
273317

274318
## Train-Deploy Alignment (Coming Soon)
275319

src/openpi/policies/agilex_policy.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class AgilexInputs(transforms.DataTransformFn):
1515
"""Inputs for the Agilex policy.
1616
1717
Expected inputs:
18-
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
18+
- images: dict[name, img] where img is [channel, height, width]. For normal pi05
19+
training, names must be exactly the keys of required_rename_map. For advantage
20+
estimator, optional_rename_map keys may be included as well.
1921
- state: [14]
2022
- actions: [action_horizon, 14]
2123
"""
@@ -28,13 +30,23 @@ class AgilexInputs(transforms.DataTransformFn):
2830

2931
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
3032
# replaced with black images and the corresponding `image_mask` will be set to False.
31-
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("top_head", "hand_left", "hand_right")
3233

33-
rename_map = {
34+
required_rename_map = {
3435
"top_head": "base_0_rgb",
3536
"hand_left": "left_wrist_0_rgb",
3637
"hand_right": "right_wrist_0_rgb"
3738
}
39+
# Optional cameras for advantage-estimator training (history frames).
40+
optional_rename_map = {
41+
"his_-100_top_head": "base_-100_rgb",
42+
"his_-100_hand_left": "left_wrist_-100_rgb",
43+
"his_-100_hand_right": "right_wrist_-100_rgb",
44+
}
45+
46+
all_rename_map = {**required_rename_map, **optional_rename_map}
47+
48+
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = tuple(required_rename_map.keys())
49+
EXTRA_CAMERAS: ClassVar[tuple[str, ...]] = tuple(optional_rename_map.keys())
3850

3951
# if set all state to zeros
4052
mask_state: bool = False
@@ -43,16 +55,22 @@ def __call__(self, data: dict) -> dict:
4355
# We only mask padding for pi0 model, not pi0-FAST
4456
mask_padding = self.model_type == _model.ModelType.PI0
4557

58+
in_images = data["images"]
59+
60+
if set(in_images) - set(self.EXPECTED_CAMERAS) - set(self.EXTRA_CAMERAS):
61+
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
62+
4663
# Pad the proprioceptive input to the action dimension of the model
4764
state = transforms.pad_to_dim(data["state"], self.action_dim)
4865
# Ensure state has correct shape [batch_size, state_dim]
4966
state = state.squeeze()
5067

5168
# Parse images to uint8 (H,W,C) since LeRobot automatically stores as float32 (C,H,W)
5269
images = {}
53-
for camera in self.EXPECTED_CAMERAS:
54-
if camera in data["images"]:
55-
img = data["images"][camera]
70+
image_masks = {}
71+
for camera in self.EXPECTED_CAMERAS + self.EXTRA_CAMERAS:
72+
if camera in in_images:
73+
img = in_images[camera]
5674
# Convert torch tensor to numpy array if needed
5775
if isinstance(img, torch.Tensor):
5876
img = img.cpu().numpy()
@@ -62,12 +80,14 @@ def __call__(self, data: dict) -> dict:
6280
# Convert from [C,H,W] to [H,W,C] if needed
6381
if img.shape[0] == 3:
6482
img = np.transpose(img, (1, 2, 0))
65-
images[self.rename_map[camera]] = img
83+
images[self.all_rename_map[camera]] = img
84+
image_masks[self.all_rename_map[camera]] = np.True_
85+
86+
elif camera not in in_images and camera in self.EXTRA_CAMERAS:
87+
continue # optional camera can be skipped
6688
else:
6789
raise ValueError(f"Camera {camera} not found in data")
6890

69-
# Create image mask based on available cameras
70-
image_mask = {self.rename_map[camera]: np.True_ for camera in self.EXPECTED_CAMERAS}
7191

7292
# filter unnormal state / action value, set to 0
7393
state = np.where(state > np.pi, 0, state)
@@ -77,7 +97,7 @@ def __call__(self, data: dict) -> dict:
7797
masked_state = np.zeros_like(state) if self.mask_state else state
7898
inputs = {
7999
"image": images,
80-
"image_mask": image_mask,
100+
"image_mask": image_masks,
81101
"state": masked_state,
82102
}
83103

@@ -91,17 +111,34 @@ def __call__(self, data: dict) -> dict:
91111
action_mask = np.ones_like(actions, dtype=bool)
92112
action_mask[:, self.action_dim:] = False
93113
inputs["action_mask"] = action_mask
94-
95-
if self.convert_to_eef_position:
96-
actions[..., :14] = batch_qpos_to_eef_pos(actions[..., :14])
114+
97115
inputs["actions"] = actions.squeeze()
98116

99117
# Add prompt if present
100118
if "prompt" in data:
101119
inputs["prompt"] = data["prompt"]
102-
120+
121+
# Advantage-estimator optional fields: passthrough or convert to tensor
122+
for key in ("frame_index", "episode_length", "progress", "image_original", "episode_index"):
123+
if key in data:
124+
inputs[key] = data[key]
125+
126+
def _to_tensor(x, default=None):
127+
if x is None and default is not None:
128+
return default
129+
if isinstance(x, np.ndarray):
130+
return torch.from_numpy(x)
131+
if isinstance(x, torch.Tensor):
132+
return x.detach().clone()
133+
raise NotImplementedError(f"Unsupported type: {type(x)}")
134+
135+
if "action_advantage" in data:
136+
inputs["action_advantage"] = _to_tensor(data["action_advantage"], default=torch.tensor(1.0))
137+
if "action_advantage_original" in data:
138+
inputs["action_advantage_original"] = _to_tensor(data["action_advantage_original"])
103139
return inputs
104140

141+
105142
@dataclasses.dataclass(frozen=True)
106143
class AgilexOutputs(transforms.DataTransformFn):
107144
"""Outputs for the Agilex policy."""

src/openpi/policies/arx_policy.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
import openpi.models.model as _model
1010
import openpi.transforms as transforms
1111

12+
1213
@dataclasses.dataclass(frozen=True)
1314
class ARXInputs(transforms.DataTransformFn):
1415
"""Inputs for the ARX policy.
1516
1617
Expected inputs:
17-
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
18+
- images: dict[name, img] where img is [channel, height, width]. For normal pi05
19+
training, names must be exactly the keys of required_rename_map. For advantage
20+
estimator, optional_rename_map keys may be included as well.
1821
- state: [14]
1922
- actions: [action_horizon, 14]
2023
"""
@@ -27,32 +30,47 @@ class ARXInputs(transforms.DataTransformFn):
2730

2831
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
2932
# replaced with black images and the corresponding `image_mask` will be set to False.
30-
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("top_head", "hand_left", "hand_right")
3133

32-
rename_map = {
34+
required_rename_map = {
3335
"top_head": "base_0_rgb",
3436
"hand_left": "left_wrist_0_rgb",
3537
"hand_right": "right_wrist_0_rgb"
3638
}
39+
# Optional cameras for advantage-estimator training (history frames).
40+
optional_rename_map = {
41+
"his_-100_top_head": "base_-100_rgb",
42+
"his_-100_hand_left": "left_wrist_-100_rgb",
43+
"his_-100_hand_right": "right_wrist_-100_rgb",
44+
}
45+
46+
all_rename_map = {**required_rename_map, **optional_rename_map}
47+
48+
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = tuple(required_rename_map.keys())
49+
EXTRA_CAMERAS: ClassVar[tuple[str, ...]] = tuple(optional_rename_map.keys())
3750

3851
# if set all state to zeros
3952
mask_state: bool = False
4053

41-
4254
def __call__(self, data: dict) -> dict:
4355
# We only mask padding for pi0 model, not pi0-FAST
4456
mask_padding = self.model_type == _model.ModelType.PI0
4557

58+
in_images = data["images"]
59+
60+
if set(in_images) - set(self.EXPECTED_CAMERAS) - set(self.EXTRA_CAMERAS):
61+
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
62+
4663
# Pad the proprioceptive input to the action dimension of the model
4764
state = transforms.pad_to_dim(data["state"], self.action_dim)
4865
# Ensure state has correct shape [batch_size, state_dim]
4966
state = state.squeeze()
5067

5168
# Parse images to uint8 (H,W,C) since LeRobot automatically stores as float32 (C,H,W)
5269
images = {}
53-
for camera in self.EXPECTED_CAMERAS:
54-
if camera in data["images"]:
55-
img = data["images"][camera]
70+
image_masks = {}
71+
for camera in self.EXPECTED_CAMERAS + self.EXTRA_CAMERAS:
72+
if camera in in_images:
73+
img = in_images[camera]
5674
# Convert torch tensor to numpy array if needed
5775
if isinstance(img, torch.Tensor):
5876
img = img.cpu().numpy()
@@ -62,38 +80,57 @@ def __call__(self, data: dict) -> dict:
6280
# Convert from [C,H,W] to [H,W,C] if needed
6381
if img.shape[0] == 3:
6482
img = np.transpose(img, (1, 2, 0))
65-
images[self.rename_map[camera]] = img
83+
images[self.all_rename_map[camera]] = img
84+
image_masks[self.all_rename_map[camera]] = np.True_
85+
86+
elif camera not in in_images and camera in self.EXTRA_CAMERAS:
87+
continue # optional camera can be skipped
6688
else:
6789
raise ValueError(f"Camera {camera} not found in data")
6890

69-
# Create image mask based on available cameras
70-
image_mask = {self.rename_map[camera]: np.True_ for camera in self.EXPECTED_CAMERAS}
71-
7291
# Prepare inputs dictionary
7392
masked_state = np.zeros_like(state) if self.mask_state else state
7493
inputs = {
7594
"image": images,
76-
"image_mask": image_mask,
95+
"image_mask": image_masks,
7796
"state": masked_state,
7897
}
7998

8099
# Add actions if present
81100
if "actions" in data:
82101
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
83-
# actions = np.where(actions > np.pi, 0, actions)
84-
# actions = np.where(actions < -np.pi, 0, actions)
102+
actions = np.where(actions > np.pi, 0, actions)
103+
actions = np.where(actions < -np.pi, 0, actions)
85104
if mask_padding:
86105
# Create action mask for padding
87106
action_mask = np.ones_like(actions, dtype=bool)
88107
action_mask[:, self.action_dim:] = False
89108
inputs["action_mask"] = action_mask
90-
109+
91110
inputs["actions"] = actions.squeeze()
92111

93112
# Add prompt if present
94113
if "prompt" in data:
95114
inputs["prompt"] = data["prompt"]
96-
115+
116+
# Advantage-estimator optional fields: passthrough or convert to tensor
117+
for key in ("frame_index", "episode_length", "progress", "image_original", "episode_index"):
118+
if key in data:
119+
inputs[key] = data[key]
120+
121+
def _to_tensor(x, default=None):
122+
if x is None and default is not None:
123+
return default
124+
if isinstance(x, np.ndarray):
125+
return torch.from_numpy(x)
126+
if isinstance(x, torch.Tensor):
127+
return x.detach().clone()
128+
raise NotImplementedError(f"Unsupported type: {type(x)}")
129+
130+
if "action_advantage" in data:
131+
inputs["action_advantage"] = _to_tensor(data["action_advantage"], default=torch.tensor(1.0))
132+
if "action_advantage_original" in data:
133+
inputs["action_advantage_original"] = _to_tensor(data["action_advantage_original"])
97134
return inputs
98135

99136

0 commit comments

Comments
 (0)