jadechoghari commited on
Commit
82e91bf
1 Parent(s): 7b3f3a7

Create renderer.py

Browse files
Files changed (1) hide show
  1. renderer.py +314 -0
renderer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
7
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
8
+ #
9
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
10
+ # property and proprietary rights in and to this material, related
11
+ # documentation and any modifications thereto. Any use, reproduction,
12
+ # disclosure or distribution of this material and related documentation
13
+ # without an express license agreement from NVIDIA CORPORATION or
14
+ # its affiliates is strictly prohibited.
15
+ #
16
+ # Modified by Zexin He
17
+ # The modifications are subject to the same license as the original.
18
+
19
+
20
+ """
21
+ The renderer is a module that takes in rays, decides where to sample along each
22
+ ray, and computes pixel colors using the volume rendering equation.
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ from .ray_marcher import MipRayMarcher2
30
+ from . import math_utils
31
+
32
+ def generate_planes():
33
+ """
34
+ Defines planes by the three vectors that form the "axes" of the
35
+ plane. Should work with arbitrary number of planes and planes of
36
+ arbitrary orientation.
37
+
38
+ Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
39
+ """
40
+ return torch.tensor([[[1, 0, 0],
41
+ [0, 1, 0],
42
+ [0, 0, 1]],
43
+ [[1, 0, 0],
44
+ [0, 0, 1],
45
+ [0, 1, 0]],
46
+ [[0, 0, 1],
47
+ [0, 1, 0],
48
+ [1, 0, 0]]], dtype=torch.float32)
49
+
50
+ def project_onto_planes(planes, coordinates):
51
+ """
52
+ Does a projection of a 3D point onto a batch of 2D planes,
53
+ returning 2D plane coordinates.
54
+
55
+ Takes plane axes of shape n_planes, 3, 3
56
+ # Takes coordinates of shape N, M, 3
57
+ # returns projections of shape N*n_planes, M, 2
58
+ """
59
+ N, M, C = coordinates.shape
60
+ n_planes, _, _ = planes.shape
61
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
62
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
63
+ coordinates = coordinates.to(inv_planes.device)
64
+ projections = torch.bmm(coordinates, inv_planes)
65
+ return projections[..., :2]
66
+
67
+ def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
68
+ assert padding_mode == 'zeros'
69
+ N, n_planes, C, H, W = plane_features.shape
70
+ _, M, _ = coordinates.shape
71
+ plane_features = plane_features.view(N*n_planes, C, H, W)
72
+
73
+ coordinates = (2/box_warp) * coordinates # add specific box bounds
74
+ # half added here
75
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
76
+ # removed float from projected_coordinates
77
+ output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
78
+ return output_features
79
+
80
+ def sample_from_3dgrid(grid, coordinates):
81
+ """
82
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
83
+ Expects grid in shape (1, channels, H, W, D)
84
+ (Also works if grid has batch size)
85
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
86
+ """
87
+ batch_size, n_coords, n_dims = coordinates.shape
88
+ sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
89
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
90
+ mode='bilinear', padding_mode='zeros', align_corners=False)
91
+ N, C, H, W, D = sampled_features.shape
92
+ sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
93
+ return sampled_features
94
+
95
+ class ImportanceRenderer(torch.nn.Module):
96
+ """
97
+ Modified original version to filter out-of-box samples as TensoRF does.
98
+
99
+ Reference:
100
+ TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
101
+ """
102
+ def __init__(self):
103
+ super().__init__()
104
+ self.activation_factory = self._build_activation_factory()
105
+ self.ray_marcher = MipRayMarcher2(self.activation_factory)
106
+ self.plane_axes = generate_planes()
107
+
108
+ def _build_activation_factory(self):
109
+ def activation_factory(options: dict):
110
+ if options['clamp_mode'] == 'softplus':
111
+ return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
112
+ else:
113
+ assert False, "Renderer only supports `clamp_mode`=`softplus`!"
114
+ return activation_factory
115
+
116
+ def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
117
+ planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
118
+ """
119
+ Additional filtering is applied to filter out-of-box samples.
120
+ Modifications made by Zexin He.
121
+ """
122
+
123
+ # context related variables
124
+ batch_size, num_rays, samples_per_ray, _ = depths.shape
125
+ device = planes.device
126
+ depths = depths.to(device)
127
+ ray_directions = ray_directions.to(device)
128
+ ray_origins = ray_origins.to(device)
129
+ # define sample points with depths
130
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
131
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
132
+
133
+ # filter out-of-box samples
134
+ mask_inbox = \
135
+ (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
136
+ (sample_coordinates <= rendering_options['sampler_bbox_max'])
137
+ mask_inbox = mask_inbox.all(-1)
138
+
139
+ # forward model according to all samples
140
+ _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
141
+
142
+ # set out-of-box samples to zeros(rgb) & -inf(sigma)
143
+ SAFE_GUARD = 3
144
+ DATA_TYPE = _out['sigma'].dtype
145
+ colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
146
+ densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
147
+ colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
148
+
149
+ # reshape back
150
+ colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
151
+ densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
152
+
153
+ return colors_pass, densities_pass
154
+
155
+ def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
156
+ # self.plane_axes = self.plane_axes.to(ray_origins.device)
157
+
158
+ if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
159
+ ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
160
+ is_ray_valid = ray_end > ray_start
161
+ if torch.any(is_ray_valid).item():
162
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
163
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
164
+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
165
+ else:
166
+ # Create stratified depth samples
167
+ depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
168
+
169
+ depths_coarse = depths_coarse.to(planes.device)
170
+
171
+ # Coarse Pass
172
+ colors_coarse, densities_coarse = self._forward_pass(
173
+ depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
174
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
175
+
176
+ # Fine Pass
177
+ N_importance = rendering_options['depth_resolution_importance']
178
+ if N_importance > 0:
179
+ _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
180
+
181
+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
182
+
183
+ colors_fine, densities_fine = self._forward_pass(
184
+ depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
185
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
186
+
187
+ all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
188
+ depths_fine, colors_fine, densities_fine)
189
+
190
+ # Aggregate
191
+ rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
192
+ else:
193
+ rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
194
+
195
+ return rgb_final, depth_final, weights.sum(2)
196
+
197
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
198
+ plane_axes = self.plane_axes.to(planes.device)
199
+ sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
200
+
201
+ out = decoder(sampled_features, sample_directions)
202
+ if options.get('density_noise', 0) > 0:
203
+ out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
204
+ return out
205
+
206
+ def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
207
+ out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
208
+ out['sigma'] = self.activation_factory(options)(out['sigma'])
209
+ return out
210
+
211
+ def sort_samples(self, all_depths, all_colors, all_densities):
212
+ _, indices = torch.sort(all_depths, dim=-2)
213
+ all_depths = torch.gather(all_depths, -2, indices)
214
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
215
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
216
+ return all_depths, all_colors, all_densities
217
+
218
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
219
+ all_depths = torch.cat([depths1, depths2], dim = -2)
220
+ all_colors = torch.cat([colors1, colors2], dim = -2)
221
+ all_densities = torch.cat([densities1, densities2], dim = -2)
222
+
223
+ _, indices = torch.sort(all_depths, dim=-2)
224
+ all_depths = torch.gather(all_depths, -2, indices)
225
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
226
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
227
+
228
+ return all_depths, all_colors, all_densities
229
+
230
+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
231
+ """
232
+ Return depths of approximately uniformly spaced samples along rays.
233
+ """
234
+ N, M, _ = ray_origins.shape
235
+ if disparity_space_sampling:
236
+ depths_coarse = torch.linspace(0,
237
+ 1,
238
+ depth_resolution,
239
+ device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
240
+ depth_delta = 1/(depth_resolution - 1)
241
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
242
+ depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
243
+ else:
244
+ if type(ray_start) == torch.Tensor:
245
+ depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
246
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
247
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
248
+ else:
249
+ depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
250
+ depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
251
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
252
+
253
+ return depths_coarse
254
+
255
+ def sample_importance(self, z_vals, weights, N_importance):
256
+ """
257
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
258
+ """
259
+ with torch.no_grad():
260
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
261
+
262
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
263
+ weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
264
+
265
+ # smooth weights
266
+ weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
267
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
268
+ weights = weights + 0.01
269
+
270
+ z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
271
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
272
+ N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
273
+ return importance_z_vals
274
+
275
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
276
+ """
277
+ Sample @N_importance samples from @bins with distribution defined by @weights.
278
+ Inputs:
279
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
280
+ weights: (N_rays, N_samples_)
281
+ N_importance: the number of samples to draw from the distribution
282
+ det: deterministic or not
283
+ eps: a small number to prevent division by zero
284
+ Outputs:
285
+ samples: the sampled samples
286
+ """
287
+ N_rays, N_samples_ = weights.shape
288
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
289
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
290
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
291
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
292
+ # padded to 0~1 inclusive
293
+
294
+ if det:
295
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
296
+ u = u.expand(N_rays, N_importance)
297
+ else:
298
+ u = torch.rand(N_rays, N_importance, device=bins.device)
299
+ u = u.contiguous()
300
+
301
+ inds = torch.searchsorted(cdf, u, right=True)
302
+ below = torch.clamp_min(inds-1, 0)
303
+ above = torch.clamp_max(inds, N_samples_)
304
+
305
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
306
+ cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
307
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
308
+
309
+ denom = cdf_g[...,1]-cdf_g[...,0]
310
+ denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
311
+ # anyway, therefore any value for it is fine (set to 1 here)
312
+
313
+ samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
314
+ return samples