Drexubery commited on
Commit
df13f4b
1 Parent(s): 2895855
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +167 -13
  2. app.py +99 -0
  3. configs/__pycache__/infer_config.cpython-38.pyc +0 -0
  4. configs/__pycache__/infer_config.cpython-39.pyc +0 -0
  5. configs/infer_config.py +58 -0
  6. configs/inference_pvd_1024.yaml +113 -0
  7. configs/inference_pvd_512.yaml +114 -0
  8. docs/config_help.md +43 -0
  9. docs/render_help.md +169 -0
  10. extern/dust3r/LICENSE +7 -0
  11. extern/dust3r/croco/LICENSE +52 -0
  12. extern/dust3r/croco/NOTICE +21 -0
  13. extern/dust3r/croco/README.MD +124 -0
  14. extern/dust3r/croco/assets/Chateau1.png +0 -0
  15. extern/dust3r/croco/assets/Chateau2.png +0 -0
  16. extern/dust3r/croco/assets/arch.jpg +0 -0
  17. extern/dust3r/croco/croco-stereo-flow-demo.ipynb +191 -0
  18. extern/dust3r/croco/datasets/__init__.py +0 -0
  19. extern/dust3r/croco/datasets/crops/README.MD +104 -0
  20. extern/dust3r/croco/datasets/crops/extract_crops_from_images.py +159 -0
  21. extern/dust3r/croco/datasets/habitat_sim/README.MD +76 -0
  22. extern/dust3r/croco/datasets/habitat_sim/__init__.py +0 -0
  23. extern/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
  24. extern/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
  25. extern/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
  26. extern/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
  27. extern/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
  28. extern/dust3r/croco/datasets/habitat_sim/paths.py +129 -0
  29. extern/dust3r/croco/datasets/pairs_dataset.py +109 -0
  30. extern/dust3r/croco/datasets/transforms.py +95 -0
  31. extern/dust3r/croco/demo.py +55 -0
  32. extern/dust3r/croco/interactive_demo.ipynb +271 -0
  33. extern/dust3r/croco/models/__pycache__/blocks.cpython-310.pyc +0 -0
  34. extern/dust3r/croco/models/__pycache__/blocks.cpython-38.pyc +0 -0
  35. extern/dust3r/croco/models/__pycache__/blocks.cpython-39.pyc +0 -0
  36. extern/dust3r/croco/models/__pycache__/croco.cpython-310.pyc +0 -0
  37. extern/dust3r/croco/models/__pycache__/croco.cpython-38.pyc +0 -0
  38. extern/dust3r/croco/models/__pycache__/croco.cpython-39.pyc +0 -0
  39. extern/dust3r/croco/models/__pycache__/dpt_block.cpython-310.pyc +0 -0
  40. extern/dust3r/croco/models/__pycache__/dpt_block.cpython-38.pyc +0 -0
  41. extern/dust3r/croco/models/__pycache__/dpt_block.cpython-39.pyc +0 -0
  42. extern/dust3r/croco/models/__pycache__/masking.cpython-310.pyc +0 -0
  43. extern/dust3r/croco/models/__pycache__/masking.cpython-38.pyc +0 -0
  44. extern/dust3r/croco/models/__pycache__/masking.cpython-39.pyc +0 -0
  45. extern/dust3r/croco/models/__pycache__/pos_embed.cpython-310.pyc +0 -0
  46. extern/dust3r/croco/models/__pycache__/pos_embed.cpython-38.pyc +0 -0
  47. extern/dust3r/croco/models/__pycache__/pos_embed.cpython-39.pyc +0 -0
  48. extern/dust3r/croco/models/blocks.py +241 -0
  49. extern/dust3r/croco/models/criterion.py +37 -0
  50. extern/dust3r/croco/models/croco.py +249 -0
README.md CHANGED
@@ -1,13 +1,167 @@
1
- ---
2
- title: ViewCrafter
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ___***ViewCrafter: Taming Video Diffusion Models for High-fidelity Novel View Synthesis***___
2
+ <div align="center">
3
+
4
+ <!-- <a href='https://arxiv.org/abs/2310.12190'><img src='https://img.shields.io/badge/arXiv-2310.12190-b31b1b.svg'></a> &nbsp; -->
5
+ <a href='https://drexubery.github.io/ViewCrafter/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;
6
+ <a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Page-blue'></a> &nbsp;
7
+
8
+ _**[Wangbo Yu*](https://scholar.google.com/citations?user=UOE8-qsAAAAJ&hl=zh-CN), [Jinbo Xing*](https://menghanxia.github.io), [Li Yuan*](), [Wenbo Hu&dagger;](https://wbhu.github.io/), [Xiaoyu Li](https://xiaoyu258.github.io/), [Zhipeng Huang](), <br> [Xiangjun Gao](https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en/), [Tien-Tsin Wong](https://www.cse.cuhk.edu.hk/~ttwong/myself.html), [Ying Shan](https://scholar.google.com/citations?hl=en&user=4oXBp9UAAAAJ&view_op=list_works&sortby=pubdate), [Yonghong Tian&dagger;]()**_
9
+ <br><br>
10
+
11
+ </div>
12
+
13
+ ## 🔆 Introduction
14
+
15
+ ViewCrafter can generate high-fidelity novel views from <strong>a single or sparse reference image</strong>, while also supporting highly precise pose control. Below shows an example:
16
+
17
+
18
+ ### Zero-shot novel view synthesis (single view)
19
+ <table class="center">
20
+ <tr style="font-weight: bolder;text-align:center;">
21
+ <td>Reference image</td>
22
+ <td>Camera trajecotry</td>
23
+ <td>Generated novel view video</td>
24
+ </tr>
25
+
26
+ <tr>
27
+ <td>
28
+ <img src=assets/train.png width="250">
29
+ </td>
30
+ <td>
31
+ <img src=assets/ctrain.gif width="150">
32
+ </td>
33
+ <td>
34
+ <img src=assets/train.gif width="250">
35
+ </td>
36
+ </tr>
37
+ <tr>
38
+ <td>
39
+ <img src=assets/wst.png width="250">
40
+ </td>
41
+ <td>
42
+ <img src=assets/cwst.gif width="150">
43
+ </td>
44
+ <td>
45
+ <img src=assets/wst.gif width="250">
46
+ </td>
47
+ </tr>
48
+ <tr>
49
+ <td>
50
+ <img src=assets/flower.png width="250">
51
+ </td>
52
+ <td>
53
+ <img src=assets/cflower.gif width="150">
54
+ </td>
55
+ <td>
56
+ <img src=assets/flower.gif width="250">
57
+ </td>
58
+ </tr>
59
+ </table>
60
+
61
+ ### Zero-shot novel view synthesis (2 views)
62
+ <table class="center">
63
+ <tr style="font-weight: bolder;text-align:center;">
64
+ <td>Reference image 1</td>
65
+ <td>Reference image 2</td>
66
+ <td>Generated novel view video</td>
67
+ </tr>
68
+
69
+ <tr>
70
+ <td>
71
+ <img src=assets/car2_1.png width="250">
72
+ </td>
73
+ <td>
74
+ <img src=assets/car2_2.png width="250">
75
+ </td>
76
+ <td>
77
+ <img src=assets/car2.gif width="250">
78
+ </td>
79
+ </tr>
80
+ <tr>
81
+ <td>
82
+ <img src=assets/barn_1.png width="250">
83
+ </td>
84
+ <td>
85
+ <img src=assets/barn_2.png width="250">
86
+ </td>
87
+ <td>
88
+ <img src=assets/barn.gif width="250">
89
+ </td>
90
+ </tr>
91
+ <tr>
92
+ <td>
93
+ <img src=assets/house_1.png width="250">
94
+ </td>
95
+ <td>
96
+ <img src=assets/house_2.png width="250">
97
+ </td>
98
+ <td>
99
+ <img src=assets/house.gif width="250">
100
+ </td>
101
+ </tr>
102
+ </table>
103
+
104
+ ## 🗓️ TODO
105
+ - [x] [2024-09-01] Launch the project page and update the arXiv preprint.
106
+ - [x] [2024-09-01] Release pretrained models and the code for single-view novel view synthesis.
107
+ - [ ] Release the code for sparse-view novel view synthesis.
108
+ - [ ] Release the code for iterative novel view synthesis.
109
+ - [ ] Release the code for 3D-GS reconstruction.
110
+ <br>
111
+
112
+ ## 🧰 Models
113
+
114
+ |Model|Resolution|Frames|GPU Mem. & Inference Time (A100, ddim 50steps)|Checkpoint|
115
+ |:---------|:---------|:--------|:--------|:--------|
116
+ |ViewCrafter_25|576x1024|25| 23.5GB & 120s (`perframe_ae=True`)|[Hugging Face](https://huggingface.co/Drexubery/ViewCrafter_25/blob/main/model.ckpt)|
117
+ |ViewCrafter_16|576x1024|16| 18.3GB & 75s (`perframe_ae=True`)|[Hugging Face](https://huggingface.co/Drexubery/ViewCrafter_16/blob/main/model.ckpt)|
118
+
119
+
120
+ Currently, we provide two versions of the model: a base model that generates 16 frames at a time and an enhanced model that generates 25 frames at a time. The inference time can be reduced by using fewer DDIM steps.
121
+
122
+ ## ⚙️ Setup
123
+
124
+ ### 1. Clone ViewCrafter
125
+ ```bash
126
+ git clone https://github.com/Drexubery/ViewCrafter.git
127
+ cd ViewCrafter
128
+ ```
129
+ ### 2. Installation
130
+
131
+ ```bash
132
+ # Create conda environment
133
+ conda create -n viewcrafter python=3.9.16
134
+ conda activate viewcrafter
135
+ pip install -r requirements.txt
136
+
137
+ # Install PyTorch3D
138
+ conda install https://anaconda.org/pytorch3d/pytorch3d/0.7.5/download/linux-64/pytorch3d-0.7.5-py39_cu117_pyt1131.tar.bz2
139
+
140
+ # Download DUSt3R
141
+ mkdir -p checkpoints/
142
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
143
+
144
+ ```
145
+
146
+ ## 💫 Inference
147
+ ### 1. Command line
148
+
149
+ (1) Download pretrained model (ViewCrafter_25 for example) and put the `model.ckpt` in `checkpoints/model.ckpt`. \
150
+ (2) Run [inference.py](./inference.py) using the following script. Please refer to the [configuration document](docs/config_help.md) and [render document](docs/render_help.md) to set up inference parameters and camera trajectory.
151
+ ```bash
152
+ sh run.sh
153
+ ```
154
+
155
+ ### 2. Local Gradio demo
156
+
157
+ Download the pretrained model and put it in the corresponding directory according to the previous guideline, then run:
158
+ ```bash
159
+ python gradio_app.py
160
+ ```
161
+
162
+ <a name="disc"></a>
163
+ ## 📢 Disclaimer
164
+ ⚠️This is an open-source research exploration rather than a commercial product, so it may not meet all your expectations. Due to the variability of the video diffusion model, you may encounter failure cases. Try using different seeds and adjusting the render configs if the results are not desirable.
165
+ Users are free to create videos using this tool, but they must comply with local laws and use it responsibly. The developers do not assume any responsibility for potential misuse by users.
166
+ ****
167
+
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ import random
5
+ from viewcrafter import ViewCrafter
6
+ from configs.infer_config import get_parser
7
+ import torch
8
+
9
+
10
+ i2v_examples = [
11
+ ['test/images/boy.png', 0, 1.0, '0 40', '0 0', '0 0', 50, 123],
12
+ ['test/images/car.jpeg', 0, 1.0, '0 -35', '0 0', '0 -0.1', 50, 123],
13
+ ['test/images/fruit.jpg', 0, 1.0, '0 -3 -15 -20 -17 -5 0', '0 -2 -5 -10 -8 -5 0 2 5 3 0', '0 0', 50, 123],
14
+ ['test/images/room.png', 5, 1.0, '0 3 10 20 17 10 0', '0 -2 -8 -6 0 2 5 3 0', '0 -0.02 -0.09 -0.16 -0.09 0', 50, 123],
15
+ ['test/images/castle.png', 0, 1.0, '0 30', '0 -1 -5 -4 0 1 5 4 0', '0 -0.2', 50, 123],
16
+ ]
17
+
18
+ max_seed = 2 ** 31
19
+
20
+ os.system('conda install https://anaconda.org/pytorch3d/pytorch3d/0.7.5/download/linux-64/pytorch3d-0.7.5-py39_cu117_pyt1131.tar.bz2')
21
+
22
+
23
+
24
+ def viewcrafter_demo(opts):
25
+ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
26
+ image2video = ViewCrafter(opts, gradio = True)
27
+ with gr.Blocks(analytics_enabled=False, css=css) as viewcrafter_iface:
28
+ gr.Markdown("<div align='center'> <h1> ViewCrafter: Taming Video Diffusion Models for High-fidelity Novel View Synthesis </span> </h1> \
29
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
30
+ <a href='https://scholar.google.com/citations?user=UOE8-qsAAAAJ&hl=zh-CN'>Wangbo Yu</a>, \
31
+ <a href='https://doubiiu.github.io/'>Jinbo Xing</a>, <a href=''>Li Yuan</a>, \
32
+ <a href='https://wbhu.github.io/'>Wenbo Hu</a>, <a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>,\
33
+ <a href=''>Zhipeng Huang</a>, <a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en/'>Xiangjun Gao</a>,\
34
+ <a href='https://www.cse.cuhk.edu.hk/~ttwong/myself.html/'>Tien-Tsin Wong</a>,\
35
+ <a href='https://scholar.google.com/citations?hl=en&user=4oXBp9UAAAAJ&view_op=list_works&sortby=pubdate/'>Ying Shan</a>\
36
+ <a href=''>Yonghong Tian</a>\
37
+ </h2> \
38
+ <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Drexubery/ViewCrafter/blob/main/docs/render_help.md'> [Guideline] </a>\
39
+ <a style='font-size:18px;color: #000000' href=''> [ArXiv] </a>\
40
+ <a style='font-size:18px;color: #000000' href='https://drexubery.github.io/ViewCrafter/'> [Project Page] </a>\
41
+ <a style='font-size:18px;color: #000000' href='https://github.com/Drexubery/ViewCrafter'> [Github] </a> </div>")
42
+
43
+ #######image2video######
44
+ with gr.Tab(label="ViewCrafter_25, 'single_view_txt' mode"):
45
+ with gr.Column():
46
+ with gr.Row():
47
+ with gr.Column():
48
+ with gr.Row():
49
+ i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
50
+ with gr.Row():
51
+ i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
52
+ with gr.Row():
53
+ i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
54
+ with gr.Row():
55
+ i2v_d_phi = gr.Text(label='d_phi sequence, should start with 0')
56
+ with gr.Row():
57
+ i2v_d_theta = gr.Text(label='d_theta sequence, should start with 0')
58
+ with gr.Row():
59
+ i2v_d_r = gr.Text(label='d_r sequence, should start with 0')
60
+ with gr.Row():
61
+ i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
62
+ with gr.Row():
63
+ i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=max_seed, step=1, value=123)
64
+ i2v_end_btn = gr.Button("Generate")
65
+ # with gr.Tab(label='Result'):
66
+ with gr.Column():
67
+ with gr.Row():
68
+ i2v_traj_video = gr.Video(label="Camera Trajectory",elem_id="traj_vid",autoplay=True,show_share_button=True)
69
+ with gr.Row():
70
+ i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
71
+
72
+ gr.Examples(examples=i2v_examples,
73
+ inputs=[i2v_input_image, i2v_elevation, i2v_center_scale, i2v_d_phi, i2v_d_theta, i2v_d_r, i2v_steps, i2v_seed],
74
+ outputs=[i2v_traj_video,i2v_output_video],
75
+ fn = image2video.run_gradio,
76
+ cache_examples=False,
77
+ )
78
+
79
+ # image2video.run_gradio(i2v_input_image='test/images/boy.png', i2v_elevation='10', i2v_d_phi='0 40', i2v_d_theta='0 0', i2v_d_r='0 0', i2v_center_scale=1, i2v_steps=50, i2v_seed=123)
80
+ i2v_end_btn.click(inputs=[i2v_input_image, i2v_elevation, i2v_center_scale, i2v_d_phi, i2v_d_theta, i2v_d_r, i2v_steps, i2v_seed],
81
+ outputs=[i2v_traj_video,i2v_output_video],
82
+ fn = image2video.run_gradio
83
+ )
84
+
85
+ return viewcrafter_iface
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = get_parser() # infer_config.py
90
+ opts = parser.parse_args() # default device: 'cuda:0'
91
+ opts.save_dir = './gradio_tmp'
92
+ os.makedirs(opts.save_dir,exist_ok=True)
93
+ test_tensor = torch.Tensor([0]).cuda()
94
+ opts.device = str(test_tensor.device)
95
+ viewcrafter_iface = viewcrafter_demo(opts)
96
+ viewcrafter_iface.queue(max_size=10)
97
+ viewcrafter_iface.launch()
98
+ # viewcrafter_iface.launch(server_name='127.0.0.1', server_port=80, max_threads=1,debug=False)
99
+
configs/__pycache__/infer_config.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
configs/__pycache__/infer_config.cpython-39.pyc ADDED
Binary file (4.22 kB). View file
 
configs/infer_config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+ def get_parser():
5
+ parser = argparse.ArgumentParser()
6
+
7
+ ## general
8
+ parser.add_argument('--image_dir', type=str, default='./test/images/fruit.png', help='Image file path')
9
+ parser.add_argument('--out_dir', type=str, default='./output', help='Output directory')
10
+ parser.add_argument('--device', type=str, default='cuda:0', help='The device to use')
11
+ parser.add_argument('--exp_name', type=str, default=None, help='Experiment name, use image file name by default')
12
+
13
+ ## renderer
14
+ parser.add_argument('--mode', type=str, default='single_view_txt', help="Currently we support 'single_view_txt' and 'single_view_target'")
15
+ parser.add_argument('--traj_txt', type=str, help="Required for 'single_view_txt' mode, a txt file that specify camera trajectory")
16
+ parser.add_argument('--elevation', type=float, default=5., help='The elevation angle of the input image in degree. Estimate a rough value based on your visual judgment' )
17
+ parser.add_argument('--center_scale', type=float, default=1., help='Range: (0, 2]. Scale factor for the spherical radius (r). By default, r is set to the depth value of the center pixel (H//2, W//2) of the reference image')
18
+ parser.add_argument('--d_theta', nargs='+', type=int, default=10., help="Range: [-40, 40]. Required for 'single_view_target' mode, specify target theta angle as theta + d_theta")
19
+ parser.add_argument('--d_phi', nargs='+', type=int, default=30., help="Range: [-45, 45]. Required for 'single_view_target' mode, specify target phi angle as phi + d_phi")
20
+ parser.add_argument('--d_r', nargs='+', type=float, default=-.2, help="Range: [-.5, .5]. Required for 'single_view_target' mode, specify target radius as r + r*dr")
21
+ parser.add_argument('--mask_image', type=bool, default=False, help='Required for mulitpule reference images and iterative mode')
22
+ parser.add_argument('--mask_pc', type=bool, default=True, help='Required for mulitpule reference images and iterative mode')
23
+ parser.add_argument('--reduce_pc', default=False, help='Required for mulitpule reference images and iterative mode')
24
+ parser.add_argument('--bg_trd', type=float, default=0., help='Required for mulitpule reference images and iterative mode, set to 0. is no mask')
25
+ parser.add_argument('--dpt_trd', type=float, default=1., help='Required for mulitpule reference images and iterative mode, limit the max depth by * dpt_trd')
26
+
27
+
28
+ ## diffusion
29
+ parser.add_argument("--ckpt_path", type=str, default='./checkpoints/model.ckpt', help="checkpoint path")
30
+ parser.add_argument("--config", type=str, default='./configs/inference_pvd_1024.yaml', help="config (yaml) path")
31
+ parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM, reduce to 10 to speed up inference")
32
+ parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)")
33
+ parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
34
+ parser.add_argument("--height", type=int, default=576, help="image height, in pixel space")
35
+ parser.add_argument("--width", type=int, default=1024, help="image width, in pixel space")
36
+ parser.add_argument("--frame_stride", type=int, default=10, help="Fixed")
37
+ parser.add_argument("--unconditional_guidance_scale", type=float, default=7.5, help="prompt classifier-free guidance")
38
+ parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
39
+ parser.add_argument("--video_length", type=int, default=25, help="inference video length, change to 16 if you use 16 frame model")
40
+ parser.add_argument("--negative_prompt", default=False, help="unused")
41
+ parser.add_argument("--text_input", default=True, help="unused")
42
+ parser.add_argument("--prompt", type=str, default='Rotating view of a scene', help="Fixed")
43
+ parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
44
+ parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
45
+ parser.add_argument("--timestep_spacing", type=str, default="uniform_trailing", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
46
+ parser.add_argument("--guidance_rescale", type=float, default=0.7, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
47
+ parser.add_argument("--perframe_ae", default=True, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
48
+ parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt")
49
+
50
+ ## dust3r
51
+ parser.add_argument('--model_path', type=str, default='./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth', help='The path of the model')
52
+ parser.add_argument('--batch_size', default=1)
53
+ parser.add_argument('--schedule', type=str, default='linear')
54
+ parser.add_argument('--niter', default=300)
55
+ parser.add_argument('--lr', default=0.01)
56
+ parser.add_argument('--min_conf_thr', default=3.0) # minimum=1.0, maximum=20
57
+
58
+ return parser
configs/inference_pvd_1024.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_checkpoint: /apdcephfs_cq10/share_1290939/vg_share/vip3d_share/DC_1024/model.ckpt
3
+ base_learning_rate: 1.0e-05
4
+ scale_lr: False
5
+ target: lvdm.models.ddpm3d.VIPLatentDiffusion
6
+ params:
7
+ rescale_betas_zero_snr: True
8
+ parameterization: "v"
9
+ linear_start: 0.00085
10
+ linear_end: 0.012
11
+ num_timesteps_cond: 1
12
+ log_every_t: 200
13
+ timesteps: 1000
14
+ first_stage_key: video
15
+ cond_stage_key: caption
16
+ cond_stage_trainable: False
17
+ image_proj_model_trainable: True
18
+ conditioning_key: hybrid
19
+ image_size: [72, 128]
20
+ channels: 4
21
+ scale_by_std: False
22
+ scale_factor: 0.18215
23
+ use_ema: False
24
+ uncond_prob: 0.05
25
+ uncond_type: 'empty_seq'
26
+ rand_cond_frame: true
27
+ use_dynamic_rescale: true
28
+ base_scale: 0.3
29
+ fps_condition_type: 'fps'
30
+ perframe_ae: True
31
+ loop_video: Flase
32
+
33
+ unet_config:
34
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
35
+ params:
36
+ in_channels: 8
37
+ out_channels: 4
38
+ model_channels: 320
39
+ attention_resolutions:
40
+ - 4
41
+ - 2
42
+ - 1
43
+ num_res_blocks: 2
44
+ channel_mult:
45
+ - 1
46
+ - 2
47
+ - 4
48
+ - 4
49
+ dropout: 0.1
50
+ num_head_channels: 64
51
+ transformer_depth: 1
52
+ context_dim: 1024
53
+ use_linear: true
54
+ use_checkpoint: True
55
+ temporal_conv: True
56
+ temporal_attention: True
57
+ temporal_selfatt_only: true
58
+ use_relative_position: false
59
+ use_causal_attention: False
60
+ temporal_length: 16
61
+ addition_attention: true
62
+ image_cross_attention: true
63
+ default_fs: 10
64
+ fs_condition: true
65
+
66
+ first_stage_config:
67
+ target: lvdm.models.autoencoder.AutoencoderKL
68
+ params:
69
+ embed_dim: 4
70
+ monitor: val/rec_loss
71
+ ddconfig:
72
+ double_z: True
73
+ z_channels: 4
74
+ resolution: 256
75
+ in_channels: 3
76
+ out_ch: 3
77
+ ch: 128
78
+ ch_mult:
79
+ - 1
80
+ - 2
81
+ - 4
82
+ - 4
83
+ num_res_blocks: 2
84
+ attn_resolutions: []
85
+ dropout: 0.0
86
+ lossconfig:
87
+ target: torch.nn.Identity
88
+
89
+ cond_stage_config:
90
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
91
+ params:
92
+ version: /apdcephfs_cq10/share_1290939/vg_share/vip3d_share/OpenCLIP-ViT-H-14-laion2B-s32B-b79K/blobs/9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4
93
+ freeze: true
94
+ layer: "penultimate"
95
+
96
+ img_cond_stage_config:
97
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
98
+ params:
99
+ version: /apdcephfs_cq10/share_1290939/vg_share/vip3d_share/OpenCLIP-ViT-H-14-laion2B-s32B-b79K/blobs/9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4
100
+ freeze: true
101
+
102
+ image_proj_stage_config:
103
+ target: lvdm.modules.encoders.resampler.Resampler
104
+ params:
105
+ dim: 1024
106
+ depth: 4
107
+ dim_head: 64
108
+ heads: 12
109
+ num_queries: 16
110
+ embedding_dim: 1280
111
+ output_dim: 1024
112
+ ff_mult: 4
113
+ video_length: 16
configs/inference_pvd_512.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_checkpoint: /apdcephfs_cq10/share_1290939/vg_share/vip3d_share/3d_320_512_SD-IPA_ztsnr_v_builton11_10k_DL3DVdust3r_fps10_allmode/epoch=7-step=60000.ckpt
3
+ base_learning_rate: 1.0e-05
4
+ scale_lr: False
5
+ target: lvdm.models.ddpm3d.VIPLatentDiffusion
6
+ params:
7
+ rescale_betas_zero_snr: True
8
+ parameterization: "v"
9
+ linear_start: 0.00085
10
+ linear_end: 0.012
11
+ num_timesteps_cond: 1
12
+ log_every_t: 200
13
+ timesteps: 1000
14
+ first_stage_key: video
15
+ cond_stage_key: caption
16
+ cond_stage_trainable: False
17
+ image_proj_model_trainable: True
18
+ conditioning_key: hybrid
19
+ image_size: [40, 64]
20
+ channels: 4
21
+ scale_by_std: False
22
+ scale_factor: 0.18215
23
+ use_ema: False
24
+ uncond_prob: 0.05
25
+ uncond_type: 'empty_seq'
26
+ rand_cond_frame: true
27
+ use_dynamic_rescale: true
28
+ base_scale: 0.7
29
+ fps_condition_type: 'fps'
30
+ perframe_ae: True
31
+ loop_video: False
32
+ fix_temporal: True
33
+
34
+ unet_config:
35
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
36
+ params:
37
+ in_channels: 8
38
+ out_channels: 4
39
+ model_channels: 320
40
+ attention_resolutions:
41
+ - 4
42
+ - 2
43
+ - 1
44
+ num_res_blocks: 2
45
+ channel_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ dropout: 0.1
51
+ num_head_channels: 64
52
+ transformer_depth: 1
53
+ context_dim: 1024
54
+ use_linear: true
55
+ use_checkpoint: True
56
+ temporal_conv: True
57
+ temporal_attention: True
58
+ temporal_selfatt_only: true
59
+ use_relative_position: false
60
+ use_causal_attention: False
61
+ temporal_length: 16
62
+ addition_attention: true
63
+ image_cross_attention: true
64
+ default_fs: 10
65
+ fs_condition: true
66
+
67
+ first_stage_config:
68
+ target: lvdm.models.autoencoder.AutoencoderKL
69
+ params:
70
+ embed_dim: 4
71
+ monitor: val/rec_loss
72
+ ddconfig:
73
+ double_z: True
74
+ z_channels: 4
75
+ resolution: 256
76
+ in_channels: 3
77
+ out_ch: 3
78
+ ch: 128
79
+ ch_mult:
80
+ - 1
81
+ - 2
82
+ - 4
83
+ - 4
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ cond_stage_config:
91
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
92
+ params:
93
+ version: /apdcephfs_cq10/share_1290939/vg_share/vip3d_share/OpenCLIP-ViT-H-14-laion2B-s32B-b79K/blobs/9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4
94
+ freeze: true
95
+ layer: "penultimate"
96
+
97
+ img_cond_stage_config:
98
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
99
+ params:
100
+ version: /apdcephfs_cq10/share_1290939/vg_share/vip3d_share/OpenCLIP-ViT-H-14-laion2B-s32B-b79K/blobs/9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4
101
+ freeze: true
102
+
103
+ image_proj_stage_config:
104
+ target: lvdm.modules.encoders.resampler.Resampler
105
+ params:
106
+ dim: 1024
107
+ depth: 4
108
+ dim_head: 64
109
+ heads: 12
110
+ num_queries: 16
111
+ embedding_dim: 1280
112
+ output_dim: 1024
113
+ ff_mult: 4
114
+ video_length: 16
docs/config_help.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Important configuration options for [inference.py](../inference.py):
2
+
3
+ ### 1. General configs
4
+ | Configuration | default | Explanation |
5
+ |:------------- |:----- | :------------- |
6
+ | `--image_dir` | './test/images/fruit.png' | Image file path |
7
+ | `--out_dir` | './output' | Output directory |
8
+ | `--device` | 'cuda:0' | The device to use |
9
+ | `--exp_name` | None | Experiment name, use image file name by default |
10
+ ### 2. Point cloud render configs
11
+ #### The definition of world coordinate system and tips for adjusting point cloud render configs are illustrated in [render document](./render_help.md).
12
+ | Configuration | default | Explanation |
13
+ |:------------- |:----- | :------------- |
14
+ | `--mode` | 'single_view_txt' | Currently we support 'single_view_txt' and 'single_view_target' |
15
+ | `--traj_txt` | None | Required for 'single_view_txt' mode, a txt file that specify camera trajectory |
16
+ | `--elevation` | 5. | The elevation angle of the input image in degree. Estimate a rough value based on your visual judgment |
17
+ | `--center_scale` | 1. | Scale factor for the spherical radius (r). By default, r is set to the depth value of the center pixel (H//2, W//2) of the reference image |
18
+ | `--d_theta` | 10. | Required for 'single_view_target' mode, specify target theta angle as (theta + d_theta) |
19
+ | `--d_phi` | 30. | Required for 'single_view_target' mode, specify target phi angle as (phi + d_phi) |
20
+ | `--d_r` | -.2 | Required for 'single_view_target' mode, specify target radius as (r + r*dr) |
21
+ ### 3. Diffusion configs
22
+ | Configuration | default | Explanation |
23
+ |:------------- |:----- | :------------- |
24
+ | `--ckpt_path` | './checkpoints/ViewCrafter_25.ckpt' | Checkpoint path |
25
+ | `--config` | './configs/inference_pvd_1024.yaml' | Config (yaml) path |
26
+ | `--ddim_steps` | 50 | Steps of ddim if positive, otherwise use DDPM, reduce to 10 to speed up inference |
27
+ | `--ddim_eta` | 1.0 | Eta for ddim sampling (0.0 yields deterministic sampling) |
28
+ | `--bs` | 1 | Batch size for inference, should be one |
29
+ | `--height` | 576 | Image height, in pixel space |
30
+ | `--width` | 1024 | Image width, in pixel space |
31
+ | `--frame_stride` | 10 | Fixed |
32
+ | `--unconditional_guidance_scale` | 7.5 | Prompt classifier-free guidance |
33
+ | `--seed` | 123 | Seed for seed_everything |
34
+ | `--video_length` | 25 | Inference video length, change to 16 if you use 16 frame model |
35
+ | `--negative_prompt` | False | Unused |
36
+ | `--text_input` | False | Unused |
37
+ | `--prompt` | 'Rotating view of a scene' | Fixed |
38
+ | `--multiple_cond_cfg` | False | Use multi-condition cfg or not |
39
+ | `--cfg_img` | None | Guidance scale for image conditioning |
40
+ | `--timestep_spacing` | "uniform_trailing" | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. |
41
+ | `--guidance_rescale` | 0.7 | Guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) |
42
+ | `--perframe_ae` | True | If we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024 |
43
+ | `--n_samples` | 1 | Num of samples per prompt |
docs/render_help.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Point cloud render configurations
2
+ | Configuration | default | Explanation |
3
+ |:------------- |:----- | :------------- |
4
+ | `--mode` | 'single_view_txt' | Currently we support 'single_view_txt' and 'single_view_target' mode|
5
+ | `--traj_txt` | None | Required for 'single_view_txt' mode, a txt file that specify camera trajectory |
6
+ | `--elevation` | 5. | The elevation angle of the input image in degree. Estimate a rough value based on your visual judgment |
7
+ | `--center_scale` | 1. | Range: (0, 2]. Scale factor for the spherical radius (r). By default, r is set to the depth value of the center pixel (H//2, W//2) of the reference image |
8
+ | `--d_theta` | 10. | Range: [-40, 40]. Required for 'single_view_target' mode, specify target theta angle as (theta + d_theta) |
9
+ | `--d_phi` | 30. | Range: [-45, 45]. Required for 'single_view_target' mode, specify target phi angle as (phi + d_phi) |
10
+ | `--d_r` | -.2 | Range: [-0.5, 0.5]. Required for 'single_view_target' mode, specify target radius as (r + r*dr) |
11
+
12
+ <hr>
13
+
14
+ ![fig](../assets/doc_world.png)
15
+
16
+ The image above illustrates the definition of the world coordinate system.
17
+
18
+ **1.** Take a single reference image as an example, you first need to estimate an elevation angle `--elevation` that represents the angle at which the image was taken. A value greater than 0 indicates a top-down view, and it doesn't need to be precise.
19
+
20
+ **2.** The origin of the world coordinate system is by default defined at the point cloud corresponding to the center pixel of the reference image. You can adjust the position of the origin by modifying `--center_scale`; a value less than 1 brings the origin closer to the reference camera.
21
+
22
+ **3.** We use spherical coordinates to represent the camera pose. The initial camera is located at (r, 0, 0). You can specify a target camera pose by setting `--mode` as 'single_view_target'. As shown in the figure above, a positive `--d_phi` moves the camera to the right, a positive `--d_theta` moves the camera down, and a negative `--d_r` moves the camera forward (closer to the origin). The program will interpolate a smooth trajectory between the initial pose and the target pose, then rendering the point cloud along that trajectory. Below shows some examples:
23
+ <table class="center">
24
+ <tr style="font-weight: bolder;text-align:center;">
25
+ <td> --center_scale </td>
26
+ <td> --d_phi </td>
27
+ <td> --d_theta </td>
28
+ <td> --d_r </td>
29
+ <td>Render results</td>
30
+ </tr>
31
+ <tr>
32
+ <td>
33
+ 0.5
34
+ </td>
35
+ <td>
36
+ 45.
37
+ </td>
38
+ <td>
39
+ 0.
40
+ </td>
41
+ <td>
42
+ 0.
43
+ </td>
44
+ <td>
45
+ <img src=../assets/doc_tgt_scale5.gif width="250">
46
+ </td>
47
+ </tr>
48
+ <tr>
49
+ <td>
50
+ 1.
51
+ </td>
52
+ <td>
53
+ 45.
54
+ </td>
55
+ <td>
56
+ 0.
57
+ </td>
58
+ <td>
59
+ 0.
60
+ </td>
61
+ <td>
62
+ <img src=../assets/doc_tgt_phi45.gif width="250">
63
+ </td>
64
+ </tr>
65
+ <tr>
66
+ <td>
67
+ 1.
68
+ </td>
69
+ <td>
70
+ 0.
71
+ </td>
72
+ <td>
73
+ -30.
74
+ </td>
75
+ <td>
76
+ 0.
77
+ </td>
78
+ <td>
79
+ <img src=../assets/doc_tgt_theta30.gif width="250">
80
+ </td>
81
+ </tr>
82
+ <tr>
83
+ <td>
84
+ 1.
85
+ </td>
86
+ <td>
87
+ 0.
88
+ </td>
89
+ <td>
90
+ 0.
91
+ </td>
92
+ <td>
93
+ -0.5
94
+ </td>
95
+ <td>
96
+ <img src=../assets/doc_tgt_r5.gif width="250">
97
+ </td>
98
+ </tr>
99
+ <tr>
100
+ <td>
101
+ 1.
102
+ </td>
103
+ <td>
104
+ 45.
105
+ </td>
106
+ <td>
107
+ -30.
108
+ </td>
109
+ <td>
110
+ -0.5
111
+ </td>
112
+ <td>
113
+ <img src=../assets/doc_tgt_combine.gif width="250">
114
+ </td>
115
+ </tr>
116
+ </table>
117
+
118
+ **4.** You can also create a camera trajectory by specifying a sequence of d_phi, d_theta, d_r values. Set `--mode` as 'single_view_txt' and write the sequences in a txt file (example: [loop1.txt](../assets/loop1.txt)). The first line of the txt file should contain the target d_phi sequence, the second line the target d_theta sequence, and the third line the target d_r sequence. Each sequence should start with 0, and the length of each sequence should range from 2 to 25. Then, input the txt file path into `--traj_txt`. The program will interpolate a smooth trajectory based on the sequences you provide. Below shows some examples:
119
+ <table class="center">
120
+ <tr style="font-weight: bolder;text-align:center;">
121
+ <td> Target sequences </td>
122
+ <td> Trajectory visulization </td>
123
+ <td>Render results</td>
124
+ </tr>
125
+ <tr>
126
+ <td>
127
+ 0 -3 -15 -20 -17 -5 0 <br>
128
+ 0 -2 -5 -10 -8 -5 0 2 5 10 8 5 0 <br>
129
+ 0 0
130
+ </td>
131
+ <td>
132
+ <img src=../assets/loop1_traj.gif width="100">
133
+ </td>
134
+ <td>
135
+ <img src=../assets/loop1_render.gif width="250">
136
+ </td>
137
+ </tr>
138
+ <tr>
139
+ <td>
140
+ 0 3 10 20 17 10 0 <br>
141
+ 0 -2 -8 -6 0 2 8 6 0 <br>
142
+ 0 -0.02 -0.09 -0.18 -0.16 -0.09 0
143
+ </td>
144
+
145
+ <td>
146
+ <img src=../assets/loop2_traj.gif width="100">
147
+ </td>
148
+ <td>
149
+ <img src=../assets/loop2_render.gif width="250">
150
+ </td>
151
+ </tr>
152
+ <tr>
153
+ <td>
154
+ 0 40 <br>
155
+ 0 -1 -3 -7 -6 -4 0 1 3 7 6 4 0 -1 -3 -7 -6 -4 0 1 3 7 6 4 0 <br>
156
+ 0 0
157
+ </td>
158
+ <td>
159
+ <img src=../assets/wave_traj.gif width="100">
160
+ </td>
161
+ <td>
162
+ <img src=../assets/wave_render.gif width="250">
163
+ </td>
164
+ </tr>
165
+ </table>
166
+
167
+ - **Tips:** A sequence in which the differences between adjacent values increase in one direction results in a smoother trajectory. Ensure that these differences are not too large; otherwise, they may lead to abrupt camera movements, causing the model to produce artifacts such as content drift.
168
+
169
+
extern/dust3r/LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
extern/dust3r/croco/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8
+
9
+
10
+ SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11
+
12
+ ***************************
13
+
14
+ NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15
+
16
+ This software is being redistributed in a modifiled form. The original form is available here:
17
+
18
+ https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+
20
+ This software in this file incorporates parts of the following software available here:
21
+
22
+ Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23
+ available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24
+
25
+ MoCo v3: https://github.com/facebookresearch/moco-v3
26
+ available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27
+
28
+ DeiT: https://github.com/facebookresearch/deit
29
+ available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30
+
31
+
32
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33
+
34
+ https://github.com/facebookresearch/mae/blob/main/LICENSE
35
+
36
+ Attribution-NonCommercial 4.0 International
37
+
38
+ ***************************
39
+
40
+ NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41
+
42
+ This software is being redistributed in a modifiled form. The original form is available here:
43
+
44
+ https://github.com/rwightman/pytorch-image-models
45
+
46
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47
+
48
+ https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49
+
50
+ Apache License
51
+ Version 2.0, January 2004
52
+ http://www.apache.org/licenses/
extern/dust3r/croco/NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo
2
+ Copyright 2022-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ facebookresearch/mae
10
+ https://github.com/facebookresearch/mae
11
+
12
+ Attribution-NonCommercial 4.0 International
13
+
14
+ ====
15
+
16
+ rwightman/pytorch-image-models
17
+ https://github.com/rwightman/pytorch-image-models
18
+
19
+ Apache License
20
+ Version 2.0, January 2004
21
+ http://www.apache.org/licenses/
extern/dust3r/croco/README.MD ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
2
+
3
+ [[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
4
+
5
+ This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
6
+
7
+ ![image](assets/arch.jpg)
8
+
9
+ ```bibtex
10
+ @inproceedings{croco,
11
+ title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
12
+ author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
13
+ booktitle={{NeurIPS}},
14
+ year={2022}
15
+ }
16
+
17
+ @inproceedings{croco_v2,
18
+ title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
19
+ author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
20
+ booktitle={ICCV},
21
+ year={2023}
22
+ }
23
+ ```
24
+
25
+ ## License
26
+
27
+ The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
28
+ Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
29
+ Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
30
+
31
+ ## Preparation
32
+
33
+ 1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
34
+
35
+ ```bash
36
+ conda create -n croco python=3.7 cmake=3.14.0
37
+ conda activate croco
38
+ conda install habitat-sim headless -c conda-forge -c aihabitat
39
+ conda install pytorch torchvision -c pytorch
40
+ conda install notebook ipykernel matplotlib
41
+ conda install ipywidgets widgetsnbextension
42
+ conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
43
+
44
+ ```
45
+
46
+ 2. Compile cuda kernels for RoPE
47
+
48
+ CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
49
+ ```bash
50
+ cd models/curope/
51
+ python setup.py build_ext --inplace
52
+ cd ../../
53
+ ```
54
+
55
+ This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
56
+ You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
57
+
58
+ In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
59
+
60
+ 3. Download pre-trained model
61
+
62
+ We provide several pre-trained models:
63
+
64
+ | modelname | pre-training data | pos. embed. | Encoder | Decoder |
65
+ |------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
66
+ | [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
67
+ | [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
68
+ | [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
69
+ | [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
70
+
71
+ To download a specific model, i.e., the first one (`CroCo.pth`)
72
+ ```bash
73
+ mkdir -p pretrained_models/
74
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
75
+ ```
76
+
77
+ ## Reconstruction example
78
+
79
+ Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
80
+ ```bash
81
+ python demo.py
82
+ ```
83
+
84
+ ## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
85
+
86
+ First download the test scene from Habitat:
87
+ ```bash
88
+ python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
89
+ ```
90
+
91
+ Then, run the Notebook demo `interactive_demo.ipynb`.
92
+
93
+ In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
94
+ ![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)
95
+
96
+ ## Pre-training
97
+
98
+ ### CroCo
99
+
100
+ To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
101
+ ```
102
+ torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
103
+ ```
104
+
105
+ Our CroCo pre-training was launched on a single server with 4 GPUs.
106
+ It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
107
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
108
+ The first run can take a few minutes to start, to parse all available pre-training pairs.
109
+
110
+ ### CroCo v2
111
+
112
+ For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
113
+ Then, run the following command for the largest model (ViT-L encoder, Base decoder):
114
+ ```
115
+ torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
116
+ ```
117
+
118
+ Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
119
+ The largest model should take around 12 days on A100.
120
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
121
+
122
+ ## Stereo matching and Optical flow downstream tasks
123
+
124
+ For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
extern/dust3r/croco/assets/Chateau1.png ADDED
extern/dust3r/croco/assets/Chateau2.png ADDED
extern/dust3r/croco/assets/arch.jpg ADDED
extern/dust3r/croco/croco-stereo-flow-demo.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9bca0f41",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Simple inference example with CroCo-Stereo or CroCo-Flow"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "80653ef7",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
19
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "4f033862",
25
+ "metadata": {},
26
+ "source": [
27
+ "First download the model(s) of your choice by running\n",
28
+ "```\n",
29
+ "bash stereoflow/download_model.sh crocostereo.pth\n",
30
+ "bash stereoflow/download_model.sh crocoflow.pth\n",
31
+ "```"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "1fb2e392",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "import torch\n",
42
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
43
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
44
+ "import matplotlib.pylab as plt"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "e0e25d77",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "from stereoflow.test import _load_model_and_criterion\n",
55
+ "from stereoflow.engine import tiled_pred\n",
56
+ "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
57
+ "from stereoflow.datasets_flow import flowToColor\n",
58
+ "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "id": "86a921f5",
64
+ "metadata": {},
65
+ "source": [
66
+ "### CroCo-Stereo example"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "64e483cb",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
77
+ "image2 = np.asarray(Image.open('<path_to_right_image>'))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "f0d04303",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "47dc14b5",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
98
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
99
+ "with torch.inference_mode():\n",
100
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
101
+ "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "583b9f16",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "plt.imshow(vis_disparity(pred))\n",
112
+ "plt.axis('off')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "id": "d2df5d70",
118
+ "metadata": {},
119
+ "source": [
120
+ "### CroCo-Flow example"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "9ee257a7",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
131
+ "image2 = np.asarray(Image.open('<path_to_second_image>'))"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "d5edccf0",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "b19692c3",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
152
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
153
+ "with torch.inference_mode():\n",
154
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
155
+ "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "26f79db3",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "plt.imshow(flowToColor(pred))\n",
166
+ "plt.axis('off')"
167
+ ]
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "Python 3 (ipykernel)",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.9.7"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 5
191
+ }
extern/dust3r/croco/datasets/__init__.py ADDED
File without changes
extern/dust3r/croco/datasets/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
extern/dust3r/croco/datasets/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Extracting crops for pre-training
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import functools
13
+ from multiprocessing import Pool
14
+ import math
15
+
16
+
17
+ def arg_parser():
18
+ parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
19
+
20
+ parser.add_argument('--crops', type=str, required=True, help='crop file')
21
+ parser.add_argument('--root-dir', type=str, required=True, help='root directory')
22
+ parser.add_argument('--output-dir', type=str, required=True, help='output directory')
23
+ parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
24
+ parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
25
+ parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
26
+ parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
27
+ return parser
28
+
29
+
30
+ def main(args):
31
+ listing_path = os.path.join(args.output_dir, 'listing.txt')
32
+
33
+ print(f'Loading list of crops ... ({args.nthread} threads)')
34
+ crops, num_crops_to_generate = load_crop_file(args.crops)
35
+
36
+ print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
37
+ num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
38
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
39
+
40
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
41
+ del crops
42
+
43
+ os.makedirs(args.output_dir, exist_ok=True)
44
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
45
+ call = functools.partial(save_image_crops, args)
46
+
47
+ print(f"Generating cropped images to {args.output_dir} ...")
48
+ with open(listing_path, 'w') as listing:
49
+ listing.write('# pair_path\n')
50
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
51
+ for path in results:
52
+ listing.write(f'{path}\n')
53
+ print('Finished writing listing to', listing_path)
54
+
55
+
56
+ def load_crop_file(path):
57
+ data = open(path).read().splitlines()
58
+ pairs = []
59
+ num_crops_to_generate = 0
60
+ for line in tqdm(data):
61
+ if line.startswith('#'):
62
+ continue
63
+ line = line.split(', ')
64
+ if len(line) < 8:
65
+ img1, img2, rotation = line
66
+ pairs.append((img1, img2, int(rotation), []))
67
+ else:
68
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
69
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
70
+ pairs[-1][-1].append((rect1, rect2))
71
+ num_crops_to_generate += 1
72
+ return pairs, num_crops_to_generate
73
+
74
+
75
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
76
+ jobs = []
77
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
78
+
79
+ def get_path(idx):
80
+ idx_array = []
81
+ d = idx
82
+ for level in range(num_levels - 1):
83
+ idx_array.append(idx // powers[level])
84
+ idx = idx % powers[level]
85
+ idx_array.append(d)
86
+ return '/'.join(map(lambda x: hex(x)[2:], idx_array))
87
+
88
+ idx = 0
89
+ for pair_data in tqdm(pairs):
90
+ img1, img2, rotation, crops = pair_data
91
+ if -60 <= rotation and rotation <= 60:
92
+ rotation = 0 # most likely not a true rotation
93
+ paths = [get_path(idx + k) for k in range(len(crops))]
94
+ idx += len(crops)
95
+ jobs.append(((img1, img2), rotation, crops, paths))
96
+ return jobs
97
+
98
+
99
+ def load_image(path):
100
+ try:
101
+ return Image.open(path).convert('RGB')
102
+ except Exception as e:
103
+ print('skipping', path, e)
104
+ raise OSError()
105
+
106
+
107
+ def save_image_crops(args, data):
108
+ # load images
109
+ img_pair, rot, crops, paths = data
110
+ try:
111
+ img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
112
+ except OSError as e:
113
+ return []
114
+
115
+ def area(sz):
116
+ return sz[0] * sz[1]
117
+
118
+ tgt_size = (args.imsize, args.imsize)
119
+
120
+ def prepare_crop(img, rect, rot=0):
121
+ # actual crop
122
+ img = img.crop(rect)
123
+
124
+ # resize to desired size
125
+ interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
126
+ img = img.resize(tgt_size, resample=interp)
127
+
128
+ # rotate the image
129
+ rot90 = (round(rot/90) % 4) * 90
130
+ if rot90 == 90:
131
+ img = img.transpose(Image.Transpose.ROTATE_90)
132
+ elif rot90 == 180:
133
+ img = img.transpose(Image.Transpose.ROTATE_180)
134
+ elif rot90 == 270:
135
+ img = img.transpose(Image.Transpose.ROTATE_270)
136
+ return img
137
+
138
+ results = []
139
+ for (rect1, rect2), path in zip(crops, paths):
140
+ crop1 = prepare_crop(img1, rect1)
141
+ crop2 = prepare_crop(img2, rect2, rot)
142
+
143
+ fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
144
+ fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
145
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
146
+
147
+ assert not os.path.isfile(fullpath1), fullpath1
148
+ assert not os.path.isfile(fullpath2), fullpath2
149
+ crop1.save(fullpath1)
150
+ crop2.save(fullpath2)
151
+ results.append(path)
152
+
153
+ return results
154
+
155
+
156
+ if __name__ == '__main__':
157
+ args = arg_parser().parse_args()
158
+ main(args)
159
+
extern/dust3r/croco/datasets/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
extern/dust3r/croco/datasets/habitat_sim/__init__.py ADDED
File without changes
extern/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6
+ """
7
+ import os
8
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
9
+ from datasets.habitat_sim.paths import SCENES_DATASET
10
+ import argparse
11
+ import quaternion
12
+ import PIL.Image
13
+ import cv2
14
+ import json
15
+ from tqdm import tqdm
16
+
17
+ def generate_multiview_images_from_metadata(metadata_filename,
18
+ output_dir,
19
+ overload_params = dict(),
20
+ scene_datasets_paths=None,
21
+ exist_ok=False):
22
+ """
23
+ Generate images from a metadata file for reproducibility purposes.
24
+ """
25
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
26
+ if scene_datasets_paths is not None:
27
+ scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
28
+
29
+ with open(metadata_filename, 'r') as f:
30
+ input_metadata = json.load(f)
31
+ metadata = dict()
32
+ for key, value in input_metadata.items():
33
+ # Optionally replace some paths
34
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
35
+ if scene_datasets_paths is not None:
36
+ for dataset_label, dataset_path in scene_datasets_paths.items():
37
+ if value.startswith(dataset_label):
38
+ value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
39
+ break
40
+ metadata[key] = value
41
+
42
+ # Overload some parameters
43
+ for key, value in overload_params.items():
44
+ metadata[key] = value
45
+
46
+ generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
47
+ generate_depth = metadata["generate_depth"]
48
+
49
+ os.makedirs(output_dir, exist_ok=exist_ok)
50
+
51
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
52
+
53
+ # Generate views
54
+ for idx_label, data in tqdm(metadata['multiviews'].items()):
55
+ positions = data["positions"]
56
+ orientations = data["orientations"]
57
+ n = len(positions)
58
+ for oidx in range(n):
59
+ observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
60
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
61
+ # Color image saved using PIL
62
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
63
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
64
+ img.save(filename)
65
+ if generate_depth:
66
+ # Depth image as EXR file
67
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
68
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
69
+ # Camera parameters
70
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
71
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
72
+ with open(filename, "w") as f:
73
+ json.dump(camera_params, f)
74
+ # Save metadata
75
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
76
+ json.dump(metadata, f)
77
+
78
+ generator.close()
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--metadata_filename", required=True)
83
+ parser.add_argument("--output_dir", required=True)
84
+ args = parser.parse_args()
85
+
86
+ generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
87
+ output_dir=args.output_dir,
88
+ scene_datasets_paths=SCENES_DATASET,
89
+ overload_params=dict(),
90
+ exist_ok=True)
91
+
92
+
extern/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script generating commandlines to generate image pairs from metadata files.
6
+ """
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import argparse
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", required=True)
15
+ parser.add_argument("--output_dir", required=True)
16
+ parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
17
+ args = parser.parse_args()
18
+
19
+ input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
20
+
21
+ for metadata_filename in tqdm(input_metadata_filenames):
22
+ output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
23
+ # Do not process the scene if the metadata file already exists
24
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
25
+ continue
26
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
27
+ print(commandline)
extern/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import PIL.Image
8
+ import numpy as np
9
+ import json
10
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
11
+ from datasets.habitat_sim.paths import list_scenes_available
12
+ import cv2
13
+ import quaternion
14
+ import shutil
15
+
16
+ def generate_multiview_images_for_scene(scene_dataset_config_file,
17
+ scene,
18
+ navmesh,
19
+ output_dir,
20
+ views_count,
21
+ size,
22
+ exist_ok=False,
23
+ generate_depth=False,
24
+ **kwargs):
25
+ """
26
+ Generate tuples of overlapping views for a given scene.
27
+ generate_depth: generate depth images and camera parameters.
28
+ """
29
+ if os.path.exists(output_dir) and not exist_ok:
30
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
31
+ return
32
+ try:
33
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
34
+ os.makedirs(output_dir, exist_ok=exist_ok)
35
+
36
+ metadata_filename = os.path.join(output_dir, "metadata.json")
37
+
38
+ metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
39
+ scene=scene,
40
+ navmesh=navmesh,
41
+ views_count=views_count,
42
+ size=size,
43
+ generate_depth=generate_depth,
44
+ **kwargs)
45
+ metadata_template["multiviews"] = dict()
46
+
47
+ if os.path.exists(metadata_filename):
48
+ print("Metadata file already exists:", metadata_filename)
49
+ print("Loading already generated metadata file...")
50
+ with open(metadata_filename, "r") as f:
51
+ metadata = json.load(f)
52
+
53
+ for key in metadata_template.keys():
54
+ if key != "multiviews":
55
+ assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
56
+ else:
57
+ print("No temporary file found. Starting generation from scratch...")
58
+ metadata = metadata_template
59
+
60
+ starting_id = len(metadata["multiviews"])
61
+ print(f"Starting generation from index {starting_id}/{size}...")
62
+ if starting_id >= size:
63
+ print("Generation already done.")
64
+ return
65
+
66
+ generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
67
+ scene=scene,
68
+ navmesh=navmesh,
69
+ views_count = views_count,
70
+ size = size,
71
+ **kwargs)
72
+
73
+ for idx in tqdm(range(starting_id, size)):
74
+ # Generate / re-generate the observations
75
+ try:
76
+ data = generator[idx]
77
+ observations = data["observations"]
78
+ positions = data["positions"]
79
+ orientations = data["orientations"]
80
+
81
+ idx_label = f"{idx:08}"
82
+ for oidx, observation in enumerate(observations):
83
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
84
+ # Color image saved using PIL
85
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
86
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
87
+ img.save(filename)
88
+ if generate_depth:
89
+ # Depth image as EXR file
90
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
91
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
92
+ # Camera parameters
93
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
94
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
95
+ with open(filename, "w") as f:
96
+ json.dump(camera_params, f)
97
+ metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
98
+ "orientations": orientations.tolist(),
99
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
100
+ "valid_fractions": data["valid_fractions"].tolist(),
101
+ "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
102
+ except RecursionError:
103
+ print("Recursion error: unable to sample observations for this scene. We will stop there.")
104
+ break
105
+
106
+ # Regularly save a temporary metadata file, in case we need to restart the generation
107
+ if idx % 10 == 0:
108
+ with open(metadata_filename, "w") as f:
109
+ json.dump(metadata, f)
110
+
111
+ # Save metadata
112
+ with open(metadata_filename, "w") as f:
113
+ json.dump(metadata, f)
114
+
115
+ generator.close()
116
+ except NoNaviguableSpaceError:
117
+ pass
118
+
119
+ def create_commandline(scene_data, generate_depth, exist_ok=False):
120
+ """
121
+ Create a commandline string to generate a scene.
122
+ """
123
+ def my_formatting(val):
124
+ if val is None or val == "":
125
+ return '""'
126
+ else:
127
+ return val
128
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
129
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
130
+ --navmesh {my_formatting(scene_data.navmesh)}
131
+ --output_dir {my_formatting(scene_data.output_dir)}
132
+ --generate_depth {int(generate_depth)}
133
+ --exist_ok {int(exist_ok)}
134
+ """
135
+ commandline = " ".join(commandline.split())
136
+ return commandline
137
+
138
+ if __name__ == "__main__":
139
+ os.umask(2)
140
+
141
+ parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
142
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
143
+ """)
144
+
145
+ parser.add_argument("--output_dir", type=str, required=True)
146
+ parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
147
+ parser.add_argument("--scene", type=str, default="")
148
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
149
+ parser.add_argument("--navmesh", type=str, default="")
150
+
151
+ parser.add_argument("--generate_depth", type=int, default=1)
152
+ parser.add_argument("--exist_ok", type=int, default=0)
153
+
154
+ kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
155
+
156
+ args = parser.parse_args()
157
+ generate_depth=bool(args.generate_depth)
158
+ exist_ok = bool(args.exist_ok)
159
+
160
+ if args.list_commands:
161
+ # Listing scenes available...
162
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
163
+
164
+ for scene_data in scenes_data:
165
+ print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
166
+ else:
167
+ if args.scene == "" or args.output_dir == "":
168
+ print("Missing scene or output dir argument!")
169
+ print(parser.format_help())
170
+ else:
171
+ generate_multiview_images_for_scene(scene=args.scene,
172
+ scene_dataset_config_file = args.scene_dataset_config_file,
173
+ navmesh = args.navmesh,
174
+ output_dir = args.output_dir,
175
+ exist_ok=exist_ok,
176
+ generate_depth=generate_depth,
177
+ **kwargs)
extern/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import numpy as np
6
+ import quaternion
7
+ import habitat_sim
8
+ import json
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import cv2
11
+
12
+ # OpenCV to habitat camera convention transformation
13
+ R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
14
+ R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
15
+ DEG2RAD = np.pi / 180
16
+
17
+ def compute_camera_intrinsics(height, width, hfov):
18
+ f = width/2 / np.tan(hfov/2 * np.pi/180)
19
+ cu, cv = width/2, height/2
20
+ return f, cu, cv
21
+
22
+ def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
23
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
24
+ t_cam2world = np.asarray(camera_position)
25
+ return R_cam2world, t_cam2world
26
+
27
+ def compute_pointmap(depthmap, hfov):
28
+ """ Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
29
+ height, width = depthmap.shape
30
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
31
+ # Cast depth map to point
32
+ z_cam = depthmap
33
+ u, v = np.meshgrid(range(width), range(height))
34
+ x_cam = (u - cu) / f * z_cam
35
+ y_cam = (v - cv) / f * z_cam
36
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
37
+ return X_cam
38
+
39
+ def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
40
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
41
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
42
+
43
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
44
+ valid_mask = (X_cam[:,:,2] != 0.0)
45
+
46
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
47
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
48
+ return X_world
49
+
50
+ def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
51
+ """
52
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
53
+ """
54
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
55
+ distances, indices = nbrs.kneighbors(pointcloud1)
56
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
57
+
58
+ data = {"intersection1": intersection1,
59
+ "size1": len(pointcloud1)}
60
+ if compute_symmetric:
61
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
62
+ distances, indices = nbrs.kneighbors(pointcloud2)
63
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
64
+ data["intersection2"] = intersection2
65
+ data["size2"] = len(pointcloud2)
66
+
67
+ return data
68
+
69
+ def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
70
+ """
71
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
72
+ In-place modifications.
73
+ """
74
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
75
+ height, width = observation['depth'].shape
76
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
77
+ K = np.asarray([[f, 0, cu],
78
+ [0, f, cv],
79
+ [0, 0, 1.0]])
80
+ observation["camera_intrinsics"] = K
81
+ observation["t_cam2world"] = t_cam2world
82
+ observation["R_cam2world"] = R_cam2world
83
+
84
+ def look_at(eye, center, up, return_cam2world=True):
85
+ """
86
+ Return camera pose looking at a given center point.
87
+ Analogous of gluLookAt function, using OpenCV camera convention.
88
+ """
89
+ z = center - eye
90
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
91
+ y = -up
92
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
93
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
94
+ x = np.cross(y, z, axis=-1)
95
+
96
+ if return_cam2world:
97
+ R = np.stack((x, y, z), axis=-1)
98
+ t = eye
99
+ else:
100
+ # World to camera transformation
101
+ # Transposed matrix
102
+ R = np.stack((x, y, z), axis=-2)
103
+ t = - np.einsum('...ij, ...j', R, eye)
104
+ return R, t
105
+
106
+ def look_at_for_habitat(eye, center, up, return_cam2world=True):
107
+ R, t = look_at(eye, center, up)
108
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
109
+ return orientation, t
110
+
111
+ def generate_orientation_noise(pan_range, tilt_range, roll_range):
112
+ return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
113
+ * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
114
+ * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
115
+
116
+
117
+ class NoNaviguableSpaceError(RuntimeError):
118
+ def __init__(self, *args):
119
+ super().__init__(*args)
120
+
121
+ class MultiviewHabitatSimGenerator:
122
+ def __init__(self,
123
+ scene,
124
+ navmesh,
125
+ scene_dataset_config_file,
126
+ resolution = (240, 320),
127
+ views_count=2,
128
+ hfov = 60,
129
+ gpu_id = 0,
130
+ size = 10000,
131
+ minimum_covisibility = 0.5,
132
+ transform = None):
133
+ self.scene = scene
134
+ self.navmesh = navmesh
135
+ self.scene_dataset_config_file = scene_dataset_config_file
136
+ self.resolution = resolution
137
+ self.views_count = views_count
138
+ assert(self.views_count >= 1)
139
+ self.hfov = hfov
140
+ self.gpu_id = gpu_id
141
+ self.size = size
142
+ self.transform = transform
143
+
144
+ # Noise added to camera orientation
145
+ self.pan_range = (-3, 3)
146
+ self.tilt_range = (-10, 10)
147
+ self.roll_range = (-5, 5)
148
+
149
+ # Height range to sample cameras
150
+ self.height_range = (1.2, 1.8)
151
+
152
+ # Random steps between the camera views
153
+ self.random_steps_count = 5
154
+ self.random_step_variance = 2.0
155
+
156
+ # Minimum fraction of the scene which should be valid (well defined depth)
157
+ self.minimum_valid_fraction = 0.7
158
+
159
+ # Distance threshold to see to select pairs
160
+ self.distance_threshold = 0.05
161
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
162
+ self.minimum_covisibility = minimum_covisibility
163
+
164
+ # Maximum number of retries.
165
+ self.max_attempts_count = 100
166
+
167
+ self.seed = None
168
+ self._lazy_initialization()
169
+
170
+ def _lazy_initialization(self):
171
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
172
+ if self.seed == None:
173
+ # Re-seed numpy generator
174
+ np.random.seed()
175
+ self.seed = np.random.randint(2**32-1)
176
+ sim_cfg = habitat_sim.SimulatorConfiguration()
177
+ sim_cfg.scene_id = self.scene
178
+ if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
179
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
180
+ sim_cfg.random_seed = self.seed
181
+ sim_cfg.load_semantic_mesh = False
182
+ sim_cfg.gpu_device_id = self.gpu_id
183
+
184
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
185
+ depth_sensor_spec.uuid = "depth"
186
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
187
+ depth_sensor_spec.resolution = self.resolution
188
+ depth_sensor_spec.hfov = self.hfov
189
+ depth_sensor_spec.position = [0.0, 0.0, 0]
190
+ depth_sensor_spec.orientation
191
+
192
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
193
+ rgb_sensor_spec.uuid = "color"
194
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
195
+ rgb_sensor_spec.resolution = self.resolution
196
+ rgb_sensor_spec.hfov = self.hfov
197
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
198
+ agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
199
+
200
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
201
+ self.sim = habitat_sim.Simulator(cfg)
202
+ if self.navmesh is not None and self.navmesh != "":
203
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
204
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
205
+
206
+ if not self.sim.pathfinder.is_loaded:
207
+ # Try to compute a navmesh
208
+ navmesh_settings = habitat_sim.NavMeshSettings()
209
+ navmesh_settings.set_defaults()
210
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
211
+
212
+ # Ensure that the navmesh is not empty
213
+ if not self.sim.pathfinder.is_loaded:
214
+ raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
215
+
216
+ self.agent = self.sim.initialize_agent(agent_id=0)
217
+
218
+ def close(self):
219
+ self.sim.close()
220
+
221
+ def __del__(self):
222
+ self.sim.close()
223
+
224
+ def __len__(self):
225
+ return self.size
226
+
227
+ def sample_random_viewpoint(self):
228
+ """ Sample a random viewpoint using the navmesh """
229
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
230
+
231
+ # Sample a random viewpoint height
232
+ viewpoint_height = np.random.uniform(*self.height_range)
233
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
234
+ viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
235
+ return viewpoint_position, viewpoint_orientation, nav_point
236
+
237
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
238
+ """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
239
+ other_nav_point = nav_point
240
+
241
+ walk_directions = self.random_step_variance * np.asarray([1,0,1])
242
+ for i in range(self.random_steps_count):
243
+ temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
244
+ # Snapping may return nan when it fails
245
+ if not np.isnan(temp[0]):
246
+ other_nav_point = temp
247
+
248
+ other_viewpoint_height = np.random.uniform(*self.height_range)
249
+ other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
250
+
251
+ # Set viewing direction towards the central point
252
+ rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
253
+ rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
254
+ return position, rotation, other_nav_point
255
+
256
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
257
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
258
+ # Observation
259
+ pixels_count = self.resolution[0] * self.resolution[1]
260
+ valid_fraction = len(other_pointcloud) / pixels_count
261
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
262
+ overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
263
+ covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
264
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
265
+ return is_valid, valid_fraction, covisibility
266
+
267
+ def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
268
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
269
+ # Observation
270
+ other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
271
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
272
+
273
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
274
+ agent_state = habitat_sim.AgentState()
275
+ agent_state.position = viewpoint_position
276
+ agent_state.rotation = viewpoint_orientation
277
+ self.agent.set_state(agent_state)
278
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
279
+ _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
280
+ return viewpoint_observations
281
+
282
+ def __getitem__(self, useless_idx):
283
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
284
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
285
+ # Extract point cloud
286
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
287
+ camera_position=ref_position, camera_rotation=ref_orientation)
288
+
289
+ pixels_count = self.resolution[0] * self.resolution[1]
290
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
291
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
292
+ if ref_valid_fraction < self.minimum_valid_fraction:
293
+ # This should produce a recursion error at some point when something is very wrong.
294
+ return self[0]
295
+ # Pick an reference observed point in the point cloud
296
+ observed_point = np.mean(ref_pointcloud, axis=0)
297
+
298
+ # Add the first image as reference
299
+ viewpoints_observations = [ref_observations]
300
+ viewpoints_covisibility = [ref_valid_fraction]
301
+ viewpoints_positions = [ref_position]
302
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
303
+ viewpoints_clouds = [ref_pointcloud]
304
+ viewpoints_valid_fractions = [ref_valid_fraction]
305
+
306
+ for _ in range(self.views_count - 1):
307
+ # Generate an other viewpoint using some dummy random walk
308
+ successful_sampling = False
309
+ for sampling_attempt in range(self.max_attempts_count):
310
+ position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
311
+ # Observation
312
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
313
+ other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
314
+
315
+ is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
316
+ if is_valid:
317
+ successful_sampling = True
318
+ break
319
+ if not successful_sampling:
320
+ print("WARNING: Maximum number of attempts reached.")
321
+ # Dirty hack, try using a novel original viewpoint
322
+ return self[0]
323
+ viewpoints_observations.append(other_viewpoint_observations)
324
+ viewpoints_covisibility.append(covisibility)
325
+ viewpoints_positions.append(position)
326
+ viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
327
+ viewpoints_clouds.append(other_pointcloud)
328
+ viewpoints_valid_fractions.append(valid_fraction)
329
+
330
+ # Estimate relations between all pairs of images
331
+ pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
332
+ for i in range(len(viewpoints_observations)):
333
+ pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
334
+ for j in range(i+1, len(viewpoints_observations)):
335
+ overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
336
+ pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
337
+ pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
338
+
339
+ # IoU is relative to the image 0
340
+ data = {"observations": viewpoints_observations,
341
+ "positions": np.asarray(viewpoints_positions),
342
+ "orientations": np.asarray(viewpoints_orientations),
343
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
344
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
345
+ "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
346
+ }
347
+
348
+ if self.transform is not None:
349
+ data = self.transform(data)
350
+ return data
351
+
352
+ def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
353
+ """
354
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
355
+ Useful to generate nice visualisations.
356
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
357
+ """
358
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
359
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
360
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
361
+ camera_position=ref_position, camera_rotation=ref_orientation)
362
+ pixels_count = self.resolution[0] * self.resolution[1]
363
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
364
+ # Dirty hack: ensure that the valid part of the image is significant
365
+ return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
366
+
367
+ # Pick an observed point in the point cloud
368
+ observed_point = np.mean(ref_pointcloud, axis=0)
369
+ ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
370
+
371
+ images = []
372
+ is_valid = []
373
+ # Spiral trajectory, use_constant orientation
374
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
375
+ r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
376
+ theta = alpha * half_turns * np.pi
377
+ x = r * np.cos(theta)
378
+ y = r * np.sin(theta)
379
+ z = 0.0
380
+ position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
381
+ if use_constant_orientation:
382
+ orientation = ref_orientation
383
+ else:
384
+ # trajectory looking at a mean point in front of the ref observation
385
+ orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
386
+ observations = self.render_viewpoint(position, orientation)
387
+ images.append(observations['color'][...,:3])
388
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
389
+ is_valid.append(_is_valid)
390
+ return images, np.all(is_valid)
extern/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ """
4
+ Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5
+ """
6
+ import os
7
+ import glob
8
+ from tqdm import tqdm
9
+ import shutil
10
+ import json
11
+ from datasets.habitat_sim.paths import *
12
+ import argparse
13
+ import collections
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input_dir")
18
+ parser.add_argument("output_dir")
19
+ args = parser.parse_args()
20
+
21
+ input_dirname = args.input_dir
22
+ output_dirname = args.output_dir
23
+
24
+ input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
25
+
26
+ images_count = collections.defaultdict(lambda : 0)
27
+
28
+ os.makedirs(output_dirname)
29
+ for input_filename in tqdm(input_metadata_filenames):
30
+ # Ignore empty files
31
+ with open(input_filename, "r") as f:
32
+ original_metadata = json.load(f)
33
+ if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
34
+ print("No views in", input_filename)
35
+ continue
36
+
37
+ relpath = os.path.relpath(input_filename, input_dirname)
38
+ print(relpath)
39
+
40
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
41
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
42
+ scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
43
+ metadata = dict()
44
+ for key, value in original_metadata.items():
45
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
46
+ known_path = False
47
+ for dataset, dataset_path in scenes_dataset_paths.items():
48
+ if value.startswith(dataset_path):
49
+ value = os.path.join(dataset, os.path.relpath(value, dataset_path))
50
+ known_path = True
51
+ break
52
+ if not known_path:
53
+ raise KeyError("Unknown path:" + value)
54
+ metadata[key] = value
55
+
56
+ # Compile some general statistics while packing data
57
+ scene_split = metadata["scene"].split("/")
58
+ upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
59
+ images_count[upper_level] += len(metadata["multiviews"])
60
+
61
+ output_filename = os.path.join(output_dirname, relpath)
62
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
63
+ with open(output_filename, "w") as f:
64
+ json.dump(metadata, f)
65
+
66
+ # Print statistics
67
+ print("Images count:")
68
+ for upper_level, count in images_count.items():
69
+ print(f"- {upper_level}: {count}")
extern/dust3r/croco/datasets/habitat_sim/paths.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Paths to Habitat-Sim scenes
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import collections
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Hardcoded path to the different scene datasets
15
+ SCENES_DATASET = {
16
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/"
23
+ }
24
+
25
+ SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
26
+
27
+ def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
28
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
29
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
30
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
31
+ scenes_data = []
32
+ for idx in range(len(scenes)):
33
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
34
+ # Add scene
35
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
36
+ scene = scenes[idx] + ".scene_instance.json",
37
+ navmesh = os.path.join(base_path, navmeshes[idx]),
38
+ output_dir = output_dir)
39
+ scenes_data.append(data)
40
+ return scenes_data
41
+
42
+ def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
43
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
44
+ scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
45
+ navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
46
+ scenes_data = []
47
+ for idx in range(len(scenes)):
48
+ output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
49
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
50
+ scene = scenes[idx],
51
+ navmesh = "",
52
+ output_dir = output_dir)
53
+ scenes_data.append(data)
54
+ return scenes_data
55
+
56
+ def list_replica_scenes(base_output_dir, base_path):
57
+ scenes_data = []
58
+ for scene_id in os.listdir(base_path):
59
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
60
+ navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
61
+ scene_dataset_config_file = ""
62
+ output_dir = os.path.join(base_output_dir, scene_id)
63
+ # Add scene only if it does not exist already, or if exist_ok
64
+ data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
65
+ scene = scene,
66
+ navmesh = navmesh,
67
+ output_dir = output_dir)
68
+ scenes_data.append(data)
69
+ return scenes_data
70
+
71
+
72
+ def list_scenes(base_output_dir, base_path):
73
+ """
74
+ Generic method iterating through a base_path folder to find scenes.
75
+ """
76
+ scenes_data = []
77
+ for root, dirs, files in os.walk(base_path, followlinks=True):
78
+ folder_scenes_data = []
79
+ for file in files:
80
+ name, ext = os.path.splitext(file)
81
+ if ext == ".glb":
82
+ scene = os.path.join(root, name + ".glb")
83
+ navmesh = os.path.join(root, name + ".navmesh")
84
+ if not os.path.exists(navmesh):
85
+ navmesh = ""
86
+ relpath = os.path.relpath(root, base_path)
87
+ output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
88
+ data = SceneData(scene_dataset_config_file="",
89
+ scene = scene,
90
+ navmesh = navmesh,
91
+ output_dir = output_dir)
92
+ folder_scenes_data.append(data)
93
+
94
+ # Specific check for HM3D:
95
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
96
+ basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
97
+ if len(basis_scenes) != 0:
98
+ folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
99
+
100
+ scenes_data.extend(folder_scenes_data)
101
+ return scenes_data
102
+
103
+ def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
104
+ scenes_data = []
105
+
106
+ # HM3D
107
+ for split in ("minival", "train", "val", "examples"):
108
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
109
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
110
+
111
+ # Gibson
112
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
113
+ base_path=scenes_dataset_paths["gibson"])
114
+
115
+ # Habitat test scenes (just a few)
116
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
117
+ base_path=scenes_dataset_paths["habitat-test-scenes"])
118
+
119
+ # ReplicaCAD (baked lightning)
120
+ scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
121
+
122
+ # ScanNet
123
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
124
+ base_path=scenes_dataset_paths["scannet"])
125
+
126
+ # Replica
127
+ list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
128
+ base_path=scenes_dataset_paths["replica"])
129
+ return scenes_data
extern/dust3r/croco/datasets/pairs_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ from datasets.transforms import get_pair_transforms
9
+
10
+ def load_image(impath):
11
+ return Image.open(impath)
12
+
13
+ def load_pairs_from_cache_file(fname, root=''):
14
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
15
+ with open(fname, 'r') as fid:
16
+ lines = fid.read().strip().splitlines()
17
+ pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
18
+ return pairs
19
+
20
+ def load_pairs_from_list_file(fname, root=''):
21
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
22
+ with open(fname, 'r') as fid:
23
+ lines = fid.read().strip().splitlines()
24
+ pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
25
+ return pairs
26
+
27
+
28
+ def write_cache_file(fname, pairs, root=''):
29
+ if len(root)>0:
30
+ if not root.endswith('/'): root+='/'
31
+ assert os.path.isdir(root)
32
+ s = ''
33
+ for im1, im2 in pairs:
34
+ if len(root)>0:
35
+ assert im1.startswith(root), im1
36
+ assert im2.startswith(root), im2
37
+ s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
38
+ with open(fname, 'w') as fid:
39
+ fid.write(s[:-1])
40
+
41
+ def parse_and_cache_all_pairs(dname, data_dir='./data/'):
42
+ if dname=='habitat_release':
43
+ dirname = os.path.join(data_dir, 'habitat_release')
44
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
45
+ cache_file = os.path.join(dirname, 'pairs.txt')
46
+ assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
47
+
48
+ print('Parsing pairs for dataset: '+dname)
49
+ pairs = []
50
+ for root, dirs, files in os.walk(dirname):
51
+ if 'val' in root: continue
52
+ dirs.sort()
53
+ pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
54
+ print('Found {:,} pairs'.format(len(pairs)))
55
+ print('Writing cache to: '+cache_file)
56
+ write_cache_file(cache_file, pairs, root=dirname)
57
+
58
+ else:
59
+ raise NotImplementedError('Unknown dataset: '+dname)
60
+
61
+ def dnames_to_image_pairs(dnames, data_dir='./data/'):
62
+ """
63
+ dnames: list of datasets with image pairs, separated by +
64
+ """
65
+ all_pairs = []
66
+ for dname in dnames.split('+'):
67
+ if dname=='habitat_release':
68
+ dirname = os.path.join(data_dir, 'habitat_release')
69
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
70
+ cache_file = os.path.join(dirname, 'pairs.txt')
71
+ assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
72
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
73
+ elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
74
+ dirname = os.path.join(data_dir, dname+'_crops')
75
+ assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
76
+ list_file = os.path.join(dirname, 'listing.txt')
77
+ assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
78
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
79
+ print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
80
+ all_pairs += pairs
81
+ if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
82
+ return all_pairs
83
+
84
+
85
+ class PairsDataset(Dataset):
86
+
87
+ def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
88
+ super().__init__()
89
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
90
+ self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
91
+
92
+ def __len__(self):
93
+ return len(self.image_pairs)
94
+
95
+ def __getitem__(self, index):
96
+ im1path, im2path = self.image_pairs[index]
97
+ im1 = load_image(im1path)
98
+ im2 = load_image(im2path)
99
+ if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
100
+ return im1, im2
101
+
102
+
103
+ if __name__=="__main__":
104
+ import argparse
105
+ parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
106
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
107
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
108
+ args = parser.parse_args()
109
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
extern/dust3r/croco/datasets/transforms.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ import torchvision.transforms
6
+ import torchvision.transforms.functional as F
7
+
8
+ # "Pair": apply a transform on a pair
9
+ # "Both": apply the exact same transform to both images
10
+
11
+ class ComposePair(torchvision.transforms.Compose):
12
+ def __call__(self, img1, img2):
13
+ for t in self.transforms:
14
+ img1, img2 = t(img1, img2)
15
+ return img1, img2
16
+
17
+ class NormalizeBoth(torchvision.transforms.Normalize):
18
+ def forward(self, img1, img2):
19
+ img1 = super().forward(img1)
20
+ img2 = super().forward(img2)
21
+ return img1, img2
22
+
23
+ class ToTensorBoth(torchvision.transforms.ToTensor):
24
+ def __call__(self, img1, img2):
25
+ img1 = super().__call__(img1)
26
+ img2 = super().__call__(img2)
27
+ return img1, img2
28
+
29
+ class RandomCropPair(torchvision.transforms.RandomCrop):
30
+ # the crop will be intentionally different for the two images with this class
31
+ def forward(self, img1, img2):
32
+ img1 = super().forward(img1)
33
+ img2 = super().forward(img2)
34
+ return img1, img2
35
+
36
+ class ColorJitterPair(torchvision.transforms.ColorJitter):
37
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
38
+ def __init__(self, assymetric_prob, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.assymetric_prob = assymetric_prob
41
+ def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
42
+ for fn_id in fn_idx:
43
+ if fn_id == 0 and brightness_factor is not None:
44
+ img = F.adjust_brightness(img, brightness_factor)
45
+ elif fn_id == 1 and contrast_factor is not None:
46
+ img = F.adjust_contrast(img, contrast_factor)
47
+ elif fn_id == 2 and saturation_factor is not None:
48
+ img = F.adjust_saturation(img, saturation_factor)
49
+ elif fn_id == 3 and hue_factor is not None:
50
+ img = F.adjust_hue(img, hue_factor)
51
+ return img
52
+
53
+ def forward(self, img1, img2):
54
+
55
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
56
+ self.brightness, self.contrast, self.saturation, self.hue
57
+ )
58
+ img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
59
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
60
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
61
+ self.brightness, self.contrast, self.saturation, self.hue
62
+ )
63
+ img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
64
+ return img1, img2
65
+
66
+ def get_pair_transforms(transform_str, totensor=True, normalize=True):
67
+ # transform_str is eg crop224+color
68
+ trfs = []
69
+ for s in transform_str.split('+'):
70
+ if s.startswith('crop'):
71
+ size = int(s[len('crop'):])
72
+ trfs.append(RandomCropPair(size))
73
+ elif s=='acolor':
74
+ trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
75
+ elif s=='': # if transform_str was ""
76
+ pass
77
+ else:
78
+ raise NotImplementedError('Unknown augmentation: '+s)
79
+
80
+ if totensor:
81
+ trfs.append( ToTensorBoth() )
82
+ if normalize:
83
+ trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
84
+
85
+ if len(trfs)==0:
86
+ return None
87
+ elif len(trfs)==1:
88
+ return trfs
89
+ else:
90
+ return ComposePair(trfs)
91
+
92
+
93
+
94
+
95
+
extern/dust3r/croco/demo.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ from models.croco import CroCoNet
6
+ from PIL import Image
7
+ import torchvision.transforms
8
+ from torchvision.transforms import ToTensor, Normalize, Compose
9
+
10
+ def main():
11
+ device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu')
12
+
13
+ # load 224x224 images and transform them to tensor
14
+ imagenet_mean = [0.485, 0.456, 0.406]
15
+ imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True)
16
+ imagenet_std = [0.229, 0.224, 0.225]
17
+ imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True)
18
+ trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
19
+ image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
20
+ image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
21
+
22
+ # load model
23
+ ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
24
+ model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device)
25
+ model.eval()
26
+ msg = model.load_state_dict(ckpt['model'], strict=True)
27
+
28
+ # forward
29
+ with torch.inference_mode():
30
+ out, mask, target = model(image1, image2)
31
+
32
+ # the output is normalized, thus use the mean/std of the actual image to go back to RGB space
33
+ patchified = model.patchify(image1)
34
+ mean = patchified.mean(dim=-1, keepdim=True)
35
+ var = patchified.var(dim=-1, keepdim=True)
36
+ decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean)
37
+ # undo imagenet normalization, prepare masked image
38
+ decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
39
+ input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
40
+ ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
41
+ image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
42
+ masked_input_image = ((1 - image_masks) * input_image)
43
+
44
+ # make visualization
45
+ visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4
46
+ B, C, H, W = visualization.shape
47
+ visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W)
48
+ visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
49
+ fname = "demo_output.png"
50
+ visualization.save(fname)
51
+ print('Visualization save in '+fname)
52
+
53
+
54
+ if __name__=="__main__":
55
+ main()
extern/dust3r/croco/interactive_demo.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Interactive demo of Cross-view Completion."
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
17
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import torch\n",
27
+ "import numpy as np\n",
28
+ "from models.croco import CroCoNet\n",
29
+ "from ipywidgets import interact, interactive, fixed, interact_manual\n",
30
+ "import ipywidgets as widgets\n",
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import quaternion\n",
33
+ "import models.masking"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "### Load CroCo model"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n",
50
+ "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n",
51
+ "msg = model.load_state_dict(ckpt['model'], strict=True)\n",
52
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
53
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
54
+ "model = model.eval()\n",
55
+ "model = model.to(device=device)\n",
56
+ "print(msg)\n",
57
+ "\n",
58
+ "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n",
59
+ " \"\"\"\n",
60
+ " Perform Cross-View completion using two input images, specified using Numpy arrays.\n",
61
+ " \"\"\"\n",
62
+ " # Replace the mask generator\n",
63
+ " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n",
64
+ "\n",
65
+ " # ImageNet-1k color normalization\n",
66
+ " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n",
67
+ " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n",
68
+ "\n",
69
+ " normalize_input_colors = True\n",
70
+ " is_output_normalized = True\n",
71
+ " with torch.no_grad():\n",
72
+ " # Cast data to torch\n",
73
+ " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
74
+ " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
75
+ "\n",
76
+ " if normalize_input_colors:\n",
77
+ " ref_image = (ref_image - imagenet_mean) / imagenet_std\n",
78
+ " target_image = (target_image - imagenet_mean) / imagenet_std\n",
79
+ "\n",
80
+ " out, mask, _ = model(target_image, ref_image)\n",
81
+ " # # get target\n",
82
+ " if not is_output_normalized:\n",
83
+ " predicted_image = model.unpatchify(out)\n",
84
+ " else:\n",
85
+ " # The output only contains higher order information,\n",
86
+ " # we retrieve mean and standard deviation from the actual target image\n",
87
+ " patchified = model.patchify(target_image)\n",
88
+ " mean = patchified.mean(dim=-1, keepdim=True)\n",
89
+ " var = patchified.var(dim=-1, keepdim=True)\n",
90
+ " pred_renorm = out * (var + 1.e-6)**.5 + mean\n",
91
+ " predicted_image = model.unpatchify(pred_renorm)\n",
92
+ "\n",
93
+ " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n",
94
+ " masked_target_image = (1 - image_masks) * target_image\n",
95
+ " \n",
96
+ " if not reconstruct_unmasked_patches:\n",
97
+ " # Replace unmasked patches by their actual values\n",
98
+ " predicted_image = predicted_image * image_masks + masked_target_image\n",
99
+ "\n",
100
+ " # Unapply color normalization\n",
101
+ " if normalize_input_colors:\n",
102
+ " predicted_image = predicted_image * imagenet_std + imagenet_mean\n",
103
+ " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n",
104
+ " \n",
105
+ " # Cast to Numpy\n",
106
+ " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
107
+ " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
108
+ " return masked_target_image, predicted_image"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {},
114
+ "source": [
115
+ "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "import os\n",
125
+ "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n",
126
+ "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n",
127
+ "import habitat_sim\n",
128
+ "\n",
129
+ "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n",
130
+ "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n",
131
+ "\n",
132
+ "sim_cfg = habitat_sim.SimulatorConfiguration()\n",
133
+ "if use_gpu: sim_cfg.gpu_device_id = 0\n",
134
+ "sim_cfg.scene_id = scene\n",
135
+ "sim_cfg.load_semantic_mesh = False\n",
136
+ "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n",
137
+ "rgb_sensor_spec.uuid = \"color\"\n",
138
+ "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n",
139
+ "rgb_sensor_spec.resolution = (224,224)\n",
140
+ "rgb_sensor_spec.hfov = 56.56\n",
141
+ "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n",
142
+ "rgb_sensor_spec.orientation = [0, 0, 0]\n",
143
+ "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n",
144
+ "\n",
145
+ "\n",
146
+ "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n",
147
+ "sim = habitat_sim.Simulator(cfg)\n",
148
+ "if navmesh is not None:\n",
149
+ " sim.pathfinder.load_nav_mesh(navmesh)\n",
150
+ "agent = sim.initialize_agent(agent_id=0)\n",
151
+ "\n",
152
+ "def sample_random_viewpoint():\n",
153
+ " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n",
154
+ " nav_point = sim.pathfinder.get_random_navigable_point()\n",
155
+ " # Sample a random viewpoint height\n",
156
+ " viewpoint_height = np.random.uniform(1.0, 1.6)\n",
157
+ " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n",
158
+ " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n",
159
+ " return viewpoint_position, viewpoint_orientation\n",
160
+ "\n",
161
+ "def render_viewpoint(position, orientation):\n",
162
+ " agent_state = habitat_sim.AgentState()\n",
163
+ " agent_state.position = position\n",
164
+ " agent_state.rotation = orientation\n",
165
+ " agent.set_state(agent_state)\n",
166
+ " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n",
167
+ " image = viewpoint_observations['color'][:,:,:3]\n",
168
+ " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n",
169
+ " return image"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Sample a random reference view"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "ref_position, ref_orientation = sample_random_viewpoint()\n",
186
+ "ref_image = render_viewpoint(ref_position, ref_orientation)\n",
187
+ "plt.clf()\n",
188
+ "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n",
189
+ "axes[0,0].imshow(ref_image)\n",
190
+ "for ax in axes.flatten():\n",
191
+ " ax.set_xticks([])\n",
192
+ " ax.set_yticks([])"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {},
198
+ "source": [
199
+ "### Interactive cross-view completion using CroCo"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "reconstruct_unmasked_patches = False\n",
209
+ "\n",
210
+ "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n",
211
+ " R = quaternion.as_rotation_matrix(ref_orientation)\n",
212
+ " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n",
213
+ " target_orientation = (ref_orientation\n",
214
+ " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n",
215
+ " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n",
216
+ " \n",
217
+ " ref_image = render_viewpoint(ref_position, ref_orientation)\n",
218
+ " target_image = render_viewpoint(target_position, target_orientation)\n",
219
+ "\n",
220
+ " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n",
221
+ "\n",
222
+ " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n",
223
+ " axes[0].imshow(ref_image)\n",
224
+ " axes[0].set_xlabel(\"Reference\")\n",
225
+ " axes[1].imshow(masked_target_image)\n",
226
+ " axes[1].set_xlabel(\"Masked target\")\n",
227
+ " axes[2].imshow(predicted_image)\n",
228
+ " axes[2].set_xlabel(\"Reconstruction\") \n",
229
+ " axes[3].imshow(target_image)\n",
230
+ " axes[3].set_xlabel(\"Target\")\n",
231
+ " for ax in axes.flatten():\n",
232
+ " ax.set_xticks([])\n",
233
+ " ax.set_yticks([])\n",
234
+ "\n",
235
+ "interact(show_demo,\n",
236
+ " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n",
237
+ " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
238
+ " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
239
+ " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
240
+ " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n",
241
+ " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));"
242
+ ]
243
+ }
244
+ ],
245
+ "metadata": {
246
+ "kernelspec": {
247
+ "display_name": "Python 3 (ipykernel)",
248
+ "language": "python",
249
+ "name": "python3"
250
+ },
251
+ "language_info": {
252
+ "codemirror_mode": {
253
+ "name": "ipython",
254
+ "version": 3
255
+ },
256
+ "file_extension": ".py",
257
+ "mimetype": "text/x-python",
258
+ "name": "python",
259
+ "nbconvert_exporter": "python",
260
+ "pygments_lexer": "ipython3",
261
+ "version": "3.7.13"
262
+ },
263
+ "vscode": {
264
+ "interpreter": {
265
+ "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67"
266
+ }
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 2
271
+ }
extern/dust3r/croco/models/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (9.12 kB). View file
 
extern/dust3r/croco/models/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (9.15 kB). View file
 
extern/dust3r/croco/models/__pycache__/blocks.cpython-39.pyc ADDED
Binary file (9.14 kB). View file
 
extern/dust3r/croco/models/__pycache__/croco.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
extern/dust3r/croco/models/__pycache__/croco.cpython-38.pyc ADDED
Binary file (7.43 kB). View file
 
extern/dust3r/croco/models/__pycache__/croco.cpython-39.pyc ADDED
Binary file (7.42 kB). View file
 
extern/dust3r/croco/models/__pycache__/dpt_block.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
extern/dust3r/croco/models/__pycache__/dpt_block.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
extern/dust3r/croco/models/__pycache__/dpt_block.cpython-39.pyc ADDED
Binary file (10 kB). View file
 
extern/dust3r/croco/models/__pycache__/masking.cpython-310.pyc ADDED
Binary file (932 Bytes). View file
 
extern/dust3r/croco/models/__pycache__/masking.cpython-38.pyc ADDED
Binary file (915 Bytes). View file
 
extern/dust3r/croco/models/__pycache__/masking.cpython-39.pyc ADDED
Binary file (923 Bytes). View file
 
extern/dust3r/croco/models/__pycache__/pos_embed.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
extern/dust3r/croco/models/__pycache__/pos_embed.cpython-38.pyc ADDED
Binary file (4.9 kB). View file
 
extern/dust3r/croco/models/__pycache__/pos_embed.cpython-39.pyc ADDED
Binary file (4.87 kB). View file
 
extern/dust3r/croco/models/blocks.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Main encoder/decoder blocks
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # timm
10
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from itertools import repeat
21
+ import collections.abc
22
+
23
+
24
+ def _ntuple(n):
25
+ def parse(x):
26
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
27
+ return x
28
+ return tuple(repeat(x, n))
29
+ return parse
30
+ to_2tuple = _ntuple(2)
31
+
32
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
33
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34
+ """
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
39
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
40
+ if keep_prob > 0.0 and scale_by_keep:
41
+ random_tensor.div_(keep_prob)
42
+ return x * random_tensor
43
+
44
+ class DropPath(nn.Module):
45
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
+ """
47
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+ self.scale_by_keep = scale_by_keep
51
+
52
+ def forward(self, x):
53
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54
+
55
+ def extra_repr(self):
56
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
57
+
58
+ class Mlp(nn.Module):
59
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
60
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
61
+ super().__init__()
62
+ out_features = out_features or in_features
63
+ hidden_features = hidden_features or in_features
64
+ bias = to_2tuple(bias)
65
+ drop_probs = to_2tuple(drop)
66
+
67
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
68
+ self.act = act_layer()
69
+ self.drop1 = nn.Dropout(drop_probs[0])
70
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
71
+ self.drop2 = nn.Dropout(drop_probs[1])
72
+
73
+ def forward(self, x):
74
+ x = self.fc1(x)
75
+ x = self.act(x)
76
+ x = self.drop1(x)
77
+ x = self.fc2(x)
78
+ x = self.drop2(x)
79
+ return x
80
+
81
+ class Attention(nn.Module):
82
+
83
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
84
+ super().__init__()
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+ self.scale = head_dim ** -0.5
88
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89
+ self.attn_drop = nn.Dropout(attn_drop)
90
+ self.proj = nn.Linear(dim, dim)
91
+ self.proj_drop = nn.Dropout(proj_drop)
92
+ self.rope = rope
93
+
94
+ def forward(self, x, xpos):
95
+ B, N, C = x.shape
96
+
97
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
98
+ q, k, v = [qkv[:,:,i] for i in range(3)]
99
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
100
+
101
+ if self.rope is not None:
102
+ q = self.rope(q, xpos)
103
+ k = self.rope(k, xpos)
104
+
105
+ attn = (q @ k.transpose(-2, -1)) * self.scale
106
+ attn = attn.softmax(dim=-1)
107
+ attn = self.attn_drop(attn)
108
+
109
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
110
+ x = self.proj(x)
111
+ x = self.proj_drop(x)
112
+ return x
113
+
114
+ class Block(nn.Module):
115
+
116
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
117
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
118
+ super().__init__()
119
+ self.norm1 = norm_layer(dim)
120
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
121
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
122
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
123
+ self.norm2 = norm_layer(dim)
124
+ mlp_hidden_dim = int(dim * mlp_ratio)
125
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126
+
127
+ def forward(self, x, xpos):
128
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
129
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
130
+ return x
131
+
132
+ class CrossAttention(nn.Module):
133
+
134
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
135
+ super().__init__()
136
+ self.num_heads = num_heads
137
+ head_dim = dim // num_heads
138
+ self.scale = head_dim ** -0.5
139
+
140
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
141
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
142
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
143
+ self.attn_drop = nn.Dropout(attn_drop)
144
+ self.proj = nn.Linear(dim, dim)
145
+ self.proj_drop = nn.Dropout(proj_drop)
146
+
147
+ self.rope = rope
148
+
149
+ def forward(self, query, key, value, qpos, kpos):
150
+ B, Nq, C = query.shape
151
+ Nk = key.shape[1]
152
+ Nv = value.shape[1]
153
+
154
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
155
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
156
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
157
+
158
+ if self.rope is not None:
159
+ q = self.rope(q, qpos)
160
+ k = self.rope(k, kpos)
161
+
162
+ attn = (q @ k.transpose(-2, -1)) * self.scale
163
+ attn = attn.softmax(dim=-1)
164
+ attn = self.attn_drop(attn)
165
+
166
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
167
+ x = self.proj(x)
168
+ x = self.proj_drop(x)
169
+ return x
170
+
171
+ class DecoderBlock(nn.Module):
172
+
173
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
174
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
175
+ super().__init__()
176
+ self.norm1 = norm_layer(dim)
177
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
178
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
179
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
180
+ self.norm2 = norm_layer(dim)
181
+ self.norm3 = norm_layer(dim)
182
+ mlp_hidden_dim = int(dim * mlp_ratio)
183
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
184
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
185
+
186
+ def forward(self, x, y, xpos, ypos):
187
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
188
+ y_ = self.norm_y(y)
189
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
190
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
191
+ return x, y
192
+
193
+
194
+ # patch embedding
195
+ class PositionGetter(object):
196
+ """ return positions of patches """
197
+
198
+ def __init__(self):
199
+ self.cache_positions = {}
200
+
201
+ def __call__(self, b, h, w, device):
202
+ if not (h,w) in self.cache_positions:
203
+ x = torch.arange(w, device=device)
204
+ y = torch.arange(h, device=device)
205
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
206
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
207
+ return pos
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
211
+
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ self.img_size = img_size
217
+ self.patch_size = patch_size
218
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
219
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
220
+ self.flatten = flatten
221
+
222
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
224
+
225
+ self.position_getter = PositionGetter()
226
+
227
+ def forward(self, x):
228
+ B, C, H, W = x.shape
229
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
230
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
231
+ x = self.proj(x)
232
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
233
+ if self.flatten:
234
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
235
+ x = self.norm(x)
236
+ return x, pos
237
+
238
+ def _init_weights(self):
239
+ w = self.proj.weight.data
240
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
241
+
extern/dust3r/croco/models/criterion.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Criterion to train CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+ class MaskedMSE(torch.nn.Module):
14
+
15
+ def __init__(self, norm_pix_loss=False, masked=True):
16
+ """
17
+ norm_pix_loss: normalize each patch by their pixel mean and variance
18
+ masked: compute loss over the masked patches only
19
+ """
20
+ super().__init__()
21
+ self.norm_pix_loss = norm_pix_loss
22
+ self.masked = masked
23
+
24
+ def forward(self, pred, mask, target):
25
+
26
+ if self.norm_pix_loss:
27
+ mean = target.mean(dim=-1, keepdim=True)
28
+ var = target.var(dim=-1, keepdim=True)
29
+ target = (target - mean) / (var + 1.e-6)**.5
30
+
31
+ loss = (pred - target) ** 2
32
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
33
+ if self.masked:
34
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
35
+ else:
36
+ loss = loss.mean() # mean loss
37
+ return loss
extern/dust3r/croco/models/croco.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # CroCo model during pretraining
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
+ from functools import partial
15
+
16
+ from models.blocks import Block, DecoderBlock, PatchEmbed
17
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
18
+ from models.masking import RandomMask
19
+
20
+
21
+ class CroCoNet(nn.Module):
22
+
23
+ def __init__(self,
24
+ img_size=224, # input image size
25
+ patch_size=16, # patch_size
26
+ mask_ratio=0.9, # ratios of masked tokens
27
+ enc_embed_dim=768, # encoder feature dimension
28
+ enc_depth=12, # encoder depth
29
+ enc_num_heads=12, # encoder number of heads in the transformer block
30
+ dec_embed_dim=512, # decoder feature dimension
31
+ dec_depth=8, # decoder depth
32
+ dec_num_heads=16, # decoder number of heads in the transformer block
33
+ mlp_ratio=4,
34
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
35
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
36
+ pos_embed='cosine', # positional embedding (either cosine or RoPE100)
37
+ ):
38
+
39
+ super(CroCoNet, self).__init__()
40
+
41
+ # patch embeddings (with initialization done as in MAE)
42
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
43
+
44
+ # mask generations
45
+ self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
46
+
47
+ self.pos_embed = pos_embed
48
+ if pos_embed=='cosine':
49
+ # positional embedding of the encoder
50
+ enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
51
+ self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
52
+ # positional embedding of the decoder
53
+ dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
54
+ self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
55
+ # pos embedding in each block
56
+ self.rope = None # nothing for cosine
57
+ elif pos_embed.startswith('RoPE'): # eg RoPE100
58
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
59
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
60
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
61
+ freq = float(pos_embed[len('RoPE'):])
62
+ self.rope = RoPE2D(freq=freq)
63
+ else:
64
+ raise NotImplementedError('Unknown pos_embed '+pos_embed)
65
+
66
+ # transformer for the encoder
67
+ self.enc_depth = enc_depth
68
+ self.enc_embed_dim = enc_embed_dim
69
+ self.enc_blocks = nn.ModuleList([
70
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
71
+ for i in range(enc_depth)])
72
+ self.enc_norm = norm_layer(enc_embed_dim)
73
+
74
+ # masked tokens
75
+ self._set_mask_token(dec_embed_dim)
76
+
77
+ # decoder
78
+ self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
79
+
80
+ # prediction head
81
+ self._set_prediction_head(dec_embed_dim, patch_size)
82
+
83
+ # initializer weights
84
+ self.initialize_weights()
85
+
86
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
87
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
88
+
89
+ def _set_mask_generator(self, num_patches, mask_ratio):
90
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
91
+
92
+ def _set_mask_token(self, dec_embed_dim):
93
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
94
+
95
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
96
+ self.dec_depth = dec_depth
97
+ self.dec_embed_dim = dec_embed_dim
98
+ # transfer from encoder to decoder
99
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
100
+ # transformer for the decoder
101
+ self.dec_blocks = nn.ModuleList([
102
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
103
+ for i in range(dec_depth)])
104
+ # final norm layer
105
+ self.dec_norm = norm_layer(dec_embed_dim)
106
+
107
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
108
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
109
+
110
+
111
+ def initialize_weights(self):
112
+ # patch embed
113
+ self.patch_embed._init_weights()
114
+ # mask tokens
115
+ if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
116
+ # linears and layer norms
117
+ self.apply(self._init_weights)
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, nn.Linear):
121
+ # we use xavier_uniform following official JAX ViT:
122
+ torch.nn.init.xavier_uniform_(m.weight)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ nn.init.constant_(m.bias, 0)
127
+ nn.init.constant_(m.weight, 1.0)
128
+
129
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
130
+ """
131
+ image has B x 3 x img_size x img_size
132
+ do_mask: whether to perform masking or not
133
+ return_all_blocks: if True, return the features at the end of every block
134
+ instead of just the features from the last block (eg for some prediction heads)
135
+ """
136
+ # embed the image into patches (x has size B x Npatches x C)
137
+ # and get position if each return patch (pos has size B x Npatches x 2)
138
+ x, pos = self.patch_embed(image)
139
+ # add positional embedding without cls token
140
+ if self.enc_pos_embed is not None:
141
+ x = x + self.enc_pos_embed[None,...]
142
+ # apply masking
143
+ B,N,C = x.size()
144
+ if do_mask:
145
+ masks = self.mask_generator(x)
146
+ x = x[~masks].view(B, -1, C)
147
+ posvis = pos[~masks].view(B, -1, 2)
148
+ else:
149
+ B,N,C = x.size()
150
+ masks = torch.zeros((B,N), dtype=bool)
151
+ posvis = pos
152
+ # now apply the transformer encoder and normalization
153
+ if return_all_blocks:
154
+ out = []
155
+ for blk in self.enc_blocks:
156
+ x = blk(x, posvis)
157
+ out.append(x)
158
+ out[-1] = self.enc_norm(out[-1])
159
+ return out, pos, masks
160
+ else:
161
+ for blk in self.enc_blocks:
162
+ x = blk(x, posvis)
163
+ x = self.enc_norm(x)
164
+ return x, pos, masks
165
+
166
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
167
+ """
168
+ return_all_blocks: if True, return the features at the end of every block
169
+ instead of just the features from the last block (eg for some prediction heads)
170
+
171
+ masks1 can be None => assume image1 fully visible
172
+ """
173
+ # encoder to decoder layer
174
+ visf1 = self.decoder_embed(feat1)
175
+ f2 = self.decoder_embed(feat2)
176
+ # append masked tokens to the sequence
177
+ B,Nenc,C = visf1.size()
178
+ if masks1 is None: # downstreams
179
+ f1_ = visf1
180
+ else: # pretraining
181
+ Ntotal = masks1.size(1)
182
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
183
+ f1_[~masks1] = visf1.view(B * Nenc, C)
184
+ # add positional embedding
185
+ if self.dec_pos_embed is not None:
186
+ f1_ = f1_ + self.dec_pos_embed
187
+ f2 = f2 + self.dec_pos_embed
188
+ # apply Transformer blocks
189
+ out = f1_
190
+ out2 = f2
191
+ if return_all_blocks:
192
+ _out, out = out, []
193
+ for blk in self.dec_blocks:
194
+ _out, out2 = blk(_out, out2, pos1, pos2)
195
+ out.append(_out)
196
+ out[-1] = self.dec_norm(out[-1])
197
+ else:
198
+ for blk in self.dec_blocks:
199
+ out, out2 = blk(out, out2, pos1, pos2)
200
+ out = self.dec_norm(out)
201
+ return out
202
+
203
+ def patchify(self, imgs):
204
+ """
205
+ imgs: (B, 3, H, W)
206
+ x: (B, L, patch_size**2 *3)
207
+ """
208
+ p = self.patch_embed.patch_size[0]
209
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
210
+
211
+ h = w = imgs.shape[2] // p
212
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
213
+ x = torch.einsum('nchpwq->nhwpqc', x)
214
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
215
+
216
+ return x
217
+
218
+ def unpatchify(self, x, channels=3):
219
+ """
220
+ x: (N, L, patch_size**2 *channels)
221
+ imgs: (N, 3, H, W)
222
+ """
223
+ patch_size = self.patch_embed.patch_size[0]
224
+ h = w = int(x.shape[1]**.5)
225
+ assert h * w == x.shape[1]
226
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
227
+ x = torch.einsum('nhwpqc->nchpwq', x)
228
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
229
+ return imgs
230
+
231
+ def forward(self, img1, img2):
232
+ """
233
+ img1: tensor of size B x 3 x img_size x img_size
234
+ img2: tensor of size B x 3 x img_size x img_size
235
+
236
+ out will be B x N x (3*patch_size*patch_size)
237
+ masks are also returned as B x N just in case
238
+ """
239
+ # encoder of the masked first image
240
+ feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
241
+ # encoder of the second image
242
+ feat2, pos2, _ = self._encode_image(img2, do_mask=False)
243
+ # decoder
244
+ decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
245
+ # prediction head
246
+ out = self.prediction_head(decfeat)
247
+ # get target
248
+ target = self.patchify(img1)
249
+ return out, mask1, target