jadechoghari commited on
Commit
3284218
1 Parent(s): 1371451

update renderer,

Browse files

we include math utils functions here to avoid import issues

Files changed (1) hide show
  1. renderer.py +103 -2
renderer.py CHANGED
@@ -29,6 +29,104 @@ import torch.nn.functional as F
29
  from .ray_marcher import MipRayMarcher2
30
  from . import math_utils
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def generate_planes():
33
  """
34
  Defines planes by the three vectors that form the "axes" of the
@@ -47,6 +145,7 @@ def generate_planes():
47
  [0, 1, 0],
48
  [1, 0, 0]]], dtype=torch.float32)
49
 
 
50
  def project_onto_planes(planes, coordinates):
51
  """
52
  Does a projection of a 3D point onto a batch of 2D planes,
@@ -64,6 +163,7 @@ def project_onto_planes(planes, coordinates):
64
  projections = torch.bmm(coordinates, inv_planes)
65
  return projections[..., :2]
66
 
 
67
  def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
68
  assert padding_mode == 'zeros'
69
  N, n_planes, C, H, W = plane_features.shape
@@ -77,6 +177,7 @@ def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear',
77
  output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
78
  return output_features
79
 
 
80
  def sample_from_3dgrid(grid, coordinates):
81
  """
82
  Expects coordinates in shape (batch_size, num_points_per_batch, 3)
@@ -156,7 +257,7 @@ class ImportanceRenderer(torch.nn.Module):
156
  # self.plane_axes = self.plane_axes.to(ray_origins.device)
157
 
158
  if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
159
- ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
160
  is_ray_valid = ray_end > ray_start
161
  if torch.any(is_ray_valid).item():
162
  ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
@@ -242,7 +343,7 @@ class ImportanceRenderer(torch.nn.Module):
242
  depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
243
  else:
244
  if type(ray_start) == torch.Tensor:
245
- depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
246
  depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
247
  depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
248
  else:
 
29
  from .ray_marcher import MipRayMarcher2
30
  from . import math_utils
31
 
32
+ # Copied from .math_utils.transform_vectors
33
+ def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Left-multiplies MxM @ NxM. Returns NxM.
36
+ """
37
+ res = torch.matmul(vectors4, matrix.T)
38
+ return res
39
+
40
+ # Copied from .math_utils.normalize_vecs
41
+ def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Normalize vector lengths.
44
+ """
45
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
46
+
47
+ # Copied from .math_utils.torch_dot
48
+ def torch_dot(x: torch.Tensor, y: torch.Tensor):
49
+ """
50
+ Dot product of two tensors.
51
+ """
52
+ return (x * y).sum(-1)
53
+
54
+ # Copied from .math_utils.get_ray_limits_box
55
+ def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
56
+ """
57
+ Author: Petr Kellnhofer
58
+ Intersects rays with the [-1, 1] NDC volume.
59
+ Returns min and max distance of entry.
60
+ Returns -1 for no intersection.
61
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
62
+ """
63
+ o_shape = rays_o.shape
64
+ rays_o = rays_o.detach().reshape(-1, 3)
65
+ rays_d = rays_d.detach().reshape(-1, 3)
66
+
67
+
68
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
69
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
70
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
71
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
72
+
73
+ # Precompute inverse for stability.
74
+ invdir = 1 / rays_d
75
+ sign = (invdir < 0).long()
76
+
77
+ # Intersect with YZ plane.
78
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
79
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
80
+
81
+ # Intersect with XZ plane.
82
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
83
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
84
+
85
+ # Resolve parallel rays.
86
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
87
+
88
+ # Use the shortest intersection.
89
+ tmin = torch.max(tmin, tymin)
90
+ tmax = torch.min(tmax, tymax)
91
+
92
+ # Intersect with XY plane.
93
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
94
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
95
+
96
+ # Resolve parallel rays.
97
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
98
+
99
+ # Use the shortest intersection.
100
+ tmin = torch.max(tmin, tzmin)
101
+ tmax = torch.min(tmax, tzmax)
102
+
103
+ # Mark invalid.
104
+ tmin[torch.logical_not(is_valid)] = -1
105
+ tmax[torch.logical_not(is_valid)] = -2
106
+
107
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
108
+
109
+ # Copied from .math_utils.linspace
110
+ def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
111
+ """
112
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
113
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
114
+ """
115
+ # create a tensor of 'num' steps from 0 to 1
116
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
117
+
118
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
119
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
120
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
121
+ for i in range(start.ndim):
122
+ steps = steps.unsqueeze(-1)
123
+
124
+ # the output starts at 'start' and increments until 'stop' in each dimension
125
+ out = start[None] + steps * (stop - start)[None]
126
+
127
+ return out
128
+
129
+ # Copied from .math_utils.generate_planes
130
  def generate_planes():
131
  """
132
  Defines planes by the three vectors that form the "axes" of the
 
145
  [0, 1, 0],
146
  [1, 0, 0]]], dtype=torch.float32)
147
 
148
+ # Copied from .math_utils.project_onto_planes
149
  def project_onto_planes(planes, coordinates):
150
  """
151
  Does a projection of a 3D point onto a batch of 2D planes,
 
163
  projections = torch.bmm(coordinates, inv_planes)
164
  return projections[..., :2]
165
 
166
+ # Copied from .math_utils.sample_from_planes
167
  def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
168
  assert padding_mode == 'zeros'
169
  N, n_planes, C, H, W = plane_features.shape
 
177
  output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
178
  return output_features
179
 
180
+ # Copied from .math_utils.sample_from_3dgrid
181
  def sample_from_3dgrid(grid, coordinates):
182
  """
183
  Expects coordinates in shape (batch_size, num_points_per_batch, 3)
 
257
  # self.plane_axes = self.plane_axes.to(ray_origins.device)
258
 
259
  if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
260
+ ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
261
  is_ray_valid = ray_end > ray_start
262
  if torch.any(is_ray_valid).item():
263
  ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
 
343
  depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
344
  else:
345
  if type(ray_start) == torch.Tensor:
346
+ depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
347
  depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
348
  depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
349
  else: