Differentiable Sphere Tracing

Oct 9, 2020 03:00 · 3832 words · 8 minute read

CVPR'20の論文を眺めていたら,最近ハマっているGLSLと関連する面白そうな論文を見つけた.

  1. Differentiable Volumetric Rendering: Learning Implicit 3D Representations without 3D Supervision [Niemeyer et al., CVPR'20]
  2. DIST: Rendering Deep Implicit Signed Distance Function with Differentiable Sphere Tracing [Liu et al., CVPR'20]
  3. SDFDiff: Differentiable Rendering of Signed Distance Fields for 3D Shape Optimization [Jiang et al., CVPR'20]

ここ最近のCVPRなどのコンピュータビジョンの学会では3Dの研究が爆発的に増えていて,ビジョンとグラフィクスを行ったり来たりしている研究も多い気がする.特に3D形状をどう表現するかという問題に焦点が当たっており,ボクセルやポイントクラウド,メッシュなどの離散表現ではなく,DNNでSigned Distance Function (SDF)をモデル化 [Park et al., CVPR'19] したり,Occupancy Functionという物体の内部か外部かを判別する2値分類器などをモデル化 [Mescheder et al., CVPR'19] したりして,3D形状を陰関数表現しているのをよく見かける.


Differentiable Rendering: A Survey [Kato et al., 2020]


またこれとは別にDifferentiable Rendering [Kato et al., CVPR'18] という手法がある.これはレンダリングプロセスを微分可能にし,計算グラフに組み込むことで,3D教師データを必要とせず,2D教師データのみで3D形状の学習を行うことができるでものである.言い換えれば,どのような3D形状を学習すれば,それによってレンダリングされた結果が教師である2Dデータと一致するか,ということを学習する.例えばラスタライズベースのレンダリングプロセスの大部分は幾何計算であり微分可能であるが,ラスタライズだけは微分できない.そこで様々な近似勾配が提案されている.


Differentiable Rendering: A Survey [Kato et al., 2020]


今回の3つの論文は全て,DNNで陰関数表現された3Dオブジェクトのレンダリングを微分可能な形でどう計算グラフに組み込むか,という問題を扱っている.陰関数表現された3Dオブジェクトのレンダリングはまさしくレイキャスティングである.レイキャスティングを微分可能な形で計算グラフに組み込めれば,レンダリングを通してDNNを最適化できる.

今回の目標は,まずこれらの論文を理解し,実装し,学習させる.次に学習されたネットワークをSigned Distance FunctionとしてGLSLにぶちこみそのままスフィアトレーシングでレンダリングしてやる.これで学習された3D形状をそのままレンダリングできるはずである.

まず今回の問題設定を定式化する.DVR [1] の手法が最も洗練されている気がしたこれをベースに理解を進めることにした. まず3Dオブジェクトの形状は$$f_{\theta}: \mathbb{R}^{3} \times \mathcal{Z} \rightarrow \mathbb{R}$$で陰関数表現する.これはSigned Distance FunctionでもOccupancy Functionでも良い. 3Dオブジェクトのテクスチャは$$t_{\theta}: \mathbb{R}^{3} \times \mathcal{Z} \rightarrow \mathbb{R}^{3}$$で表現する.共に$z \in \mathcal{Z}$は3Dオブジェクトの形状,テクスチャを表す埋め込み表現であり,2D表現からDNNで獲得する.

よってレンダリングされた2D表現を$\hat{I}$とすると,以下のような最適化問題を解きたいわけである. $$\theta^{*}=\text{argmin}_{\theta}\mathcal{L}(\hat{I}, I)$$

勾配法で最適化するとして $$\cfrac{\partial{\mathcal{L}}}{\partial{\theta}}=\sum_{u}\cfrac{\partial{\mathcal{L}}}{\partial{\hat{I}_{u}}}\cfrac{\partial{\hat{I}_{u}}}{\partial{\theta}}$$


DVR [Niemeyer et al., CVPR'20]


ここで$f_{\theta}$を用いてレイキャスティングした結果の交点を$\hat{p}$とすると,$\hat{I}_{u}=t_{\theta}(\hat{p})$であるから, $$\cfrac{\partial{\hat{I}_{u}}}{\partial{\theta}}=\cfrac{\partial{t_{\theta}(\hat{p})}}{\partial{\theta}}+\cfrac{\partial{t_{\theta}(\hat{p})}}{\partial{\hat{p}}}\cdot\cfrac{\partial{\hat{p}}}{\partial{\theta}}$$

ここで$\cfrac{\partial{\hat{p}}}{\partial{\theta}}$は陽には計算できないが,$f_{\theta}(\hat{p})=0$の陰関数微分により, $$\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\theta}}+\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\hat{p}}}\cdot\cfrac{\partial{\hat{p}}}{\partial{\theta}}=0$$

ここでレイを$r(d)=r_{0}+dw$と表すと,$\hat{p}=r(\hat{d})$と表せ,

$$\cfrac{\partial{\hat{p}}}{\partial{\theta}}=\cfrac{\partial{\hat{d}}}{\partial{\theta}}w$$

よって, $$\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\theta}}+\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\hat{p}}}\cdot\cfrac{\partial{\hat{d}}}{\partial{\theta}}w=0$$

$$\cfrac{\partial{\hat{d}}}{\partial{\theta}}=-(\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\hat{p}}} \cdot w)^{-1}\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\theta}}$$

すなわち, $$\cfrac{\partial{\hat{p}}}{\partial{\theta}}=-(\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\hat{p}}} \cdot w)^{-1}\cfrac{\partial{f_{\theta}(\hat{p})}}{\partial{\theta}}w$$

これにより,レイキャスティングによる交点$\hat{p}$を微分可能な形で求める必要はない. 例えばDVR [1] ではレイ上の点をサンプリングして,Occupancy Networkの出力が初めて0.5を跨いだ点(Occupancy Networkは与えられた点が3Dオブジェクト内部に存在する確率を出力するので,出力が0.5となる点はオブジェクト表面に存在すると推定されたことになる)を交点$\hat{p}$としている.

他にもDIST [2], SDFDiff [3] ではスフィアトレーシングを用いてレンダリングしており,微分可能な形で交点$\hat{p}$を求める工夫をしている.


DIST [Liu et al., CVPR'20]


DIST [2] では以下のような戦略を用いてスフィアトレーシングを加速させている. イテレーション毎に計算グラフを構築しているので計算量が多く,そのため近似勾配を用いて対処している. イテレーション毎に素直にDNNを通して距離を計算していては,計算グラフが肥大化し,やがてVRAMを食い潰すだろう.


DIST [Liu et al., CVPR'20]


SDFDiff [3] は,スフィアトレーシングにおける最後のイテレーションのみ計算グラフを構築することで,微分可能な形で交点$\hat{p}$を求めている.コアのアイデアはこんな感じである.

# --- ray marching --- #

with torch.no_grad():
    while not converged:
        p += SDF(p) * v

# make only the last step differentiable
with torch.enable_grad():
    p += SDF(p) * v

この論文のSigned Distance Functionはボクセルベースなので,任意の点における距離は近傍ボクセルの線形補間で求めている.

コアのアイデアは多分理解できたので,実装に移る. 今回は DVR [1] をベースにOccupancy FunctionではなくSigned Distance FunctionをDNNでモデル化した. 学習の流れとしては,以下のように行った.論文の再現は目的ではないので,実験設定は異なっている. 自分なりに理解しやすく,まずはなるべくシンプルになるように書いたので,実装は元論文とはだいぶ違うかもしれない. タスクとしては"Single-View Reconstruction with Multi-View Supervision"を扱う.

  1. 3Dオブジェクトをランダムにサンプリングしたカメラ,ライト,マテリアルを用いてレンダリングし,これをGTとする.
  2. Signed Distance Function $f_{\theta}$ を用いて,スフィアトレーシングにより3Dオブジェクトとの交点を推定する.法線は陰関数の勾配で与えられるので,素直に微分するか,有限差分で近似するかして求める.
  3. GTのカメラ,ライト,マテリアルを用いてphongライティングにより推定された3Dオブジェクトをレンダリングする.またこの時使用されるテクスチャは $t_{\theta}$ により推定する.
  4. Backpropにより,各パラメータにおける勾配を求める.スフィアトレーシングによる交点の微分は陰関数微分を用いて求める.

今回はGTのレンダリングや,反射モデル,幾何変換のために PyTorch3D を用いた.データセットとしては ShapeNetCore を用いた. このShapeNetCore,3Dモデルが汚いためか,普通にphongシェーディングすると,めちゃくちゃなシェーディングになった. なので,GTのレンダリングにはとりあえずflatシェーディングを用いている.


ShapeNetCore


まずは普通のスフィアトレーシングをPyTorchで実装してみる. カメラやライトはPyTorch3Dのクラスを用いている. PyTorch3Dはかなり新しいライブラリでまだ洗練されていない部分も多く,結構苦労した. 特にテンソルの形状のミスマッチに起因するエラーが多く,自分で書き直していかないといけなかった. とはいえ,PyTorchにレンダリングのパイプラインを自然に組み込めるのはとてもありがたいことである.

def sphere_tracing(
    signed_distance_function, 
    positions, 
    directions, 
    foreground_masks, 
    num_iterations, 
    convergence_threshold,
):
    for i in range(num_iterations):
        signed_distances = signed_distance_function(positions)
        if i:
            positions = torch.where(converged, positions, positions + directions * signed_distances)
        else:
            positions = positions + directions * signed_distances
        converged = torch.abs(signed_distances) < convergence_threshold
        if torch.all(converged[foreground_masks] if foreground_masks else converged):
            break

    return positions, converged


def compute_normal(signed_distance_function, positions, finite_difference_epsilon):

    if finite_difference_epsilon:
        finite_difference_epsilon = positions.new_tensor(finite_difference_epsilon)
        finite_difference_epsilon = finite_difference_epsilon.reshape(1, 1, 1)
        finite_difference_epsilon_x = nn.functional.pad(finite_difference_epsilon, (0, 2))
        finite_difference_epsilon_y = nn.functional.pad(finite_difference_epsilon, (1, 1))
        finite_difference_epsilon_z = nn.functional.pad(finite_difference_epsilon, (2, 0))
        normals_x = signed_distance_function(positions + finite_difference_epsilon_x) - signed_distance_function(positions - finite_difference_epsilon_x)
        normals_y = signed_distance_function(positions + finite_difference_epsilon_y) - signed_distance_function(positions - finite_difference_epsilon_y)
        normals_z = signed_distance_function(positions + finite_difference_epsilon_z) - signed_distance_function(positions - finite_difference_epsilon_z)
        normals = torch.cat((normals_x, normals_y, normals_z), dim=-1)

    else:
        create_graph = positions.requires_grad
        positions.requires_grad_(True)
        with torch.enable_grad():
            signed_distances = signed_distance_function(positions)
            normals, = autograd.grad(
                outputs=signed_distances, 
                inputs=positions, 
                grad_outputs=torch.ones_like(signed_distances),
                create_graph=create_graph,
            )
            
    return normals


def phong_shading(positions, normals, textures, cameras, lights, materials):
    light_diffuse_color = lights.diffuse(
        normals=normals, 
        points=positions,
    )
    light_specular_color = lights.specular(
        normals=normals,
        points=positions,
        camera_position=cameras.get_camera_center(),
        shininess=materials.shininess,
    )
    ambient_colors = materials.ambient_color * lights.ambient_color
    diffuse_colors = materials.diffuse_color * light_diffuse_color
    specular_colors = materials.specular_color * light_specular_color
    # NOTE: pytorch3d.renderer.phong_shading should be fixed as well
    assert diffuse_colors.shape == specular_colors.shape
    ambient_colors = ambient_colors.reshape(-1, *[1] * len(diffuse_colors.shape[1:-1]), 3)
    colors = (ambient_colors + diffuse_colors) * textures + specular_colors
    return colors

上記の関数を用いて,Constructive Solid Geometry (CSG) をレンダリングしてみた. 完全なコードはここに置いてある: [code]


Constructive Solid Geometry (CSG)


各プリミティブに対応するSigned Distance Functionは,まだその導出を理解していないものも多く,一度ちゃんと勉強したいと思う.

一応スフィアトレーシングは動作していそうなので,これを自動微分の枠組みに組み込む. 基本的には上記のsphere_tracing(...)torch.autograd.Function.forward(...)に移植し, torch.autograd.Function.backward(...)を陰関数微分にしたがって実装すれば良さそうである.

class SphereTracing(autograd.Function):

    @staticmethod
    def forward(
        ctx, 
        signed_distance_function, 
        positions, 
        directions, 
        foreground_masks, 
        num_iterations, 
        convergence_threshold,
        *parameters,
    ):
        # vanilla sphere tracing
        with torch.no_grad():
            positions, converged = sphere_tracing(
                signed_distance_function=signed_distance_function, 
                positions=positions, 
                directions=directions, 
                foreground_masks=foreground_masks,
                num_iterations=num_iterations, 
                convergence_threshold=convergence_threshold,
            )
            positions = torch.where(converged, positions, torch.zeros_like(positions))

        # save tensors for backward pass
        ctx.save_for_backward(positions, directions, foreground_masks, converged)
        ctx.signed_distance_function = signed_distance_function
        ctx.parameters = parameters

        return positions, converged

    @staticmethod
    def backward(ctx, grad_outputs, *_):
        
        # restore tensors from forward pass
        positions, directions, foreground_masks, converged = ctx.saved_tensors
        signed_distance_function = ctx.signed_distance_function
        parameters = ctx.parameters

        # compute gradients using implicit differentiation
        with torch.enable_grad():
            positions = positions.detach()
            positions.requires_grad_(True)
            signed_distances = signed_distance_function(positions)
            grad_positions, = autograd.grad(
                outputs=signed_distances, 
                inputs=positions, 
                grad_outputs=torch.ones_like(signed_distances), 
                retain_graph=True,
            )
            grad_outputs_dot_directions = torch.sum(grad_outputs * directions, dim=-1, keepdim=True)
            grad_positions_dot_directions = torch.sum(grad_positions * directions, dim=-1, keepdim=True)
            # NOTE: avoid division by zero
            grad_positions_dot_directions = torch.where(
                grad_positions_dot_directions > 0,
                torch.max(grad_positions_dot_directions, torch.full_like(grad_positions_dot_directions, +1e-6)),
                torch.min(grad_positions_dot_directions, torch.full_like(grad_positions_dot_directions, -1e-6)),
            )
            grad_outputs = -grad_outputs_dot_directions / grad_positions_dot_directions
            # NOTE: zero gradient for unconverged points 
            grad_outputs = torch.where(converged, grad_outputs, torch.zeros_like(grad_outputs))
            grad_parameters = autograd.grad(
                outputs=signed_distances, 
                inputs=parameters, 
                grad_outputs=grad_outputs, 
                retain_graph=True,
            )

        return (None, None, None, None, None, None, *grad_parameters)

これで一応学習させてみるが,どうせ最初からうまくはいかないので,適宜修正を加えていくことになるだろう. 損失関数は単純なL1 lossを用いてみたが,SDFDiff [3] で提案されているSigned Distance Functionの勾配やラプラシアンに関する正則化項を入れたり,他にも実装レベルの様々なトリックが必要になってくるかもしれない.

結果はうまくいき次第載せようと思う.