#### modeling.py import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig import torch import numpy as np import math from .dino_wrapper2 import DinoWrapper from .transformer import TriplaneTransformer from .synthesizer_part import TriplaneSynthesizer class CameraEmbedder(nn.Module): def __init__(self, raw_dim: int, embed_dim: int): super().__init__() self.mlp = nn.Sequential( nn.Linear(raw_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim), ) def forward(self, x): return self.mlp(x) class LRMGeneratorConfig(PretrainedConfig): model_type = "lrm_generator" def __init__(self, **kwargs): super().__init__(**kwargs) self.camera_embed_dim = kwargs.get("camera_embed_dim", 1024) self.rendering_samples_per_ray = kwargs.get("rendering_samples_per_ray", 128) self.transformer_dim = kwargs.get("transformer_dim", 1024) self.transformer_layers = kwargs.get("transformer_layers", 16) self.transformer_heads = kwargs.get("transformer_heads", 16) self.triplane_low_res = kwargs.get("triplane_low_res", 32) self.triplane_high_res = kwargs.get("triplane_high_res", 64) self.triplane_dim = kwargs.get("triplane_dim", 80) self.encoder_freeze = kwargs.get("encoder_freeze", False) self.encoder_model_name = kwargs.get("encoder_model_name", 'facebook/dinov2-base') self.encoder_feat_dim = kwargs.get("encoder_feat_dim", 768) class LRMGenerator(PreTrainedModel): config_class = LRMGeneratorConfig def __init__(self, config: LRMGeneratorConfig): super().__init__(config) self.encoder_feat_dim = config.encoder_feat_dim self.camera_embed_dim = config.camera_embed_dim self.encoder = DinoWrapper( model_name=config.encoder_model_name, freeze=config.encoder_freeze, ) self.camera_embedder = CameraEmbedder( raw_dim=12 + 4, embed_dim=config.camera_embed_dim, ) self.transformer = TriplaneTransformer( inner_dim=config.transformer_dim, num_layers=config.transformer_layers, num_heads=config.transformer_heads, image_feat_dim=config.encoder_feat_dim, camera_embed_dim=config.camera_embed_dim, triplane_low_res=config.triplane_low_res, triplane_high_res=config.triplane_high_res, triplane_dim=config.triplane_dim, ) self.synthesizer = TriplaneSynthesizer( triplane_dim=config.triplane_dim, samples_per_ray=config.rendering_samples_per_ray, ) def forward(self, image, camera, export_mesh=False, mesh_size=512, render_size=384, export_video=False, fps=30): assert image.shape[0] == camera.shape[0], "Batch size mismatch" N = image.shape[0] # encode image image_feats = self.encoder(image) assert image_feats.shape[-1] == self.encoder_feat_dim, \ f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}" # embed camera camera_embeddings = self.camera_embedder(camera) assert camera_embeddings.shape[-1] == self.camera_embed_dim, \ f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}" with torch.no_grad(): # transformer generating planes planes = self.transformer(image_feats, camera_embeddings) assert planes.shape[0] == N, "Batch size mismatch for planes" assert planes.shape[1] == 3, "Planes should have 3 channels" # Generate the mesh if export_mesh: import mcubes import trimesh grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size) vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0) vtx = vtx / (mesh_size - 1) * 2 - 1 vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0) vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy() vtx_colors = (vtx_colors * 255).astype(np.uint8) mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) mesh_path = "awesome_mesh.obj" mesh.export(mesh_path, 'obj') return planes, mesh_path # Generate video if export_video: render_cameras = self._default_render_cameras(batch_size=N).to(image.device) frames = [] chunk_size = 1 # Adjust chunk size as needed for i in range(0, render_cameras.shape[1], chunk_size): frame_chunk = self.synthesizer( planes, render_cameras[:, i:i + chunk_size], render_size, render_size, 0, 0 ) frames.append(frame_chunk['images_rgb']) frames = torch.cat(frames, dim=1) frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) # Save video video_path = "awesome_video.mp4" imageio.mimwrite(video_path, frames, fps=fps) return planes, video_path return planes # Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py # and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py def _default_intrinsics(self): fx = fy = 384 cx = cy = 256 w = h = 512 intrinsics = torch.tensor([ [fx, fy], [cx, cy], [w, h], ], dtype=torch.float32) return intrinsics def _default_render_cameras(self, batch_size=1): M = 160 # Number of views radius = 1.5 elevation = 0 camera_positions = [] rand_theta = np.random.uniform(0, np.pi / 180) elevation = math.radians(elevation) for i in range(M): theta = 2 * math.pi * i / M + rand_theta x = radius * math.cos(theta) * math.cos(elevation) y = radius * math.sin(theta) * math.cos(elevation) z = radius * math.sin(elevation) camera_positions.append([x, y, z]) camera_positions = torch.tensor(camera_positions, dtype=torch.float32) extrinsics = self.center_looking_at_camera_pose(camera_positions) intrinsics = self._default_intrinsics().unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) render_cameras = self.build_camera_standard(extrinsics, intrinsics) return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1) def center_looking_at_camera_pose(self, camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None): if look_at is None: look_at = torch.tensor([0, 0, 0], dtype=torch.float32) if up_world is None: up_world = torch.tensor([0, 0, 1], dtype=torch.float32) look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) z_axis = camera_position - look_at z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True) x_axis = torch.cross(up_world, z_axis) x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True) y_axis = torch.cross(z_axis, x_axis) y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True) extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) return extrinsics def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor): fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] fx, fy = fx / width, fy / height cx, cy = cx / width, cy / height return fx, fy, cx, cy def build_camera_standard(self, RT: torch.Tensor, intrinsics: torch.Tensor): E = self.compose_extrinsic_RT(RT) fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics) I = torch.stack([ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1), torch.stack([torch.zeros_like(fy), fy, cy], dim=-1), torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1), ], dim=1) return torch.cat([ E.reshape(-1, 16), I.reshape(-1, 9), ], dim=-1) def compose_extrinsic_RT(self, RT: torch.Tensor): return torch.cat([ RT, torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device) ], dim=1)