NeRF是什么?基于NeRF的三维重建是基于体素吗?

人工智能 新闻
NeRF是一种生成模型,以图像和精确姿势为条件,生成给定图像的3D场景的新视图,这一过程通常被称为“新视图合成”。

本文经自动驾驶之心公众号授权转载,转载请联系出处。

1介绍

神经辐射场(NeRF)是深度学习和计算机视觉领域的一个相当新的范式。ECCV 2020论文《NeRF:将场景表示为视图合成的神经辐射场》(该论文获得了最佳论文奖)中介绍了这项技术,该技术自此大受欢迎,迄今已获得近800次引用[1]。该方法标志着机器学习处理3D数据的传统方式发生了巨大变化。

神经辐射场场景表示和可微分渲染过程:

通过沿着相机射线采样5D坐标(位置和观看方向)来合成图像; 将这些位置输入MLP以产生颜色和体积密度; 并使用体积渲染技术将这些值合成图像; 该渲染函数是可微分的,因此可以通过最小化合成图像和真实观测图像之间的残差来优化场景表示。

2 What is a NeRF?

NeRF是一种生成模型,以图像和精确姿势为条件,生成给定图像的3D场景的新视图,这一过程通常被称为“新视图合成”。不仅如此,它还将场景的3D形状和外观明确定义为连续函数,可以通过marching cubes生成3D网格。尽管它们直接从图像数据中学习,但它们既不使用convolutional层,也不使用transformer层。

多年来,机器学习应用中表示3D数据的方法很多,从3D体素到点云,再到符号距离(signed distance )函数。他们最大的共同缺点是需要预先假设一个3D模型,要么使用摄影测量或激光雷达等工具来生成3D数据,要么手工制作3D模型。然而,许多类型的物体,如高反射物体、“网格状”物体或透明物体,都无法按比例扫描。3D重建方法通常也具有重建误差,这可能导致影响模型精度的阶梯效应或漂移。

相比之下,NeRF基于射线光场的概念。光场是描述光传输如何在整个3D体积中发生的函数。它描述了光线在空间中的每个x=(x,y,z)坐标和每个方向d上移动的方向,描述为θ和ξ角或单位向量。它们共同形成了描述3D场景中的光传输的5D特征空间。受此表示的启发,NeRF试图近似一个函数,该函数从该空间映射到由颜色c=(R,G,B)和浓度(density)σ组成的4D空间,可以将其视为该5D坐标空间处的光线终止的可能性(例如通过遮挡)。因此,标准NeRF是形式F:(x,d)->(c,σ)的函数。

原始的NeRF论文使用多层感知器将该函数参数化,该感知器基于一组姿势已知的图像上训练得到。这是一类称为generalized scene reconstruction的技术中的一种方法,旨在直接从图像集合中描述3D场景。这种方法具备一些非常好的特性:

  • 直接从数据中学习
  • 场景的连续表示允许非常薄和复杂的结构,例如树叶或网格
  • 隐含物理特性,如镜面性和粗糙度
  • 隐式呈现场景中的照明

此后,一系列的改进论文随之涌现,例如,少镜头和单镜头学习[2,3]、对动态场景的支持[4,5]、将光场推广到特征场[6]、从网络上的未校准图像集合中学习[7]、结合激光雷达数据[8]、大规模场景表示[9]、在没有神经网络的情况下学习[10],诸如此类。

3 NeRF Architecture

总体而言,给定一个经过训练的NeRF模型和一个具有已知姿势和图像维度的相机,我们通过以下过程构建场景:

  • 对于每个像素,从相机光心射出光线穿过场景,以在(x,d)位置收集一组样本
  • 使用每个样本的点和视线方向(x,d)作为输入,以产生输出(c,σ)值(rgbσ)
  • 使用经典的体积渲染技术构建图像

光射场(很多文献翻译为"辐射场",但译者认为"光射场"更直观)函数只是几个组件中的一个,一旦组合起来,就可以创建之前看到的视频中的视觉效果。总体而言,本文包括以下几个部分:

  • 位置编码(Positional encoding)
  • 光射场函数近似器(MLP)
  • 可微分体渲染器(Differentiable volume renderer)
  • 分层(Stratified)取样 层次(Hierarchical)体积采样

为了最大限度地清晰讲述,本文将每个组件的关键元素以尽可能简洁的代码展示。参考了bmild的原始实现和yenchenlin和krrish94的PyTorch实现。

3.1 Positional Encoder

就像2017年推出的transformer模型[11]一样,NeRF也受益于位置编码器作为其输入。它使用高频函数将其连续输入映射到更高维的空间,以帮助模型学习数据中的高频变化,从而产生更清晰的模型。这种方法避开(circumvent)了神经网络对低频函数偏置(bias),使NeRF能够表示更清晰的细节。作者参考了ICML 2019上的一篇论文[12]。

如果熟悉transformerd的位置编码,NeRF的相关实现是很标准的,它具有相同的交替正弦和余弦表达式。位置编码器实现:

# py
class PositionalEncoder(nn.Module):
  # sine-cosine positional encoder for input points.
  def __init__( self,
                d_input: int,
                n_freqs: int,
                log_space: bool = False ):
    super().__init__()
    self.d_input = d_input
    self.n_freqs = n_freqs         # 是不是视线上的采样频率?
    self.log_space = log_space
    self.d_output = d_input * (1 + 2 * self.n_freqs)
    self.embed_fns = [lambda x: x] # 冒号前面的x表示函数参数,后面的表示匿名函数运算

    # Define frequencies in either linear or log scale
    if self.log_space:
      freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
    else:
      freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)

    # Alternate sin and cos
    for freq in freq_bands:
      self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
      self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
  
  def forward(self, x) -> torch.Tensor:
    # Apply positional encoding to input.
    return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)

思考:这个位置编码,是对输入点(input points)进行编码,这个输入点是视线上的采样点?还是不同的视角位置点? self.n_freqs是不是视线上的采样频率?由此理解,应该就是视线上的采样位置,因为如果不对视线上的采样位置进行编码,就无法有效表示这些位置,也就无法对它们的RGBA进行训练。

3.2 Radiance Field Function

在原文中,光射场函数由NeRF模型表示,NeRF模型是一种典型的多层感知器,以编码的3D点和视角方向作为输入,并返回RGBA值作为输出。虽然本文使用的是神经网络,但这里可以使用任何函数逼近器(function approximator)。例如,Yu等人的后续论文Plenoxels使用球面谐波(spherical harmonics)实现了数量级的更快训练,同时获得有竞争力的结果[10]。

图片图片

NeRF模型有8层深,大多数层的特征维度为256。剩余的连接被放置在层4处。在这些层之后,产生RGB和σ值。RGB值用线性层进一步处理,然后与视线方向连接,然后通过另一个线性层,最后在输出处与σ重新组合。NeRF模型的PyTorch模块实现:

class NeRF(nn.Module):
  # Neural radiance fields module.
  def __init__( self,
                d_input: int = 3,                  
                n_layers: int = 8,
                d_filter: int = 256,
                skip: Tuple[int] = (4,), # (4,)只有一个元素4的元组       
                d_viewdirs: Optional[int] = None): 

    super().__init__()
    self.d_input = d_input        # 这里是3D XYZ,?
    self.skip = skip              # 是要跳过什么?为啥要跳过?被遮挡?
    self.act = nn.functional.relu
    self.d_viewdirs = d_viewdirs  # d_viewdirs 是2D方向?

    # Create model layers
    # [if_true 就执行的指令] if [if_true条件] else [if_false]
    # 是否skip的区别是,训练输入维度是否多3维,
    # if i in skip =  if i in (4,),似乎是判断i是否等于4
    # self.d_input=3 :如果层id=4,网络输入要加3维,这是为什么?第4层有何特殊的?
    self.layers = nn.ModuleList(
      [nn.Linear(self.d_input, d_filter)] +
      [nn.Linear(d_filter + self.d_input, d_filter) if i in skip else \
       nn.Linear(d_filter  , d_filter) for i in range(n_layers - 1)]
    )

    # Bottleneck layers
    if self.d_viewdirs is not None:
      # If using viewdirs, split alpha and RGB
      self.alpha_out = nn.Linear(d_filter, 1)
      self.rgb_filters = nn.Linear(d_filter, d_filter)
      self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
      self.output = nn.Linear(d_filter // 2, 3) # 为啥要取一半?
    else:
      # If no viewdirs, use simpler output
      self.output = nn.Linear(d_filter, 4) # d_filter=256,输出是4维RGBA
  
  def forward(self,
              x: torch.Tensor, # ?
              viewdirs: Optional[torch.Tensor] = None) -> torch.Tensor:
   
    # Forward pass with optional view direction.
    if self.d_viewdirs is None and viewdirs is not None:
      raise ValueError('Cannot input x_direction')

    # Apply forward pass up to bottleneck
    x_input = x    # 这里的x是几维?从下面的分离RGB和A看,应该是4D
    # 下面通过8层MLP训练RGBA
    for i, layer in enumerate(self.layers):  # 8层,每一层进行运算
      x = self.act(layer(x)) 
      if i in self.skip:
        x = torch.cat([x, x_input], dim=-1)

    # Apply bottleneck  bottleneck 瓶颈是啥?是不是最费算力的模块?
    if self.d_viewdirs is not None:
      # 从网络输出分离A,RGB还需要经过更多训练
      alpha = self.alpha_out(x)  
      
      # Pass through bottleneck to get RGB
      x = self.rgb_filters(x) 
      x = torch.concat([x, viewdirs], dim=-1)
      x = self.act(self.branch(x)) # self.branch shape: (d_filter // 2)
      x = self.output(x)           # self.output shape: (3)

      # Concatenate alphas to output
      x = torch.concat([x, alpha], dim=-1)
    else:
      # Simple output
      x = self.output(x)
    return x

思考:这个NERF类的输入输出是什么?通过这个类发生了啥?从__init__函数参数看出,主要是对神经网络的输入、层次和维度等进行设置,输入了5D数据,也就是视点位置和视线方向,输出的是RGBA。问题,这个输出的RGBA是一个点的?还是视线上一串的?如果是一串的,没有看到位置编码如何确定每个采样点的RGBA?

也没看到采样间隔之类的说明;如果是一个点,那这个RGBA是视线上哪个点的?是不是眼睛看到的视线采样点集合成后的点RGBA?从NERF类代码可以看出,主要是根据视点位置和视线方向进行了多层前馈训练,输入5D的视点位置和视线方向,输出4D的RGBA。

3.3 可微分体渲染器(Differentiable Volume Renderer)

RGBA输出点位于3D空间中,因此要将它们合成图像,需要应用论文第4节中方程1-3中描述的体积积分。本质上,沿着每个像素的视线对所有样本进行加权求和,以获得该像素的估计颜色值。每个RGB采样都按其透明度alpha值进行加权:α值越高,表示采样区域不透明的可能性越高,因此沿射线更远的点更有可能被遮挡。累积乘积运算确保了这些进一步的点被抑制。

原始NeRF模型输出的体绘制:

def raw2outputs(raw: torch.Tensor,
                z_vals: torch.Tensor,
                rays_d: torch.Tensor,
                raw_noise_std: float = 0.0,
                white_bkgd: bool = False) 
  -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

  # 将原始的NeRF输出转为RGB和其他映射
  # Difference between consecutive elements of `z_vals`. [n_rays, n_samples]
  dists = z_vals[..., 1:] - z_vals[..., :-1]# ?这里减法的意义是啥?
  dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1)

  # 将每个距离乘以其对应方向光线的范数,以转换为真实世界的距离(考虑非单位方向)
  dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

  # 将噪声添加到模型对密度的预测中,用于在训练期间规范网络(防止漂浮物伪影)
  noise = 0.
  if raw_noise_std > 0.:
    noise = torch.randn(raw[..., 3].shape) * raw_noise_std

  # Predict density of each sample along each ray. Higher values imply
  # higher likelihood of being absorbed at this point. [n_rays, n_samples]
  alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists)

  # Compute weight for RGB of each sample along each ray. [n_rays, n_samples]
  # The higher the alpha, the lower subsequent weights are driven.
  weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)

  # Compute weighted RGB map.
  rgb = torch.sigmoid(raw[..., :3])  # [n_rays, n_samples, 3]
  rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # [n_rays, 3]

  # Estimated depth map is predicted distance.
  depth_map = torch.sum(weights * z_vals, dim=-1)

  # Disparity map is inverse depth.
  disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
                            depth_map / torch.sum(weights, -1))

  # Sum of weights along each ray. In [0, 1] up to numerical error.
  acc_map = torch.sum(weights, dim=-1)

  # To composite onto a white background, use the accumulated alpha map.
  if white_bkgd:
    rgb_map = rgb_map + (1. - acc_map[..., None])

  return rgb_map, depth_map, acc_map, weights


def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
  # (Courtesy of https://github.com/krrish94/nerf-pytorch)
  # Compute regular cumprod first.
  cumprod = torch.cumprod(tensor, -1)
  # "Roll" the elements along dimension 'dim' by 1 element.
  cumprod = torch.roll(cumprod, 1, -1)
  # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
  cumprod[..., 0] = 1.
  return cumprod

问题:这里的主要功能是啥?输入了什么?输出了什么?

3.4 Stratified Sampling

相机最终拾取到的RGB值是沿着穿过该像素视线的光样本的累积,经典的体积渲染方法是沿着该视线累积点,然后对点进行积分,在每个点估计光线在不撞击任何粒子的情况下射行的概率。因此,每个像素都需要沿着穿过它的光线对点进行采样。为了最好地近似积分,他们的分层采样方法是将空间均匀地划分为N个仓(bins),并从每个仓中均匀地抽取一个样本。stratified sampling方法不是简单地以相等的间隔绘制样本,而是允许模型在连续空间中采样,从而调节网络在连续空间上学习。

图片图片

分层采样PyTorch中实现:

def sample_stratified(rays_o: torch.Tensor,
                      rays_d: torch.Tensor,
                      near: float,
                      far: float,
                      n_samples: int,
                      perturb: Optional[bool] = True,
                      inverse_depth: bool = False)
   -> Tuple[torch.Tensor, torch.Tensor]:
  # Sample along ray from regularly-spaced bins.
  # Grab samples for space integration along ray
  t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
  if not inverse_depth:
    # Sample linearly between `near` and `far`
    z_vals = near * (1.-t_vals) + far * (t_vals)
  else:
    # Sample linearly in inverse depth (disparity)
    z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

  # Draw uniform samples from bins along ray
  if perturb:
    mids = .5 * (z_vals[1:] + z_vals[:-1])
    upper = torch.concat([mids, z_vals[-1:]], dim=-1)
    lower = torch.concat([z_vals[:1], mids], dim=-1)
    t_rand = torch.rand([n_samples], device=z_vals.device)
    z_vals = lower + (upper - lower) * t_rand
  z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])

  # Apply scale from `rays_d` and offset from `rays_o` to samples
  # pts: (width, height, n_samples, 3)
  pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
  return pts, z_vals

3.5 层次体积采样(Hierarchical Volume Sampling)

辐射场由两个多层感知器表示:一个是在粗略级别上操作,对场景的广泛结构属性进行编码;另一个是在精细的层面上细化细节,从而实现网格和分支等薄而复杂的结构。此外,他们接收的样本是不同的,粗模型在整个射线中处理宽的、大多是规则间隔的样本,而精细模型在具有强先验的区域中珩磨(honing in)以获得显著信息。

这种“珩磨”过程是通过层次体积采样流程完成的。3D空间实际上非常稀疏,存在遮挡,因此大多数点对渲染图像的贡献不大。因此,对具有对积分贡献可能性高的区域进行过采样(oversample)更有好处。他们将学习到的归一化权重应用于第一组样本,以在光线上创建PDF,然后再将inverse transform sampling应用于该PDF以收集第二组样本。该集合与第一集合相结合,并被馈送到精细网络以产生最终输出。

分层采样PyTorch实现:

def sample_hierarchical(rays_o: torch.Tensor,
                        rays_d: torch.Tensor,
                        z_vals: torch.Tensor,
                        weights: torch.Tensor,
                        n_samples: int,
                        perturb: bool = False) 
  -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  # Apply hierarchical sampling to the rays.
  # Draw samples from PDF using z_vals as bins and weights as probabilities.
  z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
  new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, perturb=perturb)
  new_z_samples = new_z_samples.detach()

  # Resample points from ray based on PDF.
  z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
  # [N_rays, N_samples + n_samples, 3]
  pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None]  
  return pts, z_vals_combined, new_z_samples

def sample_pdf(bins: torch.Tensor,
               weights: torch.Tensor,
               n_samples: int,
               perturb: bool = False) -> torch.Tensor:
  # Apply inverse transform sampling to a weighted set of points.
  # Normalize weights to get PDF.
  # [n_rays, weights.shape[-1]]
  pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) 

  # Convert PDF to CDF.
  cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]
  # [n_rays, weights.shape[-1] + 1]
  cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) 

  # Take sample positions to grab from CDF. Linear when perturb == 0.
  if not perturb:
    u = torch.linspace(0., 1., n_samples, device=cdf.device)
    u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]
  else:
    # [n_rays, n_samples]
    u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) 

  # Find indices along CDF where values in u would be placed.
  u = u.contiguous() # Returns contiguous tensor with same values.
  inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]

  # Clamp indices that are out of bounds.
  below = torch.clamp(inds - 1, min=0)
  above = torch.clamp(inds, max=cdf.shape[-1] - 1)
  inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]

  # Sample from cdf and the corresponding bin centers.
  matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]
  cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,index=inds_g)
  bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1, index=inds_g)

  # Convert samples to ray length.
  denom = (cdf_g[..., 1] - cdf_g[..., 0])
  denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
  t = (u - cdf_g[..., 0]) / denom
  samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

  return samples # [n_rays, n_samples]

4 Training

论文中训练NeRF推荐的每网络8层、每层256维的架构在训练过程中会消耗大量内存。缓解这种情况的方法是将前传(forward pass)分成更小的部分,然后在这些部分上积累梯度。注意与minibatching的区别:梯度是在采样光线的单个小批次上累积的,这些光线可能已经被收集成块。如果没有论文中使用的NVIDIA V100类似性能的GPU,可能必须相应地调整块大小以避免OOM错误。Colab笔记本采用了更小的架构和更适中的分块尺寸。

我个人发现,由于局部极小值,即使选择了许多默认值,NeRF的训练也有些棘手。一些有帮助的技术包括早期训练迭代和早期重新启动期间的中心裁剪(center cropping)。随意尝试不同的超参数和技术,以进一步提高训练收敛性。

初始化

def init_models():
  # Initialize models, encoders, and optimizer for NeRF training.
  encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space)
  encode = lambda x: encoder(x)
  # View direction encoders
  if use_viewdirs:
    encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,log_space=log_space)
    encode_viewdirs  = lambda x: encoder_viewdirs(x)
    d_viewdirs       = encoder_viewdirs.d_output
  else:
    encode_viewdirs = None
    d_viewdirs = None

  model = NeRF(encoder.d_output, 
               n_layers=n_layers, 
               d_filter=d_filter, skip=skip,d_viewdirs=d_viewdirs)
  model.to(device)
  model_params = list(model.parameters())
  if use_fine_model:
    fine_model = NeRF(encoder.d_output, 
                      n_layers=n_layers, 
                      d_filter=d_filter, skip=skip,d_viewdirs=d_viewdirs)
    fine_model.to(device)
    model_params = model_params + list(fine_model.parameters())
  else:
    fine_model = None

  optimizer      = torch.optim.Adam(model_params, lr=lr)
  warmup_stopper = EarlyStopping(patience=50)
  return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper

训练

def train():
  # Launch training session for NeRF.
  # Shuffle rays across all images.
  if not one_image_per_step:
    height, width = images.shape[1:3]
    all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), 0)
                           for p in poses[:n_training]], 0)
    rays_rgb = torch.cat([all_rays, images[:, None]], 1)
    rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4])
    rays_rgb = rays_rgb.reshape([-1, 3, 3])
    rays_rgb = rays_rgb.type(torch.float32)
    rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
    i_batch = 0

  train_psnrs = []
  val_psnrs = []
  iternums = []
  for i in trange(n_iters):
    model.train()
    if one_image_per_step:
      # Randomly pick an image as the target.
      target_img_idx = np.random.randint(images.shape[0])
      target_img     = images[target_img_idx].to(device)
      if center_crop and i < center_crop_iters:
        target_img = crop_center(target_img)
      height, width = target_img.shape[:2]
      target_pose = poses[target_img_idx].to(device)
      rays_o, rays_d = get_rays(height, width, focal, target_pose)
      rays_o = rays_o.reshape([-1, 3])
      rays_d = rays_d.reshape([-1, 3])
    else:
      # Random over all images.
      batch = rays_rgb[i_batch:i_batch + batch_size]
      batch = torch.transpose(batch, 0, 1)
      rays_o, rays_d, target_img = batch
      height, width = target_img.shape[:2]
      i_batch += batch_size
      # Shuffle after one epoch
      if i_batch >= rays_rgb.shape[0]:
          rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
          i_batch = 0
    target_img = target_img.reshape([-1, 3])

    # Run one iteration of TinyNeRF and get the rendered RGB image.
    outputs = nerf_forward(rays_o, rays_d,
                           near, far, encode, model,
                           kwargs_sample_stratified=kwargs_sample_stratified,
                           n_samples_hierarchical=n_samples_hierarchical,
                           kwargs_sample_hierarchical=kwargs_sample_hierarchical,
                           fine_model=fine_model,
                           viewdirs_encoding_fn=encode_viewdirs,
                           chunksize=chunksize)
    # Backprop!
    rgb_predicted = outputs['rgb_map']
    loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    psnr = -10. * torch.log10(loss)
    train_psnrs.append(psnr.item())

    # Evaluate testimg at given display rate.
    if i % display_rate == 0:
      model.eval()
      height, width = testimg.shape[:2]
      rays_o, rays_d = get_rays(height, width, focal, testpose)
      rays_o = rays_o.reshape([-1, 3])
      rays_d = rays_d.reshape([-1, 3])
      outputs = nerf_forward(rays_o, rays_d,
                             near, far, encode, model,
                             kwargs_sample_stratified=kwargs_sample_stratified,
                             n_samples_hierarchical=n_samples_hierarchical,
                             kwargs_sample_hierarchical=kwargs_sample_hierarchical,
                             fine_model=fine_model,
                             viewdirs_encoding_fn=encode_viewdirs,
                             chunksize=chunksize)
      rgb_predicted = outputs['rgb_map']
      loss = torch.nn.functional.mse_loss(rgb_predicted, testimg.reshape(-1, 3))
      val_psnr = -10. * torch.log10(loss)
      val_psnrs.append(val_psnr.item())
      iternums.append(i)

    # Check PSNR for issues and stop if any are found.
    if i == warmup_iters - 1:
      if val_psnr < warmup_min_fitness:
        return False, train_psnrs, val_psnrs
    elif i < warmup_iters:
      if warmup_stopper is not None and warmup_stopper(i, psnr):
        return False, train_psnrs, val_psnrs

  return True, train_psnrs, val_psnrs

训练

# Run training session(s)
for _ in range(n_restarts):
  model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper = init_models()
  success, train_psnrs, val_psnrs = train()
  if success and val_psnrs[-1] >= warmup_min_fitness:
    print('Training successful!')
    break
print(f'Done!')

5 Conclusion

辐射场标志着处理3D数据的方式发生了巨大变化。NeRF模型和更广泛的可微分渲染正在迅速弥合图像创建和体积场景创建之间的差距。虽然我们的组件可能看起来非常复杂,但受vanilla NeRF启发的无数其他方法证明,基本概念(连续函数逼近器+可微分渲染器)是构建各种解决方案的坚实基础,这些解决方案可用于几乎无限的情况。

原文:NeRF From Nothing: A Tutorial with PyTorch | Towards Data Science

原文链接:https://mp.weixin.qq.com/s/zxJAIpAmLgsIuTsPqQqOVg

责任编辑:张燕妮 来源: 自动驾驶之心
相关推荐

2021-10-09 15:36:31

技术研发三维

2024-02-20 09:46:00

模型技术

2023-12-13 10:14:00

机器视觉技术

2023-08-05 13:53:34

2023-10-27 14:54:33

智能驾驶云计算

2024-06-19 11:30:36

2022-12-09 10:00:23

2022-12-22 10:15:05

神经网络AI

2021-03-16 09:53:35

人工智能机器学习技术

2023-06-02 14:10:05

三维重建

2021-04-14 15:03:16

数据性能存储

2024-02-29 09:38:13

神经网络模型

2022-09-26 15:18:15

3D智能

2022-02-25 23:46:16

神经网络机器学习AI

2022-08-10 10:00:00

人工智能三维模型编程技术

2023-01-31 12:30:26

模型代码

2021-04-07 10:13:51

人工智能深度学习

2024-02-06 09:55:33

框架代码

2023-04-03 11:52:51

6D英伟达

2024-04-30 09:54:59

模型训练
点赞
收藏

51CTO技术栈公众号