|
| 1 | +import marimo |
| 2 | + |
| 3 | +__generated_with = "0.20.2" |
| 4 | +app = marimo.App(width="full") |
| 5 | + |
| 6 | +with app.setup: |
| 7 | + import marimo as mo |
| 8 | + from openptv_python.calibration import Calibration |
| 9 | + from openptv_python.correspondences import MatchedCoords, correspondences |
| 10 | + from openptv_python.epi import epipolar_curve |
| 11 | + from openptv_python.image_processing import preprocess_image |
| 12 | + from openptv_python.imgcoord import image_coordinates |
| 13 | + from openptv_python.parameters import ControlPar as ControlParams, VolumePar as VolumeParams |
| 14 | + from openptv_python.segmentation import target_recognition |
| 15 | + from openptv_python.trafo import arr_metric_to_pixel as convert_arr_metric_to_pixel |
| 16 | + from openptv_python.tracking_frame_buf import Frame |
| 17 | + |
| 18 | + from .parameter_manager import ParameterManager |
| 19 | + from .experiment import Experiment |
| 20 | + from pathlib import Path |
| 21 | + import matplotlib.pyplot as plt |
| 22 | + import imageio.v3 as iio |
| 23 | + import numpy as np |
| 24 | + import matplotlib |
| 25 | + from . import ptv |
| 26 | + |
| 27 | + |
| 28 | +@app.cell |
| 29 | +def _(): |
| 30 | + # load parameters from the YAML file |
| 31 | + _yaml_path = '/home/user/Downloads/Illmenau/pyPTV_folder/parameters_Run4.yaml' |
| 32 | + |
| 33 | + yaml_path = Path(_yaml_path).expanduser().resolve() |
| 34 | + assert yaml_path.exists() |
| 35 | + |
| 36 | + pm = ParameterManager() |
| 37 | + pm.from_yaml(yaml_path) |
| 38 | + exp = Experiment(pm=pm) |
| 39 | + |
| 40 | + params = pm.parameters |
| 41 | + num_cams = int(params.get("num_cams", pm.num_cams or 0) or 0) |
| 42 | + print(f"Number of cameras: {num_cams} in {yaml_path}") |
| 43 | + return num_cams, params, pm, yaml_path |
| 44 | + |
| 45 | + |
| 46 | +@app.cell |
| 47 | +def _(num_cams, pm, yaml_path): |
| 48 | + |
| 49 | + |
| 50 | + cals = [] |
| 51 | + images = [] |
| 52 | + |
| 53 | + ptv_params = pm.parameters.get('ptv', {}) |
| 54 | + img_names = ptv_params.get('img_name', []) |
| 55 | + cal_img_names = ptv_params.get('img_cal', []) |
| 56 | + |
| 57 | + # Let's try to get them directly. |
| 58 | + cal_ori = pm.parameters.get('cal_ori', {}) |
| 59 | + ori_names = cal_ori.get('img_ori', []) |
| 60 | + |
| 61 | + base_path = Path(yaml_path).parent |
| 62 | + |
| 63 | + for i in range(num_cams): |
| 64 | + # Images |
| 65 | + img_path = img_names[i] |
| 66 | + if not Path(img_path).is_absolute(): |
| 67 | + img_path = base_path / img_path |
| 68 | + |
| 69 | + try: |
| 70 | + img = iio.imread(img_path) |
| 71 | + images.append(img) |
| 72 | + except Exception as e: |
| 73 | + print(f"Failed to load image {img_path}: {e}") |
| 74 | + # fallback to a blank image |
| 75 | + images.append(np.zeros((ptv_params.get('imy', 1024), ptv_params.get('imx', 1024)))) |
| 76 | + |
| 77 | + # Calibrations |
| 78 | + cal = Calibration() |
| 79 | + |
| 80 | + # Try using the logic from ptv.py: base name from cal_ori.img_cal_name |
| 81 | + cal_img_name = cal_ori.get('img_cal_name', cal_img_names)[i] |
| 82 | + |
| 83 | + # wait, the output of cal_ori shows img_ori: ['cal/run3/cam1.tif.ori', ...] |
| 84 | + ori_file_path = base_path / ori_names[i] |
| 85 | + |
| 86 | + # In PTV, addpar file has .addpar extension but what is the exact name? |
| 87 | + # Usually it's base name + .addpar, i.e., without .tif.ori? |
| 88 | + # Let's just check if it's ori_names[i] replacing .tif.ori with .addpar |
| 89 | + # or .ori with .addpar |
| 90 | + addpar_file_path = Path(str(ori_file_path).replace('.ori', '') + '.addpar') |
| 91 | + if not addpar_file_path.exists(): |
| 92 | + addpar_file_path = Path(str(ori_file_path).replace('.tif.ori', '') + '.addpar') |
| 93 | + |
| 94 | + if ori_file_path.exists() and addpar_file_path.exists(): |
| 95 | + cal.from_file(ori_file_path, addpar_file_path) |
| 96 | + print(f"Loaded calibration from {ori_file_path} and {addpar_file_path}") |
| 97 | + else: |
| 98 | + print(f"Missing calibration files for camera {i+1}: {ori_file_path} / {addpar_file_path}") |
| 99 | + |
| 100 | + cals.append(cal) |
| 101 | + return cals, images |
| 102 | + |
| 103 | + |
| 104 | +@app.cell |
| 105 | +def _(num_cams, params, pm): |
| 106 | + cpar = ptv._populate_cpar(pm.parameters['ptv'], num_cams) |
| 107 | + vpar = ptv._populate_vpar(pm.parameters['criteria']) |
| 108 | + tpar = ptv._populate_tpar({'targ_rec': params['targ_rec']}, num_cams) |
| 109 | + print("cpar image size:", cpar.get_image_size()) |
| 110 | + return cpar, tpar, vpar |
| 111 | + |
| 112 | + |
| 113 | +@app.cell |
| 114 | +def _(cpar, images, pm): |
| 115 | + images_8bit = [ptv.img_as_ubyte(im) for im in images] |
| 116 | + |
| 117 | + # # Check if negative flag is set, if so, invert the 8-bit images |
| 118 | + is_negative = pm.parameters.get('ptv', {}).get('negative', False) |
| 119 | + if is_negative: |
| 120 | + # Invert images: 255 - image |
| 121 | + images_8bit = [np.clip(255 - im, 0, 255) for im in images_8bit] |
| 122 | + print("Applied negative inversion to images.") |
| 123 | + |
| 124 | + images_8bit = [ptv.simple_highpass(img, cpar) for img in images_8bit] |
| 125 | + return (images_8bit,) |
| 126 | + |
| 127 | + |
| 128 | +@app.cell |
| 129 | +def _(): |
| 130 | + # # Visualize the first image after applying the highpass filter |
| 131 | + # _fig, _ax = plt.subplots(figsize=(8, 6)) |
| 132 | + # _ax.imshow(images_8bit[0], cmap='gray') |
| 133 | + # _ax.set_title("Highpass Filtered Image (Camera 1)") |
| 134 | + # _ax.axis('off') |
| 135 | + # _ax |
| 136 | + return |
| 137 | + |
| 138 | + |
| 139 | +@app.cell |
| 140 | +def _(cals, cpar, images_8bit, tpar, vpar): |
| 141 | + targets = [] |
| 142 | + matched = [] |
| 143 | + frame = Frame(len(cals)) |
| 144 | + |
| 145 | + for i_cam, im in enumerate(images_8bit): |
| 146 | + targs = target_recognition(im, tpar, i_cam, cpar) |
| 147 | + if hasattr(targs, "sort_y"): |
| 148 | + targs.sort_y() |
| 149 | + else: |
| 150 | + targs.sort(key=lambda targ: targ.y) |
| 151 | + targets.append(targs) |
| 152 | + frame.targets[i_cam] = targs |
| 153 | + frame.num_targets[i_cam] = len(targs) |
| 154 | + |
| 155 | + mc = MatchedCoords(targs, cpar, cals[i_cam]) |
| 156 | + matched.append(mc) |
| 157 | + |
| 158 | + sorted_pos, sorted_corresp, num_targs = correspondences( |
| 159 | + frame, |
| 160 | + [mc.coords if hasattr(mc, "coords") else mc for mc in matched], |
| 161 | + vpar, |
| 162 | + cpar, |
| 163 | + cals, |
| 164 | + [0] * len(cals), |
| 165 | + ) |
| 166 | + |
| 167 | + print(f"Total targets used: {num_targs}") |
| 168 | + print("cpar image size:", cpar.get_image_size()) |
| 169 | + print(sorted_pos[0][0, 0, :]) |
| 170 | + return matched, sorted_corresp, sorted_pos |
| 171 | + |
| 172 | + |
| 173 | +@app.cell |
| 174 | +def _(cals, cpar, images, num_cams, sorted_pos, vpar): |
| 175 | + |
| 176 | + # We create a 2x2 grid of subplots for the 4 cameras |
| 177 | + fig_corr, axes_corr = plt.subplots(2, 2, figsize=(12, 10)) |
| 178 | + axes_flat_corr = axes_corr.flatten() |
| 179 | + |
| 180 | + # Colors by order of images: red, green, blue, yellow |
| 181 | + colors_corr = ['red', 'green', 'blue', 'yellow'] |
| 182 | + |
| 183 | + for cam_idx in range(num_cams): |
| 184 | + axes_flat_corr[cam_idx].imshow(images[cam_idx], cmap='gray') |
| 185 | + axes_flat_corr[cam_idx].set_title(f"Camera {cam_idx+1}") |
| 186 | + axes_flat_corr[cam_idx].axis('on') |
| 187 | + # Set limits to image bounds and prevent expanding |
| 188 | + img_h, img_w = images[cam_idx].shape[:2] |
| 189 | + axes_flat_corr[cam_idx].set_xlim(0, img_w) |
| 190 | + axes_flat_corr[cam_idx].set_ylim(img_h, 0) |
| 191 | + axes_flat_corr[cam_idx].autoscale(False) |
| 192 | + |
| 193 | + # Display detected points from correspondences: |
| 194 | + clique_colors_corr = ['red', 'green', 'yellow'] |
| 195 | + clique_labels_corr = ['Quadruplets', 'Triplets', 'Pairs'] |
| 196 | + |
| 197 | + for clique_idx_corr, pos_type_corr in enumerate(sorted_pos): |
| 198 | + c_color_corr = clique_colors_corr[clique_idx_corr] |
| 199 | + c_label_corr = clique_labels_corr[clique_idx_corr] |
| 200 | + |
| 201 | + for cam_idx in range(num_cams): |
| 202 | + if len(pos_type_corr) == 0: |
| 203 | + continue |
| 204 | + |
| 205 | + pts_corr = pos_type_corr[cam_idx] |
| 206 | + # Filter out invalid points (-999) |
| 207 | + valid_corr = (pts_corr[:, 0] > -900) & (pts_corr[:, 1] > -900) |
| 208 | + valid_pts_corr = pts_corr[valid_corr] |
| 209 | + if len(valid_pts_corr) > 0: |
| 210 | + axes_flat_corr[cam_idx].scatter( |
| 211 | + valid_pts_corr[:, 0], valid_pts_corr[:, 1], |
| 212 | + facecolors='none', edgecolors=c_color_corr, s=60, |
| 213 | + label=c_label_corr if cam_idx == 0 else "" |
| 214 | + ) |
| 215 | + |
| 216 | + # Add a legend to the first subplot to explain the colors |
| 217 | + axes_flat_corr[0].legend(loc='upper right', fontsize=8) |
| 218 | + |
| 219 | + def onclick_corr(event): |
| 220 | + if not event.inaxes: |
| 221 | + return |
| 222 | + |
| 223 | + # Restrict to right-click (button 3) to allow left-click for panning/zooming |
| 224 | + if event.button != 3: |
| 225 | + return |
| 226 | + |
| 227 | + ax = event.inaxes |
| 228 | + |
| 229 | + # Find which camera was clicked |
| 230 | + clicked_i_corr = None |
| 231 | + for j_cam_corr, a in enumerate(axes_flat_corr): |
| 232 | + if a == ax: |
| 233 | + clicked_i_corr = j_cam_corr |
| 234 | + break |
| 235 | + |
| 236 | + if clicked_i_corr is None: |
| 237 | + return |
| 238 | + |
| 239 | + x, y = event.xdata, event.ydata |
| 240 | + |
| 241 | + # Draw a point on the clicked image |
| 242 | + ax.plot(x, y, 'o', color=colors_corr[clicked_i_corr], markersize=6) |
| 243 | + |
| 244 | + point_corr = np.array([x, y]) |
| 245 | + num_points_corr = 100 |
| 246 | + |
| 247 | + # Draw epipolar lines on other images |
| 248 | + for j_other_corr in range(num_cams): |
| 249 | + if clicked_i_corr == j_other_corr: |
| 250 | + continue |
| 251 | + |
| 252 | + try: |
| 253 | + pts_epipolar_corr = epipolar_curve( |
| 254 | + point_corr, |
| 255 | + cals[clicked_i_corr], |
| 256 | + cals[j_other_corr], |
| 257 | + num_points_corr, |
| 258 | + cpar, |
| 259 | + vpar |
| 260 | + ) |
| 261 | + |
| 262 | + if len(pts_epipolar_corr) > 1: |
| 263 | + # Also we can mathematically filter to only those points inside the image |
| 264 | + img_h, img_w = images[j_other_corr].shape[:2] |
| 265 | + valid_mask_corr = (pts_epipolar_corr[:, 0] >= 0) & (pts_epipolar_corr[:, 0] <= img_w) & \ |
| 266 | + (pts_epipolar_corr[:, 1] >= 0) & (pts_epipolar_corr[:, 1] <= img_h) |
| 267 | + |
| 268 | + # If you just want it not to exceed the axis visually, |
| 269 | + # autoscale(False) and axis limits already handle it! |
| 270 | + axes_flat_corr[j_other_corr].plot(pts_epipolar_corr[:, 0], pts_epipolar_corr[:, 1], color=colors_corr[clicked_i_corr], linewidth=1.5) |
| 271 | + except Exception as e: |
| 272 | + print(f"Error drawing epipolar line for camera {j_other_corr+1}: {e}") |
| 273 | + |
| 274 | + fig_corr.canvas.draw_idle() |
| 275 | + |
| 276 | + # Connect the click event |
| 277 | + cid_corr = fig_corr.canvas.mpl_connect('button_press_event', onclick_corr) |
| 278 | + |
| 279 | + plt.tight_layout() |
| 280 | + # In Marimo, the last expression is displayed. If the user has an interactive backend, |
| 281 | + # it will support clicks. mo.mpl.interactive(fig) also helps for browser interactivity. |
| 282 | + mo.mpl.interactive(fig_corr) |
| 283 | + return |
| 284 | + |
| 285 | + |
| 286 | +@app.cell |
| 287 | +def _(cals, cpar, matched, sorted_corresp, sorted_pos, vpar): |
| 288 | + from openptv_python.orientation import point_positions |
| 289 | + concatenated_pos = np.concatenate(sorted_pos, axis=1) |
| 290 | + concatenated_corresp = np.concatenate(sorted_corresp, axis=1) |
| 291 | + |
| 292 | + flat = np.array( |
| 293 | + [corr.get_by_pnrs(corresp) for corr, corresp in zip(matched, concatenated_corresp)] |
| 294 | + ) |
| 295 | + |
| 296 | + pos, _ = point_positions(flat.transpose(1, 0, 2), cpar, cals, vpar) |
| 297 | + return (pos,) |
| 298 | + |
| 299 | + |
| 300 | +@app.cell |
| 301 | +def _(pos): |
| 302 | + fig = plt.figure(figsize=(12, 10)) |
| 303 | + ax = fig.add_subplot(projection="3d") |
| 304 | + |
| 305 | + # |
| 306 | + for row in pos: |
| 307 | + ax.plot(row[0], row[1], row[2], "ro") |
| 308 | + ax.text(row[0], row[1], row[2], f"{row[0]:.0f}", None) |
| 309 | + |
| 310 | + ax.set_xlim(pos[:, 0].min(), pos[:, 0].max()) |
| 311 | + ax.set_ylim(pos[:, 1].min(), pos[:, 1].max()) |
| 312 | + ax.set_zlim(pos[:, 2].min(), pos[:, 2].max()) |
| 313 | + |
| 314 | + ax.set_xlabel("x") |
| 315 | + ax.set_ylabel("y") |
| 316 | + ax.set_zlabel("z") |
| 317 | + |
| 318 | + mo.mpl.interactive(ax.figure) |
| 319 | + return |
| 320 | + |
| 321 | + |
| 322 | +@app.cell |
| 323 | +def _(): |
| 324 | + return |
| 325 | + |
| 326 | + |
| 327 | +if __name__ == "__main__": |
| 328 | + app.run() |
0 commit comments