
"""Some useful util functions"""

import os
import random
from pathlib import Path
from typing import List, Literal, Optional, Union

import cv2
import numpy as np
import open3d as o3d
import torch
from natsort import natsorted
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import resize
from tqdm import tqdm
from typing import List, Optional, Tuple

from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.models.base_model import Model
from nerfstudio.process_data.process_data_utils import (
    convert_video_to_images,
    get_num_frames_in_video,
)
from nerfstudio.utils import colormaps
from nerfstudio.utils.rich_utils import CONSOLE

# Depth Scale Factor m to mm
SCALE_FACTOR = 0.001

# opengl to opencv transformation matrix
OPENGL_TO_OPENCV = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])


# ndc space is x to the right y up. uv space is x to the right, y down.
def pix2ndc_x(x, W):
    x = x.float()
    return (2 * x) / W - 1


def pix2ndc_y(y, H):
    y = y.float()
    return 1 - (2 * y) / H


# ndc is y up and x right. uv is y down and x right
def ndc2pix_x(x, W):
    return (x + 1) * 0.5 * W


def ndc2pix_y(y, H):
    return (1 - y) * 0.5 * H


def euclidean_to_z_depth(
    depths: Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    img_size: tuple,
    device: torch.device,
) -> Tensor:
    """Convert euclidean depths to z_depths given camera intrinsics"""
    if depths.dim() == 3:
        depths = depths.view(-1, 1)
    elif depths.shape[-1] != 1:
        depths = depths.unsqueeze(-1).contiguous()
        depths = depths.view(-1, 1)
    if depths.dtype != torch.float:
        depths = depths.float()
    image_coords = get_camera_coords(img_size=img_size)
    image_coords = image_coords.to(device)

    z_depth = torch.empty(
        size=(img_size[0], img_size[1], 3), dtype=torch.float32, device=device
    ).view(-1, 3)
    z_depth[:, 0] = (image_coords[:, 0] - cx) / fx  # x
    z_depth[:, 1] = (image_coords[:, 1] - cy) / fy  # y
    z_depth[:, 2] = 1  # z

    z_depth = z_depth / torch.norm(z_depth, dim=-1, keepdim=True)
    z_depth = (z_depth * depths)[:, 2]  # pick only z component

    z_depth = z_depth[..., None]
    z_depth = z_depth.view(img_size[1], img_size[0], 1)

    return z_depth


def get_camera_coords(img_size: tuple, pixel_offset: float = 0.5) -> Tensor:
    """Generates camera pixel coordinates [W,H]

    Returns:
        stacked coords [H*W,2] where [:,0] corresponds to W and [:,1] corresponds to H
    """

    # img size is (w,h)
    image_coords = torch.meshgrid(
        torch.arange(img_size[0]),
        torch.arange(img_size[1]),
        indexing="xy",  # W = u by H = v
    )
    image_coords = (
        torch.stack(image_coords, dim=-1) + pixel_offset
    )  # stored as (x, y) coordinates
    image_coords = image_coords.view(-1, 2)
    image_coords = image_coords.float()

    return image_coords


def get_means3d_backproj(
    depths: Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    img_size: tuple,
    c2w: Tensor,
    device: torch.device,
    mask: Optional[Tensor] = None,
) -> Tuple[Tensor, List]:
    """Backprojection using camera intrinsics and extrinsics

    image_coords -> (x,y,depth) -> (X, Y, depth)

    Returns:
        Tuple of (means: Tensor, image_coords: Tensor)
    """

    if depths.dim() == 3:
        depths = depths.view(-1, 1)
    elif depths.shape[-1] != 1:
        depths = depths.unsqueeze(-1).contiguous()
        depths = depths.view(-1, 1)
    if depths.dtype != torch.float:
        depths = depths.float()
        c2w = c2w.float()
    if c2w.device != device:
        c2w = c2w.to(device)

    image_coords = get_camera_coords(img_size)
    image_coords = image_coords.to(device)  # note image_coords is (H,W)

    # TODO: account for skew / radial distortion
    means3d = torch.empty(
        size=(img_size[0], img_size[1], 3), dtype=torch.float32, device=device
    ).view(-1, 3)
    means3d[:, 0] = (image_coords[:, 0] - cx) * depths[:, 0] / fx  # x
    means3d[:, 1] = (image_coords[:, 1] - cy) * depths[:, 0] / fy  # y
    means3d[:, 2] = depths[:, 0]  # z

    if mask is not None:
        if not torch.is_tensor(mask):
            mask = torch.tensor(mask, device=depths.device)
        means3d = means3d[mask]
        image_coords = image_coords[mask]

    if c2w is None:
        c2w = torch.eye((means3d.shape[0], 4, 4), device=device)

    # to world coords
    means3d = means3d @ torch.linalg.inv(c2w[..., :3, :3]) + c2w[..., :3, 3]
    return means3d, image_coords


def project_pix(
    p: Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    c2w: Tensor,
    device: torch.device,
    return_z_depths: bool = False,
) -> Tensor:
    """Projects a world 3D point to uv coordinates using intrinsics/extrinsics

    Returns:
        uv coords
    """
    if c2w is None:
        c2w = torch.eye((p.shape[0], 4, 4), device=device)  # type: ignore
    if c2w.device != device:
        c2w = c2w.to(device)

    points_cam = (p.to(device) - c2w[..., :3, 3]) @ c2w[..., :3, :3]
    u = points_cam[:, 0] * fx / points_cam[:, 2] + cx  # x
    v = points_cam[:, 1] * fy / points_cam[:, 2] + cy  # y
    if return_z_depths:
        return torch.stack([u, v, points_cam[:, 2]], dim=-1)
    return torch.stack([u, v], dim=-1)


def get_colored_points_from_depth(
    depths: Tensor,
    rgbs: Tensor,
    c2w: Tensor,
    fx: float,
    fy: float,
    cx: int,
    cy: int,
    img_size: tuple,
    mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
    """Return colored pointclouds from depth and rgb frame and c2w. Optional masking.

    Returns:
        Tuple of (points, colors)
    """
    points, _ = get_means3d_backproj(
        depths=depths.float(),
        fx=fx,
        fy=fy,
        cx=cx,
        cy=cy,
        img_size=img_size,
        c2w=c2w.float(),
        device=depths.device,
    )
    points = points.squeeze(0)
    if mask is not None:
        if not torch.is_tensor(mask):
            mask = torch.tensor(mask, device=depths.device)
        colors = rgbs.view(-1, 3)[mask]
        points = points[mask]
    else:
        colors = rgbs.view(-1, 3)
        points = points
    return (points, colors)


def get_rays_x_y_1(H, W, focal, c2w):
    """Get ray origins and directions in world coordinates.

    Convention here is (x,y,-1) such that depth*rays_d give real z depth values in world coordinates.
    """
    assert c2w.shape == torch.Size([3, 4])
    image_coords = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing="ij",
    )
    i, j = image_coords
    # dirs = torch.stack([(i-W*0.5)/focal, -(j-H*0.5)/focal, -torch.ones_like(i)], dim = -1)
    dirs = torch.stack(
        [(pix2ndc_x(i, W)) / focal, pix2ndc_y(j, H) / focal, -torch.ones_like(i)],
        dim=-1,
    )
    dirs = dirs.view(-1, 3)
    rays_d = dirs[..., :] @ c2w[:3, :3]
    rays_o = c2w[:3, -1].expand_as(rays_d)

    # return world coordinate rays_o and rays_d
    return rays_o, rays_d


def get_projection_matrix(znear=0.001, zfar=1000, fovx=None, fovy=None, **kwargs):
    """Opengl projection matrix

    Returns:
        projmat: Tensor
    """

    t = znear * math.tan(0.5 * fovy)
    b = -t
    r = znear * math.tan(0.5 * fovx)
    l = -r
    n = znear
    f = zfar
    return torch.tensor(
        [
            [2 * n / (r - l), 0.0, (r + l) / (r - l), 0.0],
            [0.0, 2 * n / (t - b), (t + b) / (t - b), 0.0],
            [0.0, 0.0, (f + n) / (f - n), -1.0 * f * n / (f - n)],
            [0.0, 0.0, 1.0, 0.0],
        ],
        **kwargs,
    )




def video_to_frames(
    video_path: Path, image_dir: Path("./data/frames"), force: bool = False
):
    """Extract frames from video, requires nerfstudio install"""
    is_empty = False

    if not image_dir.exists():
        is_empty = True
    else:
        is_empty = not any(image_dir.iterdir())

    if is_empty or force:
        num_frames_target = get_num_frames_in_video(video=video_path)
        summary_log, num_extracted_frames = convert_video_to_images(
            video_path,
            image_dir=image_dir,
            num_frames_target=num_frames_target,
            num_downscales=0,
            verbose=True,
            image_prefix="frame_",
            keep_image_dir=False,
        )
        assert num_extracted_frames == num_frames_target


def get_filename_list(image_dir: Path, ends_with: Optional[str] = None) -> List:
    """List directory and save filenames

    Returns:
        image_filenames
    """
    image_filenames = os.listdir(image_dir)
    if ends_with is not None:
        image_filenames = [
            image_dir / name
            for name in image_filenames
            if name.lower().endswith(ends_with)
        ]
    else:
        image_filenames = [image_dir / name for name in image_filenames]
    image_filenames = natsorted(image_filenames)
    return image_filenames


def image_path_to_tensor(
    image_path: Path, size: Optional[tuple] = None, black_and_white=False
) -> Tensor:
    """Convert image from path to tensor

    Returns:
        image: Tensor
    """
    img = Image.open(image_path)
    if black_and_white:
        img = img.convert("1")
    transform = transforms.ToTensor()
    img_tensor = transform(img).permute(1, 2, 0)[..., :3]
    if size:
        img_tensor = resize(
            img_tensor.permute(2, 0, 1), size=size, antialias=None
        ).permute(1, 2, 0)
    return img_tensor


def depth_path_to_tensor(
    depth_path: Path, scale_factor: float = SCALE_FACTOR, return_color=False
) -> Tensor:
    """Load depth image in either .npy or .png format and return tensor

    Args:
        depth_path: Path
        scale_factor: float
        return_color: bool
    Returns:
        depth tensor and optionally colored depth tensor
    """
    if depth_path.suffix == ".png":
        depth = cv2.imread(str(depth_path.absolute()), cv2.IMREAD_ANYDEPTH)
    elif depth_path.suffix == ".npy":
        depth = np.load(depth_path, allow_pickle=True)
        if len(depth.shape) == 3:
            depth = depth[..., 0]
    else:
        raise Exception(f"Format is not supported {depth_path.suffix}")
    depth = depth * scale_factor
    depth = depth.astype(np.float32)
    depth = torch.from_numpy(depth).unsqueeze(-1)
    if not return_color:
        return depth
    else:
        depth_color = colormaps.apply_depth_colormap(depth)
        return depth, depth_color  # type: ignore


def save_img(image, image_path, verbose=True) -> None:
    """helper to save images

    Args:
        image: image to save (numpy, Tensor)
        image_path: path to save
        verbose: whether to print save path

    Returns:
        None
    """
    if image.shape[-1] == 1 and torch.is_tensor(image):
        image = image.repeat(1, 1, 3)
    if torch.is_tensor(image):
        image = image.detach().cpu().numpy() * 255
        image = image.astype(np.uint8)
    if not Path(os.path.dirname(image_path)).exists():
        Path(os.path.dirname(image_path)).mkdir(parents=True)
    im = Image.fromarray(image)
    if verbose:
        print("saving to: ", os.getcwd() + "/" + image_path)
    im.save(image_path)


def save_depth(depth, depth_path, verbose=True, scale_factor=SCALE_FACTOR) -> None:
    """helper to save metric depths

    Args:
        depth: image to save (numpy, Tensor)
        depth_path: path to save
        verbose: whether to print save path
        scale_factor: depth metric scaling factor

    Returns:
        None
    """
    if torch.is_tensor(depth):
        depth = depth.float() / scale_factor
        depth = depth.detach().cpu().numpy()
    else:
        depth = depth / scale_factor
    if not Path(os.path.dirname(depth_path)).exists():
        Path(os.path.dirname(depth_path)).mkdir(parents=True)
    if verbose:
        print("saving to: ", depth_path)
    np.save(depth_path, depth)


def save_normal(
    normal: Union[np.array, Tensor],
    normal_path: Path,
    verbose: bool = True,
    format: Literal["png", "npy"] = "png",
) -> None:
    """helper to save normal

    Args:
        normal: image to save (numpy, Tensor)
        normal_path: path to save
        verbose: whether to print save path

    Returns:
        None
    """
    if torch.is_tensor(normal):
        normal = normal.float()
        normal = normal.detach().cpu().numpy()
    else:
        normal = normal
    if not Path(os.path.dirname(normal_path)).exists():
        Path(os.path.dirname(normal_path)).mkdir(parents=True)
    if verbose:
        print("saving to: ", normal_path)
    if format == "npy":
        np.save(normal_path, normal)
    elif format == "png":
        normal = normal * 255
        normal = normal.astype(np.uint8)
        nm = Image.fromarray(normal)
        nm.save(normal_path)


def gs_get_point_clouds(
    eval_data: Optional[InputDataset],
    train_data: Optional[InputDataset],
    model: Model,
    render_output_path: Path,
    num_points: int = 1_000_000,
) -> None:
    """Saves pointcloud rendered from a model using predicted eval/train depths

    Args:
        eval_data: eval input dataset
        train_data: train input dataset
        model: model object
        render_output_path: path to render results to
        num_points: number of points to extract in pd

    Returns:
        None
    """
    CONSOLE.print("[bold green] Generating pointcloud ...")
    H, W = (
        int(train_data.cameras[0].height.item()),
        int(train_data.cameras[0].width.item()),
    )
    pixels_per_frame = W * H
    samples_per_frame = (num_points + (len(train_data) + len(eval_data))) // (
        len(train_data) + len(eval_data)
    )
    points = []
    colors = []
    if len(train_data) > 0:
        for image_idx in tqdm(range(len(train_data)), leave=False):
            camera = train_data.cameras[image_idx : image_idx + 1].to(model.device)
            outputs = model.get_outputs(camera)
            rgb_out, depth_out = outputs["rgb"], outputs["depth"]

            c2w = torch.concatenate(
                [
                    camera.camera_to_worlds,
                    torch.tensor([[[0, 0, 0, 1]]]).to(model.device),
                ],
                dim=1,
            )
            # convert from opengl to opencv
            c2w = torch.matmul(
                c2w, torch.from_numpy(OPENGL_TO_OPENCV).float().to(model.device)
            )
            # backproject
            point, _ = get_means3d_backproj(
                depths=depth_out.float(),
                fx=camera.fx,
                fy=camera.fy,
                cx=camera.cx,
                cy=camera.cy,
                img_size=(W, H),
                c2w=c2w.float(),
                device=model.device,
            )
            point = point.squeeze(0)

            # sample pixels for this frame
            indices = random.sample(range(pixels_per_frame), samples_per_frame)
            mask = torch.tensor(indices, device=model.device)

            color = rgb_out.view(-1, 3)[mask].detach().cpu().numpy()
            point = point[mask].detach().cpu().numpy()
            points.append(point)
            colors.append(color)

    if len(eval_data) > 0:
        for image_idx in tqdm(range(len(eval_data)), leave=False):
            camera = eval_data.cameras[image_idx : image_idx + 1].to(model.device)
            outputs = model.get_outputs(camera)
            rgb_out, depth_out = outputs["rgb"], outputs["depth"]

            c2w = torch.concatenate(
                [
                    camera.camera_to_worlds,
                    torch.tensor([[[0, 0, 0, 1]]]).to(model.device),
                ],
                dim=1,
            )
            # convert from opengl to opencv
            c2w = torch.matmul(
                c2w, torch.from_numpy(OPENGL_TO_OPENCV).float().to(model.device)
            )
            # backproject
            point, _ = get_means3d_backproj(
                depths=depth_out.float(),
                fx=camera.fx,
                fy=camera.fy,
                cx=camera.cx,
                cy=camera.cy,
                img_size=(W, H),
                c2w=c2w.float(),
                device=model.device,
            )
            point = point.squeeze(0)

            # sample pixels for this frame
            indices = random.sample(range(pixels_per_frame), samples_per_frame)
            mask = torch.tensor(indices, device=model.device)

            color = rgb_out.view(-1, 3)[mask].detach().cpu().numpy()
            point = point[mask].detach().cpu().numpy()
            points.append(point)
            colors.append(color)

    points = np.vstack(points)
    colors = np.vstack(colors)

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    CONSOLE.print(
        f"[bold yellow]Saved pointcloud to {os.getcwd() + render_output_path}'/pointcloud.ply'"
    )
    o3d.io.write_point_cloud(os.getcwd() + f"{render_output_path}/pointcloud.ply", pcd)
    return (points, colors)


def gs_render_dataset_images(
    train_cache: List,
    eval_cache: List,
    train_dataset,
    eval_dataset,
    model,
    render_output_path: Path,
    mushroom=False,
    save_train_images=False,
) -> None:
    """Render and save all train/eval images of gs model to directory

    Args:
        train_cache: list of cached train images
        eval_cache: list of cached eval images
        eval_data: eval input dataset
        train_data: train input dataset
        model: model object
        render_output_path: path to render results to
        mushroom: if dataset is Mushroom dataset or not
        save_train_images: whether to save train images or not

    Returns:
        None
    """
    CONSOLE.print(f"[bold yellow]Saving results to {render_output_path}")
    if len(eval_cache) > 0:
        for i, _ in tqdm(enumerate(range(len(eval_cache))), leave=False):
            image_idx = i
            data = eval_cache[image_idx]
            # ground truth data
            gt_img = data["image"]
            if "sensor_depth" in data:
                depth_gt = data["sensor_depth"]
                depth_gt_color = colormaps.apply_depth_colormap(data["sensor_depth"])
            else:
                depth_gt = None
                depth_gt_color = None
            normal_gt = data["normal"] if "normal" in data else None
            camera = eval_dataset.cameras[image_idx : image_idx + 1].to(model.device)

            # save the image with its original name for easy comparison
            if mushroom:
                seq_name = Path(eval_dataset.image_filenames[image_idx])
                image_name = f"{seq_name.parts[-3]}_{seq_name.stem}"
            else:
                image_name = Path(eval_dataset.image_filenames[image_idx]).stem
            outputs = model.get_outputs(camera)
            rgb_out, depth_out, normal_out, surface_normal = (
                outputs["rgb"],
                outputs["depth"],
                outputs["normal"],
                outputs["surface_normal"],
            )

            depth_color = colormaps.apply_depth_colormap(depth_out)
            depth = depth_out.detach().cpu().numpy()
            save_outputs_helper(
                rgb_out,
                gt_img,
                depth_color,
                depth_gt_color,
                depth_gt,
                depth,
                normal_gt if normal_gt is not None else None,
                surface_normal if surface_normal is not None else None,
                render_output_path,
                image_name,
            )

    if save_train_images and len(train_cache) > 0:
        for i, _ in tqdm(enumerate(range(len(train_cache))), leave=False):
            image_idx = i
            data = train_cache[image_idx]
            # ground truth data
            gt_img = data["image"]
            if "sensor_depth" in data:
                depth_gt = data["sensor_depth"]
                depth_gt_color = colormaps.apply_depth_colormap(data["sensor_depth"])
            else:
                depth_gt = None
                depth_gt_color = None
            normal_gt = data["normal"] if "normal" in data else None
            camera = train_dataset.cameras[image_idx : image_idx + 1].to(model.device)

            # save the image with its original name for easy comparison
            if mushroom:
                seq_name = Path(train_dataset.image_filenames[image_idx])
                image_name = f"{seq_name.parts[-3]}_{seq_name.stem}"
            else:
                image_name = Path(train_dataset.image_filenames[image_idx]).stem
            outputs = model.get_outputs(camera)
            rgb_out, depth_out, normal_out, surface_normal = (
                outputs["rgb"],
                outputs["depth"],
                outputs["normal"],
                outputs["surface_normal"],
            )

            depth_color = colormaps.apply_depth_colormap(depth_out)
            depth = depth_out.detach().cpu().numpy()
            save_outputs_helper(
                rgb_out,
                gt_img,
                depth_color,
                depth_gt_color,
                depth_gt,
                depth,
                normal_gt if normal_gt is not None else None,
                surface_normal if surface_normal is not None else None,
                render_output_path,
                image_name,
            )


def ns_render_dataset_images(
    train_dataloader: DataLoader,
    eval_dataloader: DataLoader,
    train_dataset: InputDataset,
    eval_dataset: InputDataset,
    model: Model,
    render_output_path: Path,
    mushroom=False,
    save_train_images=False,
) -> None:
    """render and save all train/eval images of nerfstudio model to directory

    Args:
        train_dataloader: train dataloader
        eval_dataloader: eval dataloader
        eval_data: eval input dataset
        train_data: train input dataset
        model: model object
        render_output_path: path to render results to
        mushroom: whether the dataset is Mushroom dataset or not
        save_train_images:  whether to save train images or not

    Returns:
        None
    """
    CONSOLE.print(f"[bold yellow]Saving results to {render_output_path}")
    if len(eval_dataloader) > 0:
        for image_idx, (camera, batch) in tqdm(enumerate(eval_dataloader)):
            with torch.no_grad():
                outputs = model.get_outputs_for_camera(camera)
            # ground truth data
            data = batch.copy()
            gt_img = data["image"]
            if "sensor_depth" in data:
                depth_gt = data["sensor_depth"]
                depth_gt_color = colormaps.apply_depth_colormap(data["sensor_depth"])
            else:
                depth_gt = None
                depth_gt_color = None
            normal_gt = data["normal"] if "normal" in data else None
            # save the image with its original name for easy comparison
            if mushroom:
                seq_name = Path(eval_dataset.image_filenames[image_idx])
                image_name = f"{seq_name.parts[-3]}_{seq_name.stem}"
            else:
                image_name = Path(eval_dataset.image_filenames[image_idx]).stem

            rgb_out, depth_out, normal_out, surface_normal = (
                outputs["rgb"],
                outputs["depth"],
                outputs["normal"] if "normal" in outputs else None,
                outputs["surface_normal"] if "surface_normal" in outputs else None,
            )
            depth_color = colormaps.apply_depth_colormap(depth_out)
            depth = depth_out.detach().cpu().numpy()
            save_outputs_helper(
                rgb_out,
                gt_img,
                depth_color,
                depth_gt_color,
                depth_gt,
                depth,
                normal_gt,
                surface_normal,
                render_output_path,
                image_name,
            )

    if save_train_images and len(train_dataloader) > 0:
        for image_idx, (camera, batch) in tqdm(enumerate(train_dataloader)):
            with torch.no_grad():
                outputs = model.get_outputs_for_camera(camera)
            # ground truth data
            data = batch.copy()
            gt_img = data["image"]
            if "sensor_depth" in data:
                depth_gt = data["sensor_depth"]
                depth_gt_color = colormaps.apply_depth_colormap(data["sensor_depth"])
            else:
                depth_gt = None
                depth_gt_color = None
            normal_gt = data["normal"] if "normal" in data else None
            # save the image with its original name for easy comparison
            if mushroom:
                seq_name = Path(train_dataset.image_filenames[image_idx])
                image_name = f"{seq_name.parts[-3]}_{seq_name.stem}"
            else:
                image_name = Path(train_dataset.image_filenames[image_idx]).stem

            rgb_out, depth_out, normal_out = (
                outputs["rgb"],
                outputs["depth"],
                outputs["normal"] if "normal" in outputs else None,
            )
            depth_color = colormaps.apply_depth_colormap(depth_out)
            depth = depth_out.detach().cpu().numpy()
            save_outputs_helper(
                rgb_out,
                gt_img,
                depth_color,
                depth_gt_color,
                depth_gt,
                depth,
                normal_gt,
                normal_out,
                render_output_path,
                image_name,
            )


def save_outputs_helper(
    rgb_out: Optional[Tensor] = None,
    gt_img: Optional[Tensor] = None,
    depth_color: Optional[Tensor] = None,
    depth_gt_color: Optional[Tensor] = None,
    depth_gt: Optional[Tensor] = None,
    depth: Optional[Tensor] = None,
    normal_gt: Optional[Tensor] = None,
    normal: Optional[Tensor] = None,
    render_output_path: Optional[Path] = None,
    image_name: Optional[str] = None,
) -> None:
    """Helper to save model rgb/depth/gt outputs to disk

    Args:
        rgb_out: rgb image
        gt_img: gt rgb image
        depth_color: colored depth image
        depth_gt_color: gt colored depth image
        depth_gt: gt depth map
        depth: depth map
        render_output_path: save directory path
        image_name: stem of save name

    Returns:
        None
    """
    if image_name is None:
        image_name = ""

    if rgb_out is not None and gt_img is not None:
        save_img(
            rgb_out,
            os.getcwd() + f"/{render_output_path}/pred/rgb/{image_name}.png",
            False,
        )
        save_img(
            gt_img,
            os.getcwd() + f"/{render_output_path}/gt/rgb/{image_name}.png",
            False,
        )
    if depth_color is not None:
        save_img(
            depth_color,
            os.getcwd()
            + f"/{render_output_path}/pred/depth/colorised/{image_name}.png",
            False,
        )
    if depth_gt_color is not None:
        save_img(
            depth_gt_color,
            os.getcwd() + f"/{render_output_path}/gt/depth/colorised/{image_name}.png",
            False,
        )
    if depth_gt is not None:
        # save metric depths
        save_depth(
            depth_gt,
            os.getcwd() + f"/{render_output_path}/gt/depth/raw/{image_name}.npy",
            False,
        )
    if depth is not None:
        save_depth(
            depth,
            os.getcwd() + f"/{render_output_path}/pred/depth/raw/{image_name}.npy",
            False,
        )

    if normal is not None:
        save_normal(
            normal,
            os.getcwd() + f"/{render_output_path}/pred/normal/{image_name}.png",
            verbose=False,
        )

    if normal_gt is not None:
        save_normal(
            normal_gt,
            os.getcwd() + f"/{render_output_path}/gt/normal/{image_name}.png",
            verbose=False,
        )


def pcd_to_normal(xyz: Tensor):
    hd, wd, _ = xyz.shape
    '''
    bottom_point = xyz[..., 2:hd, 1 : wd - 1, :]
    top_point = xyz[..., 0 : hd - 2, 1 : wd - 1, :]
    right_point = xyz[..., 1 : hd - 1, 2:wd, :]
    left_point = xyz[..., 1 : hd - 1, 0 : wd - 2, :]
    left_to_right = right_point - left_point
    bottom_to_top = top_point - bottom_point
    xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1)
    '''
    dx = torch.zeros_like(xyz)
    dy = torch.zeros_like(xyz)
    dx[:, :-1] = xyz[:, 1:] - xyz[:, :-1]
    dx[:, -1] = dx[:, -2]
    dy[:-1, :] = xyz[1:, :] - xyz[:-1, :]
    dy[-1, :] = dy[-2, :]
    xyz_normal = torch.cross(dy, dx)

    xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1)
    #xyz_normal = torch.nn.functional.pad(
    #    xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant"
    #).permute(1, 2, 0)
    return xyz_normal


def normal_from_depth_image(
    depths: Tensor,
    fx: float,
    fy: float,
    cx: float,
    cy: float,
    img_size: tuple,
    c2w: Tensor,
    device: torch.device,
    smooth: bool = False,
):
    """estimate normals from depth map"""
    if smooth:
        if torch.count_nonzero(depths) > 0:
            print("Input depth map contains 0 elements, skipping smoothing filter")
        else:
            kernel_size = (9, 9)
            depths = torch.from_numpy(
                cv2.GaussianBlur(depths.cpu().numpy(), kernel_size, 0)
            ).to(device)
    means3d, _ = get_means3d_backproj(depths, fx, fy, cx, cy, img_size, c2w, device)
    means3d = means3d.view(img_size[1], img_size[0], 3)
    normals = pcd_to_normal(means3d)
    return normals
