diff --git a/.readthedocs.yml b/.readthedocs.yml index 2e57d27a8..6599afb97 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,7 +1,7 @@ version: 2 python: - version: 3.7 + version: 3.8 system_packages: true install: - requirements: docs/readthedocs_requirements.txt diff --git a/docs/kaolin_ext.py b/docs/kaolin_ext.py index c501c4418..4ced3ac3a 100644 --- a/docs/kaolin_ext.py +++ b/docs/kaolin_ext.py @@ -29,6 +29,7 @@ def run_apidoc(_): "setup.py", "**.so", "kaolin/version.py", + "kaolin/version.txt", "kaolin/ops/conversions/pointcloud.py", "kaolin/ops/conversions/sdf.py", "kaolin/ops/conversions/trianglemesh.py", @@ -51,6 +52,7 @@ def run_apidoc(_): "kaolin/render/spc/raytrace.py", "kaolin/rep/spc.py", "kaolin/visualize/timelapse.py", + "kaolin/visualize/ipython.py", "kaolin/framework/*", "kaolin/render/camera/camera.py", "kaolin/render/camera/coordinates.py", @@ -70,9 +72,8 @@ def run_apidoc(_): "-d", "2", "--templatedir", DOCS_MODULE_PATH, - "-o", - DOCS_MODULE_PATH, - KAOLIN_ROOT, + "-o", DOCS_MODULE_PATH, + os.path.join(KAOLIN_ROOT, "kaolin"), *EXCLUDE_PATHS ] apidoc.main(argv) diff --git a/docs/readthedocs_requirements.txt b/docs/readthedocs_requirements.txt index d00ca8bf9..05d92c9d8 100644 --- a/docs/readthedocs_requirements.txt +++ b/docs/readthedocs_requirements.txt @@ -1,2 +1,3 @@ +numpy<1.27.0,>=1.19.5 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html torch==1.8.2+cpu diff --git a/examples/tutorial/interactive_visualizer.ipynb b/examples/tutorial/interactive_visualizer.ipynb new file mode 100644 index 000000000..37f148409 --- /dev/null +++ b/examples/tutorial/interactive_visualizer.ipynb @@ -0,0 +1,577 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "43840b4a-f9dc-4076-b1bf-5276da05f4ea", + "metadata": { + "tags": [] + }, + "source": [ + "# Interactive visualizer\n", + "Using [Interactive visualizers](https://kaolin.readthedocs.io/en/latest/modules/kaolin.visualize.html) you can bring your own renderer and connect it to the visualizer.\n", + "\n", + "The main condition is that the renderer have to take a [Camera](https://kaolin.readthedocs.io/en/latest/modules/kaolin.render.camera.camera.html#kaolin-render-camera-camera) as input." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "69114969", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "import glob\n", + "import math\n", + "import copy\n", + "\n", + "import torch\n", + "import numpy as np\n", + "\n", + "import kaolin as kal\n", + "\n", + "import nvdiffrast\n", + "glctx = nvdiffrast.torch.RasterizeGLContext(False, device='cuda')" + ] + }, + { + "cell_type": "markdown", + "id": "dd8f9d9e", + "metadata": {}, + "source": [ + "## Load Mesh information" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4963d2ce", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Here replace the \"root\" by the path where you are storing shapenet\n", + "ds = kal.io.shapenet.ShapeNetV2(root='/data/ShapeNetCore.v2/',\n", + " categories=['car'],\n", + " train=True, split=1.,\n", + " with_materials=True,\n", + " output_dict=True)\n", + "mesh = ds[0]['mesh']\n", + "\n", + "### Uncomment to load a specific obj\n", + "# mesh = kal.io.obj.import_mesh('path/to/obj', with_materials=True)\n", + "\n", + "# Normalize the data between [-0.5, 0.5]\n", + "vertices = mesh.vertices.unsqueeze(0).cuda()\n", + "vertices_min = vertices.min(dim=1, keepdims=True)[0]\n", + "vertices_max = vertices.max(dim=1, keepdims=True)[0]\n", + "vertices -= (vertices_max + vertices_min) / 2.\n", + "vertices /= (vertices_max - vertices_min).max()\n", + "faces = mesh.faces.cuda()\n", + "\n", + "# Here we are preprocessing the materials, assigning faces to materials and\n", + "# using single diffuse color as backup when map doesn't exist (and face_uvs_idx == -1)\n", + "uvs = torch.nn.functional.pad(mesh.uvs.unsqueeze(0).cuda(), (0, 0, 0, 1)) % 1.\n", + "face_uvs_idx = mesh.face_uvs_idx.cuda()\n", + "materials_order = mesh.materials_order\n", + "diffuse_maps = [m['map_Kd'].permute(2, 0, 1).unsqueeze(0).cuda().float() / 255. if 'map_Kd' in m else\n", + " m['Kd'].reshape(1, 3, 1, 1).cuda()\n", + " for m in mesh.materials]\n", + "specular_maps = [m['map_Ks'].permute(2, 0, 1).unsqueeze(0).cuda().float() / 255. if 'map_Ks' in m else\n", + " m['Ks'].reshape(1, 3, 1, 1).cuda()\n", + " for m in mesh.materials]\n", + "\n", + "nb_faces = faces.shape[0]\n", + "\n", + "num_consecutive_materials = \\\n", + " torch.cat([\n", + " materials_order[1:, 1],\n", + " torch.LongTensor([nb_faces])\n", + " ], dim=0)- materials_order[:, 1]\n", + "\n", + "face_material_idx = kal.ops.batch.tile_to_packed(\n", + " materials_order[:, 0],\n", + " num_consecutive_materials\n", + ").cuda().squeeze(-1)\n", + "mask = face_uvs_idx == -1\n", + "face_uvs_idx[mask] = uvs.shape[1] - 1\n", + "face_vertices = kal.ops.mesh.index_vertices_by_faces(vertices, faces)\n", + "face_world_normals = kal.ops.mesh.face_normals(face_vertices, unit=True)" + ] + }, + { + "cell_type": "markdown", + "id": "03e898f4", + "metadata": {}, + "source": [ + "## Instantiate a camera\n", + "\n", + "With the general constructor `Camera.from_args()` the underlying constructors are `CameraExtrinsics.from_lookat()` and `PinholeIntrinsics.from_fov` we will use this camera as a starting point for the visualizers." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c6eee7ab", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "camera = kal.render.camera.Camera.from_args(eye=torch.tensor([2., 1., 1.], device='cuda'),\n", + " at=torch.tensor([0., 0., 0.]),\n", + " up=torch.tensor([1., 1., 1.]),\n", + " fov=math.pi * 45 / 180,\n", + " width=512, height=512, device='cuda')" + ] + }, + { + "cell_type": "markdown", + "id": "4fff8eb1", + "metadata": {}, + "source": [ + "## Rendering a mesh\n", + "\n", + "Here we are rendering the loaded mesh with [nvdiffrast](https://github.com/NVlabs/nvdiffrast) using the camera object created above and use both diffuse and specular reflectance for lighting.\n", + "\n", + "For more information on lighting in Kaolin see [diffuse](./diffuse_lighting.ipynb) and [specular](./sg_specular_lighting.ipynb) tutorials and the [documentation](https://kaolin.readthedocs.io/en/latest/modules/kaolin.render.lighting.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5e4b8a49", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Those are the parameters used to define the Spherical gaussian\n", + "azimuth = torch.zeros((1,), device='cuda')\n", + "elevation = torch.full((1,), math.pi / 3., device='cuda')\n", + "amplitude = torch.full((1, 3), 3., device='cuda')\n", + "sharpness = torch.full((1,), 5., device='cuda')\n", + "# We will use this variable to enable / disable specular reflectance\n", + "global apply_specular\n", + "apply_specular = True\n", + "\n", + "def generate_pinhole_rays_dir(camera, height, width, device='cuda'):\n", + " \"\"\"Generate centered grid.\n", + " \n", + " This is a utility function for specular reflectance with spherical gaussian.\n", + " \"\"\"\n", + " pixel_y, pixel_x = torch.meshgrid(\n", + " torch.arange(height, device=device),\n", + " torch.arange(width, device=device),\n", + " indexing='ij'\n", + " )\n", + " pixel_x = pixel_x + 0.5 # scale and add bias to pixel center\n", + " pixel_y = pixel_y + 0.5 # scale and add bias to pixel center\n", + "\n", + " # Account for principal point (offsets from the center)\n", + " pixel_x = pixel_x - camera.x0\n", + " pixel_y = pixel_y + camera.y0\n", + "\n", + " # pixel values are now in range [-1, 1], both tensors are of shape res_y x res_x\n", + " # Convert to NDC\n", + " pixel_x = 2 * (pixel_x / width) - 1.0\n", + " pixel_y = 2 * (pixel_y / height) - 1.0\n", + "\n", + " ray_dir = torch.stack((pixel_x * camera.tan_half_fov(kal.render.camera.intrinsics.CameraFOV.HORIZONTAL),\n", + " -pixel_y * camera.tan_half_fov(kal.render.camera.intrinsics.CameraFOV.VERTICAL),\n", + " -torch.ones_like(pixel_x)), dim=-1)\n", + "\n", + " ray_dir = ray_dir.reshape(-1, 3) # Flatten grid rays to 1D array\n", + " ray_orig = torch.zeros_like(ray_dir)\n", + "\n", + " # Transform from camera to world coordinates\n", + " ray_orig, ray_dir = camera.extrinsics.inv_transform_rays(ray_orig, ray_dir)\n", + " ray_dir /= torch.linalg.norm(ray_dir, dim=-1, keepdim=True)\n", + "\n", + " return ray_dir[0].reshape(1, height, width, 3)\n", + "\n", + "\n", + "def base_render(camera, height, width):\n", + " \"\"\"Base function for rendering using separate height and width\"\"\"\n", + " transformed_vertices = camera.transform(vertices)\n", + " face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(\n", + " transformed_vertices, faces)\n", + " face_normals_z = kal.ops.mesh.face_normals(\n", + " face_vertices_camera,\n", + " unit=True\n", + " )[..., -1:].contiguous()\n", + " # Create a fake W (See nvdiffrast documentation)\n", + " pos = torch.nn.functional.pad(\n", + " transformed_vertices, (0, 1), mode='constant', value=1.\n", + " ).contiguous()\n", + " rast = nvdiffrast.torch.rasterize(\n", + " glctx, pos, faces.int(), (height, width), grad_db=False)\n", + " hard_mask = rast[0][:, :, :, -1:] != 0\n", + " face_idx = (rast[0][..., -1].long() - 1).contiguous()\n", + "\n", + " uv_map = nvdiffrast.torch.interpolate(\n", + " uvs, rast[0], face_uvs_idx.int())[0]\n", + "\n", + " im_world_normals = face_world_normals.reshape(-1, 3)[face_idx]\n", + " im_cam_normals = face_normals_z.reshape(-1, 1)[face_idx]\n", + " im_world_normals = im_world_normals * torch.sign(im_cam_normals)\n", + " albedo = torch.zeros(\n", + " (1, height, width, 3),\n", + " dtype=torch.float, device='cuda'\n", + " )\n", + " spec_albedo = torch.zeros(\n", + " (1, height, width, 3),\n", + " dtype=torch.float, device='cuda'\n", + " )\n", + " # Obj meshes can be composed of multiple materials\n", + " # so at rendering we need to interpolate from corresponding materials\n", + " im_material_idx = face_material_idx[face_idx]\n", + " im_material_idx[face_idx == -1] = -1\n", + "\n", + " for i, material in enumerate(diffuse_maps):\n", + " mask = im_material_idx == i\n", + " mask_idx = torch.nonzero(mask, as_tuple=False)\n", + " _texcoords = uv_map[mask] * 2. - 1.\n", + " _texcoords[:, 1] = -_texcoords[:, 1]\n", + " pixel_val = torch.nn.functional.grid_sample(\n", + " diffuse_maps[i], _texcoords.reshape(1, 1, -1, 2),\n", + " mode='bilinear', align_corners=False,\n", + " padding_mode='border')\n", + " albedo[mask] = pixel_val[0, :, 0].permute(1, 0)\n", + " pixel_val = torch.nn.functional.grid_sample(\n", + " specular_maps[i], _texcoords.reshape(1, 1, -1, 2),\n", + " mode='bilinear', align_corners=False,\n", + " padding_mode='border')\n", + " spec_albedo[mask] = pixel_val[0, :, 0].permute(1, 0)\n", + " img = torch.zeros((1, height, width, 3),\n", + " dtype=torch.float, device='cuda')\n", + " sg_x, sg_y, sg_z = kal.ops.coords.spherical2cartesian(azimuth, elevation)\n", + " directions = torch.stack(\n", + " [sg_x, sg_z, sg_y],\n", + " dim=-1\n", + " )\n", + " im_world_normals = im_world_normals[hard_mask.squeeze(-1)]\n", + " diffuse_effect = kal.render.lighting.sg_diffuse_inner_product(\n", + " amplitude, directions, sharpness,\n", + " im_world_normals,\n", + " albedo[hard_mask.squeeze(-1)]\n", + " )\n", + " img[hard_mask.squeeze(-1)] = diffuse_effect\n", + " global apply_specular\n", + " if apply_specular:\n", + " rays_d = generate_pinhole_rays_dir(camera, height, width)\n", + " specular_effect = kal.render.lighting.sg_warp_specular_term(\n", + " amplitude, directions, sharpness,\n", + " im_world_normals,\n", + " torch.full((im_world_normals.shape[0],), 0.5, device='cuda'),\n", + " -rays_d[hard_mask.squeeze(-1)],\n", + " spec_albedo[hard_mask.squeeze(-1)]\n", + " )\n", + " img[hard_mask.squeeze(-1)] += specular_effect\n", + "\n", + " # Need to flip the image because opengl\n", + " return (torch.flip(torch.clamp(\n", + " img * hard_mask, 0., 1.\n", + " )[0], dims=(0,)) * 255.).to(torch.uint8)\n", + "\n", + "def render(camera):\n", + " \"\"\"Render using camera dimension.\n", + " \n", + " This is the main function provided to the interactive visualizer\n", + " \"\"\"\n", + " return base_render(camera, camera.height, camera.width)\n", + "\n", + "def lowres_render(camera):\n", + " \"\"\"Render with lower dimension.\n", + " \n", + " This function will be used as a \"fast\" rendering used when the mouse is moving to avoid slow down.\n", + " \"\"\"\n", + " return base_render(camera, int(camera.height / 4), int(camera.width / 4))" + ] + }, + { + "cell_type": "markdown", + "id": "383f1ffa-936c-416f-8779-d20fe28e7230", + "metadata": {}, + "source": [ + "## Turntable visualizer\n", + "This is a simple visualizer useful to inspect a small object.\n", + "\n", + "You can move around with the mouse (left button) and zoom with the mouse wheel.\n", + "See the [documentation](https://kaolin.readthedocs.io/en/latest/modules/kaolin.visualize.html#kaolin.visualize.IpyTurntableVisualizer) to customize the sensitivity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f10f6ab7-662c-4cf6-8af4-5f28ae79d795", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ac64aa63260a48ff9c4f892fe66b2931", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Canvas(height=512, width=512)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b16ce5f96a554ecfa44b90ca96b6a16f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "visualizer = kal.visualize.IpyTurntableVisualizer(\n", + " 512, 512, copy.deepcopy(camera), render,\n", + " fast_render=lowres_render, max_fps=24, world_up_axis=1)\n", + "visualizer.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d8ff3945-66a4-47af-a3c0-4428c2adf6cf", + "metadata": {}, + "source": [ + "## First person visualizer\n", + "This is a visualizer useful to inspect details on an object, or a big scene.\n", + "\n", + "You can move the orientation of the camera with the mouse left button, move the camera around with the mouse right button or\n", + "the keys 'i' (up), 'k' (down), 'j' (left), 'l' (right), 'o' (forward), 'u' (backward)\n", + "\n", + "See the [documentation](https://kaolin.readthedocs.io/en/latest/modules/kaolin.visualize.html#kaolin.visualize.IpyFirstPersonVisualizer) to customize the sensitivity and keys.\n", + "\n", + "--------------------\n", + "*Note: camera are mutable in the visualizer. If you want to keep track of the camera position you can remove the `copy.deepcopy` on camera argument or you can check `visualizer.camera`*" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "27a16dea", + "metadata": { + "scrolled": false, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3b8c57ff3e434165bb8c0f6a1c000212", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Canvas(height=512, width=512)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb6cac48b60e4073a6a7c07482474aae", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "visualizer = kal.visualize.IpyFirstPersonVisualizer(\n", + " 512, 512, copy.deepcopy(camera), render, fast_render=lowres_render,\n", + " max_fps=24, world_up=torch.tensor([0., 1., 0.], device='cuda'))\n", + "visualizer.show()" + ] + }, + { + "cell_type": "markdown", + "id": "0c6a9142", + "metadata": {}, + "source": [ + "# Adding events and other widgets\n", + "\n", + "The visualizer is modular.\n", + "Here we will add:\n", + "* sliders to control the spherical gaussian parameters (see [ipywidgets tutorial](https://ipywidgets.readthedocs.io/en/stable/examples/Using%20Interact.html) for more info).\n", + "* A key event to 'space' to enable / disable specular reflectance (see [ipyevents documentation](https://github.com/mwcraig/ipyevents/blob/main/docs/events.ipynb)) to see all the events that can be caught.\n", + "\n", + "In general if you want to modify the rendering function you can use global variables or make a class (with the rendering function being a method)\n", + "\n", + "-------------\n", + "More info on spherical gaussians parameters on this [tutorial]((./sg_specular_lighting.ipynb)\n", + "and [documentation](https://kaolin.readthedocs.io/en/latest/modules/kaolin.render.lighting.html)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0a9fd84e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4b177f1809084b9e8a1ff0ac907656e6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Canvas(height=512, width=512), interactive(children=(FloatSlider(value=0.0, description='Elevat…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6dc3a41289c14251bc3e7337d1b0eb03", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from ipywidgets import interactive, HBox, FloatSlider\n", + "\n", + "def additional_event_handler(visualizer, event):\n", + " \"\"\"Event handler to be provided to Kaolin's visualizer\"\"\"\n", + " with visualizer.out: # This is for catching print and errors\n", + " if event['type'] == 'keydown' and event['key'] == ' ':\n", + " global apply_specular\n", + " apply_specular = not apply_specular\n", + " visualizer.render_update()\n", + " return False\n", + " return True\n", + "\n", + "visualizer = kal.visualize.IpyTurntableVisualizer(\n", + " 512, 512, copy.deepcopy(camera), render,\n", + " fast_render=lowres_render, max_fps=24,\n", + " additional_event_handler=additional_event_handler,\n", + " additional_watched_events=['keydown'] # We need to now watch for key press event\n", + ")\n", + "# we don't call visualizer.show() here\n", + "\n", + "def sliders_callback(new_elevation, new_azimuth, new_amplitude, new_sharpness):\n", + " \"\"\"ipywidgets sliders callback\"\"\"\n", + " with visualizer.out: # This is in case of bug\n", + " elevation[:] = new_elevation\n", + " azimuth[:] = new_azimuth\n", + " amplitude[:] = new_amplitude\n", + " sharpness[:] = new_sharpness\n", + " # this is how we request a new update\n", + " visualizer.render_update()\n", + " \n", + "elevation_slider = FloatSlider(\n", + " value=0.,\n", + " min=-math.pi / 2.,\n", + " max=math.pi / 2.,\n", + " step=0.1,\n", + " description='Elevation:',\n", + " continuous_update=True,\n", + " readout=True,\n", + " readout_format='.1f',\n", + ")\n", + "\n", + "azimuth_slider = FloatSlider(\n", + " value=0.,\n", + " min=-math.pi,\n", + " max=math.pi,\n", + " step=0.1,\n", + " description='Azimuth:',\n", + " continuous_update=True,\n", + " readout=True,\n", + " readout_format='.1f',\n", + ")\n", + "\n", + "amplitude_slider = FloatSlider(\n", + " value=5.,\n", + " min=0.1,\n", + " max=20.,\n", + " step=0.1,\n", + " description='Amplitude:\\n',\n", + " continuous_update=True,\n", + " readout=True,\n", + " readout_format='.1f',\n", + ")\n", + "\n", + "sharpness_slider = FloatSlider(\n", + " value=5.,\n", + " min=0.1,\n", + " max=20.,\n", + " step=0.1,\n", + " description='Sharpness:\\n',\n", + " continuous_update=True,\n", + " readout=True,\n", + " readout_format='.1f',\n", + ")\n", + "\n", + "interactive_slider = interactive(\n", + " sliders_callback,\n", + " new_elevation=elevation_slider,\n", + " new_azimuth=azimuth_slider,\n", + " new_amplitude=amplitude_slider,\n", + " new_sharpness=sharpness_slider\n", + ")\n", + "\n", + "# We combine all the widgets and the visualizer canvas and output in a single display\n", + "full_output = HBox([visualizer.canvas, interactive_slider])\n", + "display(full_output, visualizer.out)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/kaolin/render/camera/intrinsics_pinhole.py b/kaolin/render/camera/intrinsics_pinhole.py index fdc0893ab..ddf8e814e 100644 --- a/kaolin/render/camera/intrinsics_pinhole.py +++ b/kaolin/render/camera/intrinsics_pinhole.py @@ -671,4 +671,3 @@ def zoom(self, amount): fov_ratio = self.fov_x / self.fov_y self.fov_y -= amount self.fov_x = self.fov_y * fov_ratio # Make sure the view is not distorted - diff --git a/kaolin/visualize/__init__.py b/kaolin/visualize/__init__.py index 8f9040bcd..8158aebf1 100644 --- a/kaolin/visualize/__init__.py +++ b/kaolin/visualize/__init__.py @@ -1,3 +1,4 @@ from .timelapse import * +from .ipython import * __all__ = [k for k in locals().keys() if not k.startswith('__')] diff --git a/kaolin/visualize/ipython.py b/kaolin/visualize/ipython.py new file mode 100644 index 000000000..4758e258f --- /dev/null +++ b/kaolin/visualize/ipython.py @@ -0,0 +1,710 @@ +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +import math + +import torch + +from ipyevents import Event +from ipywidgets import Output +from ipywidgets import Image as ImageWidget +from ipycanvas import Canvas + +from io import BytesIO +from PIL import Image as PILImage + +from ..render.camera import CameraExtrinsics + +__all__ = [ + 'update_canvas', + 'BaseIpyVisualizer', + 'IpyTurntableVisualizer', + 'IpyFirstPersonVisualizer', +] + +def update_canvas(canvas, image): + assert isinstance(image, torch.Tensor) and image.dtype == torch.uint8, \ + "image must be a torch.Tensor of uint8 " + assert isinstance(canvas, Canvas) + f = BytesIO() + PILImage.fromarray(image.cpu().numpy()).save( + f, "PNG", quality=100) + image = ImageWidget(value=f.getvalue()) + canvas.draw_image(image, 0, 0, canvas.width, canvas.height) + +class BaseIpyVisualizer(object): + r"""Base class for ipython visualizer. + + To create a visualizer one must define the class attribute _WATCHED_EVENTS and + the method :func:`_handle_event`. + + the method :func:`_handle_event` must use the methods + :func:`self.render()` or :func:`self.fast_render()` to update the canvas + + You can overload the constructor + (make sure to reuse the base class one so that ipycanvas and ipyevents are properly used) + + Args: + height (int): height of the canvas. + width (int): width of the canvas. + camera (kal.render.camera.Camera): Camera used for the visualization. + render (Callable): + render function that take a :class:`kal.render.camera.Camera` as input. + Must return a torch.uint8 tensor as output, + of shape :math:`(\text{output_height}, \text{output_width})`, + height and width don't have to match canvas dimension. + fast_render (optional, Callable): + A faster rendering function that may be used when doing high frequency manipulation + such as moving the camera with a mouse. Default: same than ``render``. + watched_events (list of str): + Events to be watched by the visualizer + (see `ipyevents main documentation`_). + max_fps (float): + maximum framerate for handling consecutive events, + this is useful when the rendering is slow to avoid freezes. + Typically 24 fps is great when working on a local machine + with :func:`render` close to real-time, + and lower to 10 fps with slower rendering or network latency. + + .. _ipyevents main documentation: https://github.com/mwcraig/ipyevents/blob/main/docs/events.ipynb + """ + + def __init__(self, height, width, camera, render, fast_render=None, + watched_events=None, max_fps=None): + self.canvas = Canvas(height=height, width=width) + self.out = Output() + assert len(camera) == 1, "only single camera supported for visualizer" + self.camera = camera + self.render = render + if fast_render is None: + self.fast_render = render + else: + self.fast_render = fast_render + + self._max_fps = max_fps + + wait = 0 if max_fps is None else int(1000. / max_fps) + + self.event = Event( + source=self.canvas, + watched_events=watched_events, + prevent_default_action=True, + wait=wait, + ) + self.event.on_dom_event(self._handle_event) + + + def render_update(self): + """Update the Canvas with :func:`render`""" + with torch.no_grad(): + update_canvas(self.canvas, self.render(self.camera)) + + def fast_render_update(self): + """Update the Canvas with :func:`fast_render`""" + with torch.no_grad(): + update_canvas(self.canvas, self.fast_render(self.camera)) + + def show(self): + """display the Canvas with interactive features""" + self.render_update() + display(self.canvas, self.out) + + @abstractmethod + def _handle_event(self, event): + pass + + @property + def max_fps(self): + """maximum fps for handling consecutive events""" + return self._max_fps + + @max_fps.setter + def max_fps(self, new_val): + self._max_fps = new_val + if new_val is None: + self.event.wait = 0 + else: + self.event.wait = int(1000. / max_fps) + +@torch.jit.script +def make_quaternion_rotation(angle: float, vec: torch.Tensor): + r"""Represent a rotation around axis as a quaternion. + + Args: + angle (float): angle of rotation. + vec (torch.Tensor): + axis around which the rotation is done, + of shape :math:`(\text{batch_size}, 3)` + + Returns: + (torch.Tensor): A quaternion of shape :math:`(\text{batch_size}, 4)` + """ + half_angle = angle / 2 + sin_half_angle = math.sin(half_angle) + cos_half_angle = math.cos(half_angle) + return torch.stack([ + vec[:, 0] * sin_half_angle, + vec[:, 1] * sin_half_angle, + vec[:, 2] * sin_half_angle, + torch.full((vec.shape[0],), cos_half_angle, dtype=vec.dtype, device=vec.device) + ], dim=-1) + +@torch.jit.script +def conjugate(quat: torch.Tensor): + r"""Return the conjugate of a quaternion. + + Args: + quat (torch.Tensor): The quaternion, of shape :math:`(\text{batch_size}, 4)`. + + Returns: + (torch.Tensor): the conjugate, of shape :math:`(\text{batch_size}, 4)`. + """ + return torch.stack([-quat[:, 0], -quat[:, 1], -quat[:, 2], quat[:, 3]], dim=-1) + +@torch.jit.script +def mulqv(q: torch.Tensor, v: torch.Tensor): + r"""Return the product of a quaternion with a 3D vector. + + Support broadcasting. + + Args: + q (torch.Tensor): The quaternion, of shape :math:`(\text{batch_size}, 4)`. + v (torch.Tensor): The vector, of shape :math:`(\text{batch_size}, 3)`. + + Return: + (torch.Tensor): A quaternion, of shape :math:`(\text{batch_size}, 4)`. + """ + output = torch.stack([ + q[:, 3] * v[:, 0] + q[:, 1] * v[:, 2] - q[:, 2] * v[:, 1], + q[:, 3] * v[:, 1] + q[:, 2] * v[:, 0] - q[:, 0] * v[:, 2], + q[:, 3] * v[:, 2] + q[:, 0] * v[:, 1] - q[:, 1] * v[:, 0], + - q[:, 0] * v[:, 0] - q[:, 1] * v[:, 1] - q[:, 2] * v[:, 2], + ], dim=-1) + return output + +@torch.jit.script +def mulqq(l: torch.Tensor, r: torch.Tensor): + r"""Return the product of two quaternions. + + Support broadcasting. + + Args: + l (torch.Tensor): The quaternion, of shape :math:`(\text{batch_size}, 4)`. + r (torch.Tensor): The quaternion, of shape :math:`(\text{batch_size}, 4)`. + + Returns: + (torch.Tensor): A quaternion, of shape :math:`(\text{batch_size}, 4)`. + """ + output = torch.stack([ + l[:, 0] * r[:, 3] + l[:, 3] * r[:, 0] + l[:, 1] * r[:, 2] - l[:, 2] * r[:, 1], + l[:, 1] * r[:, 3] + l[:, 3] * r[:, 1] + l[:, 2] * r[:, 0] - l[:, 0] * r[:, 2], + l[:, 2] * r[:, 3] + l[:, 3] * r[:, 2] + l[:, 0] * r[:, 1] - l[:, 1] * r[:, 0], + l[:, 3] * r[:, 3] - l[:, 0] * r[:, 0] - l[:, 1] * r[:, 1] - l[:, 2] * l[:, 2], + ], dim=-1) + return output + +@torch.jit.script +def rotate_around_axis(point: torch.Tensor, angle: float, axis: torch.Tensor): + r"""Compute the rotation of a point around an axis. + + Args: + point (torch.Tensor): The point to be rotated, of shape :math:`(\text{batch_size}, 3)`. + angle (float): The angle of rotation + axis (torch.Tensor): The axis around which the point is revolving, + of shape :math:`(\text{batch_size}, 3)`. + + Returns: + (torch.Tensor): The rotated point, of shape :math:`(\text{batch_size}, 3)`. + """ + rot_q = make_quaternion_rotation(angle, axis) + conj_q = conjugate(rot_q) + w = mulqq(mulqv(rot_q, point), conj_q) + return w[:, :-1] + +class IpyTurntableVisualizer(BaseIpyVisualizer): + r"""An interactive turntable visualizer that can display on jupyter notebook. + + You can move around with the mouse (using the left button), zoom with the wheel and + get closer to the center with the wheel + control key. + + Args: + height (int): height of the canvas. + width (int): width of the canvas. + camera (kal.render.camera.Camera): + Camera used for the visualization. + Note: The camera will be reoriented to look at ``focus_at`` + and with respect to ``world_up``. + render (Callable): + render function that take a :class:`kal.render.camera.Camera` as input. + Must return a torch.uint8 tensor as output + of shape :math:`(\text{output_height}, \text{output_width})`, + height and width don't have to match canvas dimension. + fast_render (optional, Callable): + A faster rendering function that may be used when doing high frequency manipulation + such as moving the camera with a mouse. Default: same than ``render``. + focus_at (optional, torch.Tensor): + The center of the turntable on which the camera is focusing on. + Default: (0, 0, 0). + world_up_axis (optional, int): + The up axis of the world, in the coordinate system. Default: 1. + zoom_sensitivity (float): + Sensitivity of the wheel on zoom. Default: 1e-3. + forward_sensitivity (float): + Sensitivity of the wheel on forward. Default: 1e-3. + mouse_sensitivity (float): + Sensitivity of the mouse on movements. Default: 1.5. + max_fps (optional, float): + maximum framerate for handling consecutive events, + this is useful when the rendering is slow to avoid freezes. + Typically 24 fps is great when working on a local machine + with :func:`render` close to real-time. + And you lower to 10 fps with slower rendering or network latency. + Default: 24 fps. + update_only_on_release (bool): + If true, the canvas won't be updated while the mouse button is pressed + and only when it's released. To avoid freezes with very slow rendering functions. + Default: False. + additional_watched_events (optional, list of str): + Additional events to be watched by the visualizer + (see `ipyevents main documentation`_). + To be used for customed events such as enabling / disabling a feature on a key press. + ['wheel', 'mousedown', 'mouseup', 'mousemove', 'mouseleave'] are already watched. + Default: None. + additional_event_handler (optional, Callable): + Additional event handler to be used for customed events such as + enabling / disabling a feature on a key press. + The Callable must take as input a tuple of (this visualizer object, the event). + (see `ipyevents main documentation`_). + + Attributes: + camera (kal.render.camera.Camera): The camera used for rendering. + canvas (ipycanvas.Canvas): The canvas on which the rendering is copied to. + out (ipywidgets.Output): An output where error and prints are displayed. + render (Callable): The rendering function. + fast_render (Callable) + focus_at (torch.Tensor) + world_up_axis (int) + zoom_sensitivity (float) + forward_sensitivity (float) + mouse_sensitivity (float) + max_fps (int) + update_only_on_release (bool) + + Methods: + show: display the Canvas with interactive features. + render_update: Update the Canvas with :func:`render()`. + fast_render_update: Update the Canvas with :func:`fast_render()`. + + .. _ipyevents main documentation: https://github.com/mwcraig/ipyevents/blob/main/docs/events.ipynb + """ + def __init__(self, + height, + width, + camera, + render, + fast_render=None, + focus_at=None, + world_up_axis=1, + zoom_sensitivity=0.001, + forward_sensitivity=0.001, + mouse_sensitivity=1.5, + max_fps=24., + update_only_on_release=False, + additional_watched_events=None, + additional_event_handler=None): + + with torch.no_grad(): + if focus_at is None: + self.focus_at = torch.zeros((3,), device=camera.device) + else: + self.focus_at = focus_at + + up = torch.zeros((3,), device=camera.device) + up[world_up_axis] = float(camera.cam_up().squeeze()[world_up_axis] >= 0) * 2. - 1. + camera.extrinsics = CameraExtrinsics.from_lookat( + eye=camera.cam_pos().squeeze(), + at=self.focus_at, + up=up, + dtype=camera.dtype, + device=camera.device, + ) + + self.position = None + + self.world_up_axis = world_up_axis + self.zoom_sensitivity = zoom_sensitivity + self.forward_sensitivity = forward_sensitivity + self.mouse_scale = mouse_sensitivity * math.pi + self.update_only_on_release = update_only_on_release + + watched_events = ['wheel', 'mousedown', 'mouseup', 'mousemove', 'mouseleave', 'mouseenter'] + if additional_watched_events is not None: + watched_events += additional_watched_events + self.additional_event_handler = additional_event_handler + + super().__init__(height, width, camera, render, fast_render, + watched_events, max_fps) + + def _move_turntable(self, amount_elevation, amount_azimuth): + """Move the camera around a focus point as turntable + + Args: + amount_elevation (float): + Amount of elevation rotation, measured in radians. + amount_azimuth (float): + Amount of azimuth rotation, measure in radians. + """ + radius = (self.camera.cam_pos().squeeze() - self.focus_at).norm(dim=-1, keepdim=True) + # Rotates camera in normal direction to plane (world space) + in_plane_amount = -amount_azimuth + # Rotates camera in up-forward direction (camera space) + pitch = -amount_elevation + self.camera.t = torch.zeros(3, device=self.camera.device, dtype=self.camera.dtype) + self.camera.rotate(pitch=pitch) + translate = torch.eye(4, device=self.camera.device, dtype=self.camera.dtype) + translate[:3, 3] = -self.focus_at + rot_mat = torch.eye(4, device=self.camera.device, dtype=self.camera.dtype) + if self.world_up_axis == 1: + rot_mat[0, 0] = math.cos(in_plane_amount) + rot_mat[0, 2] = -math.sin(in_plane_amount) + rot_mat[2, 0] = math.sin(in_plane_amount) + rot_mat[2, 2] = math.cos(in_plane_amount) + elif self.world_up_axis == 2: + rot_mat[0, 0] = math.cos(in_plane_amount) + rot_mat[0, 1] = math.sin(in_plane_amount) + rot_mat[1, 0] = -math.sin(in_plane_amount) + rot_mat[1, 1] = math.cos(in_plane_amount) + elif self.world_up_axis == 0: + rot_mat[1, 1] = math.cos(in_plane_amount) + rot_mat[1, 2] = math.sin(in_plane_amount) + rot_mat[2, 1] = -math.sin(in_plane_amount) + rot_mat[2, 2] = math.cos(in_plane_amount) + view_matrix = self.camera.view_matrix().squeeze(0) @ rot_mat @ translate + self.camera._backend.update(view_matrix) + cam_forward = self.camera.cam_forward().squeeze() + backward_dir = cam_forward / cam_forward.norm(dim=-1, keepdim=True) + self.camera.translate(radius * backward_dir) + + def _safe_zoom(self, amount): + r"""Applies a zoom on the camera by adjusting the lens. + + This function is different from :func:`kal.render.camera.CameraExtrinsics.zoom` + in which the FOV is constrained by a sigmoid. + + Args: + amount (float): + Amount of adjustment. + Mind the conventions - + To zoom in, give a positive amount (decrease fov by amount -> increase focal length) + To zoom out, give a negative amount (increase fov by amount -> decrease focal length) + """ + fov_ratio = self.camera.fov_x / self.camera.fov_y + fov_y_coeff = self.camera.fov_y / 180. + inv_fov_y = torch.log(fov_y_coeff / (1 - fov_y_coeff)) + self.camera.fov_y = torch.sigmoid(inv_fov_y + amount) * 180. + self.camera.fov_x = self.camera.fov_y * fov_ratio # Make sure the view is not distorted + + def _safe_forward(self, amount): + r"""Move the camera forward (or backward if negative) + + This functions is different from :func:`kal.render.camera.CameraExtrinsics.move_forward` + in which the radius is restricted by :math:`new_radius = exp(log(old_radius) - amount)` + + Args: + amount (float): Amout of adjustment (positive amount => move forward) + """ + radius = (self.camera.cam_pos().squeeze() - self.focus_at).norm(dim=-1, keepdim=True) + log_radius = torch.log(radius) + new_radius = torch.exp(log_radius + amount) + self.camera.move_forward(new_radius - radius) + + def _handle_event(self, event): + with torch.no_grad(): + with self.out: + process_event = True + if self.additional_event_handler is not None: + process_event = self.additional_event_handler(self, event) + if process_event: + if event['type'] == 'wheel': + if event['ctrlKey']: + self._safe_forward(event['deltaY'] * self.forward_sensitivity) + else: + self._safe_zoom(event['deltaY'] * self.zoom_sensitivity) + self.render_update() + elif event['type'] == 'mousedown': + self.position = (event['relativeX'], event['relativeY']) + # If the camera is upside down w.r.t to world we need to invert azimuth movement + self.sign = torch.sign(self.camera.cam_up()[0, self.world_up_axis, 0]) + elif event['type'] in ['mouseup', 'mouseleave', 'mouseenter']: + self.render_update() + elif event['type'] == 'mousemove' and event['buttons'] == 1: + dx = (self.mouse_scale * + (event['relativeX'] - self.position[0]) / self.canvas.width) + dy = (self.mouse_scale * + (event['relativeY'] - self.position[1]) / self.canvas.height) + self._move_turntable(dy, self.sign * dx) + self.position = (event['relativeX'], event['relativeY']) + if not self.update_only_on_release: + self.fast_render_update() + +class IpyFirstPersonVisualizer(BaseIpyVisualizer): + r"""An interactive first person visualizer that can display on jupyter notebook. + + You can move the orientation with the left button of the mouse, + move the position of the camera with the right button of the mouse or the associated key, + and zoom with the wheel. + + Args: + height (int): height of the canvas. + width (int): width of the canvas. + camera (kal.render.camera.Camera): + Camera used for the visualization. + render (Callable): + render function that take a :class:`kal.render.camera.Camera` as input. + Must return a torch.uint8 tensor as output, + of shape :math:`(\text{output_height}, \text{output_width})`, + height and width don't have to match canvas dimension. + fast_render (optional, Callable): + A faster rendering function that may be used when doing high frequency manipulation + such as moving the camera with a mouse. Default: same than ``render``. + world_up (optional, torch.Tensor): + World up axis, of shape :math:`(3,)`. If provided the camera will be reoriented to avoid roll. + Default: ``camera.cam_up()``. + zoom_sensitivity (float): + Sensitivity of the wheel on zoom. Default: 1e-3. + rotation_sensitivity (float): + Sensitivity of the mouse on rotations. Default: 0.4. + translation_sensitivity (float): + Sensitivity of the mouse on camera translation. Default: 1. + key_move_sensitivity (float): + Amount of camera movement on key press. Default 0.05. + max_fps (optional, float): + maximum framerate for handling consecutive events, + this is useful when the rendering is slow to avoid freezes. + Typically 24 fps is great when working on a local machine + with :func:`render` close to real-time. + And you lower to 10 fps with slower rendering or network latency. + Default: 24 fps. + up_key (str): key associated to moving up. Default 'i'. + down_key (str): key associated to moving up. Default 'k'. + left_key (str): key associated to moving up. Default 'j'. + right_key (str): key associated to moving up. Default 'l'. + forward_key (str): key associated to moving up. Default 'o'. + backward_key (str): key associated to moving up. Default 'u'. + update_only_on_release (bool): + If true, the canvas won't be updated while the mouse button is pressed + and only when it's released. To avoid freezes with very slow rendering functions. + Default: False. + additional_watched_events (optional, list of str): + Additional events to be watched by the visualizer + (see `ipyevents main documentation`_). + To be used for customed events such as enabling / disabling a feature on a key press. + ['wheel', 'mousedown', 'mouseup', 'mousemove', 'mouseleave'] are already watched. + Default: None. + additional_event_handler (optional, Callable): + Additional event handler to be used for customed events such as + enabling / disabling a feature on a key press. + The Callable must take as input a tuple of (this visualizer object, the event). + (see `ipyevents main documentation`_). + + Attributes: + camera (kal.render.camera.Camera): The camera used for rendering. + canvas (ipycanvas.Canvas): The canvas on which the rendering is copied to. + out (ipywidgets.Output): An output where error and prints are displayed. + render (Callable): The rendering function. + fast_render (Callable) + world_up (torch.Tensor) + zoom_sensitivity (float) + rotation_sensitivity (float) + translation_sensitivity (float) + key_move_sensitivity (float) + max_fps (int) + update_only_on_release (bool) + + Methods: + show: display the Canvas with interactive features. + render_update: Update the Canvas with :func:`render()`. + fast_render_update: Update the Canvas with :func:`fast_render()`. + + .. _ipyevents main documentation: https://github.com/mwcraig/ipyevents/blob/main/docs/events.ipynb + """ + + def __init__(self, + height, + width, + camera, + render, + fast_render=None, + world_up=None, + zoom_sensitivity=0.001, + rotation_sensitivity=0.4, + translation_sensitivity=1., + key_move_sensitivity=0.05, + max_fps=24., + up_key='i', + down_key='k', + left_key='j', + right_key='l', + forward_key='o', + backward_key='u', + update_only_on_release=False, + additional_watched_events=None, + additional_event_handler=None): + + self.position = None + + with torch.no_grad(): + if world_up is None: + self.world_up = torch.nn.functional.normalize( + camera.cam_up().clone().detach().squeeze(-1), dim=-1) + self.world_right = torch.nn.functional.normalize( + camera.cam_right().clone().detach().squeeze(-1), dim=-1) + self.elevation = torch.zeros((1,), device=camera.device, dtype=camera.dtype) + else: + self.world_up = torch.nn.functional.normalize(world_up, dim=-1) + camera.extrinsics = CameraExtrinsics.from_lookat( + eye=camera.cam_pos().squeeze(), + at=(camera.cam_pos() - camera.cam_forward()).squeeze(), + up=self.world_up, + device=camera.device, + dtype=camera.dtype + ) + if self.world_up.ndim == 1: + self.world_up = self.world_up.unsqueeze(0) + + self.world_right = camera.cam_right().squeeze(-1) + self.elevation = torch.acos(torch.dot( + self.world_up.squeeze(), camera.cam_up().squeeze() + )).reshape(1) + if torch.dot(self.world_up.squeeze(), camera.cam_forward().squeeze()) >= 0: + self.elevation = -self.elevation + self.azimuth = torch.zeros((1,), device=camera.device, dtype=camera.dtype) + + self.zoom_sensitivity = zoom_sensitivity + self.rotation_scale = rotation_sensitivity * math.pi + self.translation_sensitivity = translation_sensitivity + self.key_move_sensitivity = key_move_sensitivity + + self.up_key = up_key + self.down_key = down_key + self.left_key = left_key + self.right_key = right_key + self.forward_key = forward_key + self.backward_key = backward_key + + self.update_only_on_release = update_only_on_release + + watched_events = ['wheel', 'mousedown', 'mouseup', 'mousemove', + 'mouseleave', 'mouseenter', 'contextmenu', 'keydown', 'keyup'] + if additional_watched_events is not None: + watched_events += additional_watched_events + self.additional_event_handler = additional_event_handler + + super().__init__(height, width, camera, render, fast_render, + watched_events, max_fps) + + def _safe_zoom(self, amount): + r"""Applies a zoom on the camera by adjusting the lens. + + This function is different from :func:`kal.render.camera.CameraExtrinsics.zoom` + in which the FOV is constrained by a sigmoid. + + Args: + amount (float): + Amount of adjustment. + Mind the conventions - + To zoom in, give a positive amount (decrease fov by amount -> increase focal length) + To zoom out, give a negative amount (increase fov by amount -> decrease focal length) + """ + fov_ratio = self.camera.fov_x / self.camera.fov_y + fov_y_coeff = self.camera.fov_y / 180. + inv_fov_y = torch.log(fov_y_coeff / (1 - fov_y_coeff)) + self.camera.fov_y = torch.sigmoid(inv_fov_y + amount) * 180. + self.camera.fov_x = self.camera.fov_y * fov_ratio # Make sure the view is not distorted + + def _first_person_rotate(self, move_azimuth, move_elevation): + """Do a combination of rotations around camera-right axis and world up""" + self.azimuth[:] = (self.azimuth + move_azimuth) % (2 * math.pi) + self.elevation[:] = torch.clamp(self.elevation + move_elevation, + -math.pi / 2., math.pi / 2.) + cam_right = rotate_around_axis(self.world_right, self.azimuth, self.world_up) + cam_up = rotate_around_axis(self.world_up, self.elevation, cam_right) + cam_forward = torch.cross(cam_right, cam_up) + world_rotation = torch.stack((cam_right, cam_up, cam_forward), dim=1) + world_translation = -world_rotation @ self.camera.cam_pos() + mat = self.camera.view_matrix() + mat[:, :3, :3] = world_rotation + mat[:, :3, 3] = world_translation.squeeze(-1) + self.camera._backend.update(mat) + + def _handle_event(self, event): + with torch.no_grad(): + with self.out: + process_event = True + if self.additional_event_handler is not None: + process_event = self.additional_event_handler(self, event) + if process_event: + if event['type'] == 'wheel': + self._safe_zoom(event['deltaY'] * self.zoom_sensitivity) + self.render_update() + elif event['type'] == 'mousedown': + self.position = (event['relativeX'], event['relativeY']) + elif event['type'] in ['mouseup', 'mouseleave', 'mouseenter']: + self.render_update() + elif event['type'] == 'mousemove': + if event['buttons'] == 1: + dx = (self.rotation_scale * + (event['relativeX'] - self.position[0]) / self.canvas.width) + dy = (self.rotation_scale * + (event['relativeY'] - self.position[1]) / self.canvas.height) + self._first_person_rotate(dx, dy) + self.position = (event['relativeX'], event['relativeY']) + if not self.update_only_on_release: + self.fast_render_update() + elif event['buttons'] == 2: + dx = (-self.translation_sensitivity * + (event['relativeX'] - self.position[0]) / self.canvas.width) + dy = (self.translation_sensitivity * + (event['relativeY'] - self.position[1]) / self.canvas.height) + self.camera.move_up(dy) + self.camera.move_right(dx) + self.position = (event['relativeX'], event['relativeY']) + if not self.update_only_on_release: + self.fast_render_update() + elif event['type'] == 'keydown': + if event['key'] == self.forward_key: + # Camera notion of forward is backward (OpenGL convention) + self.camera.move_forward(-self.key_move_sensitivity) + self.fast_render_update() + elif event['key'] == self.backward_key: + self.camera.move_forward(self.key_move_sensitivity) + self.fast_render_update() + elif event['key'] == self.up_key: + self.camera.move_up(self.key_move_sensitivity) + self.fast_render_update() + elif event['key'] == self.down_key: + self.camera.move_up(-self.key_move_sensitivity) + self.fast_render_update() + elif event['key'] == self.left_key: + self.camera.move_right(-self.key_move_sensitivity) + self.fast_render_update() + elif event['key'] == self.right_key: + self.camera.move_right(self.key_move_sensitivity) + self.fast_render_update() + elif event['type'] == 'keyup': + if event['key'] in [self.forward_key, self.backward_key, self.up_key, + self.down_key, self.right_key, self.left_key]: + self.render_update() diff --git a/setup.py b/setup.py index ad56aacce..9a702b2cf 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ import warnings TORCH_MIN_VER = '1.6.0' -TORCH_MAX_VER = '1.13.0' +TORCH_MAX_VER = '1.13.1' CYTHON_MIN_VER = '0.29.20' INCLUDE_EXPERIMENTAL = os.getenv('KAOLIN_INSTALL_EXPERIMENTAL') is not None IGNORE_TORCH_VER = os.getenv('IGNORE_TORCH_VER') is not None @@ -190,8 +190,11 @@ def get_requirements(): "to use USD related features") requirements.append('usd-core<22.8; python_version < "3.10"') if INCLUDE_EXPERIMENTAL: - requirements.append('tornado==6.1') + # requirements.append('tornado==6.1') requirements.append('flask==2.0.3') + with open(os.path.join(cwd, 'tools', 'viz_requirements.txt'), 'r') as f: + for line in f.readlines(): + requirements.append(line.strip()) return requirements diff --git a/tests/python/kaolin/render/camera/test_extrinsics.py b/tests/python/kaolin/render/camera/test_extrinsics.py index 44215f265..660c001e6 100644 --- a/tests/python/kaolin/render/camera/test_extrinsics.py +++ b/tests/python/kaolin/render/camera/test_extrinsics.py @@ -906,8 +906,11 @@ def test_min_max_angles_batched_tensor(self, device, dtype, backend): class TestCameraExtrinsicsMoveCam: def test_move_right(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, dtype=dtype, device=device, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, dtype=dtype, device=device, backend=backend) amount = 1.0 extrinsics.move_right(amount) @@ -917,8 +920,11 @@ def test_move_right(self, device, dtype, backend, cam_pos_data): assert_view_matrix(extrinsics.view_matrix(), expected_mat) def test_move_up(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, dtype=dtype, device=device, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, dtype=dtype, device=device, backend=backend) amount = 1.0 extrinsics.move_up(amount) @@ -928,8 +934,11 @@ def test_move_up(self, device, dtype, backend, cam_pos_data): assert_view_matrix(extrinsics.view_matrix(), expected_mat) def test_move_forward(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, dtype=dtype, device=device, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, dtype=dtype, device=device, backend=backend) amount = 1.0 extrinsics.move_forward(amount) @@ -938,7 +947,6 @@ def test_move_forward(self, device, dtype, backend, cam_pos_data): expected_mat[..., axis_idx, 3] -= amount assert_view_matrix(extrinsics.view_matrix(), expected_mat) - @pytest.mark.parametrize('backend', CameraExtrinsics.available_backends()) @pytest.mark.parametrize('requires_grad', (True, False)) class TestCameraExtrinsicsBackendProperties: @@ -997,8 +1005,11 @@ def test_all_close_backend(self, device, dtype, requires_grad, cam_pos_data): class TestCameraExtrinsicsCamPosDir: def test_cam_pos(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) num_cams = view_matrix.shape[0] expected = torch.tensor(cam_pos, device=device, dtype=dtype) expected = expected.reshape(num_cams, 3, 1) @@ -1007,8 +1018,11 @@ def test_cam_pos(self, device, dtype, backend, cam_pos_data): assert extrinsics_result.shape == (num_cams, 3, 1) def test_cam_right(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) num_cams = view_matrix.shape[0] cam_dir = torch.tensor(cam_dir, device=device, dtype=dtype) if cam_dir.ndim == 2: @@ -1021,8 +1035,11 @@ def test_cam_right(self, device, dtype, backend, cam_pos_data): assert extrinsics_result.shape == (num_cams, 3, 1) def test_cam_up(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) num_cams = view_matrix.shape[0] cam_dir = torch.tensor(cam_dir, device=device, dtype=dtype) if cam_dir.ndim == 2: @@ -1034,8 +1051,11 @@ def test_cam_up(self, device, dtype, backend, cam_pos_data): assert extrinsics_result.shape == (num_cams, 3, 1) def test_cam_forward(self, device, dtype, backend, cam_pos_data): - cam_pos, cam_dir, view_matrix = cam_pos_data['cam_pos'], cam_pos_data['cam_dir'], cam_pos_data['view_matrix'] - extrinsics = CameraExtrinsics.from_camera_pose(cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) + cam_pos = cam_pos_data['cam_pos'] + cam_dir = cam_pos_data['cam_dir'] + view_matrix = cam_pos_data['view_matrix'] + extrinsics = CameraExtrinsics.from_camera_pose( + cam_pos, cam_dir, device=device, dtype=dtype, backend=backend) num_cams = view_matrix.shape[0] cam_dir = torch.tensor(cam_dir, device=device, dtype=dtype) if cam_dir.ndim == 2: @@ -1045,3 +1065,4 @@ def test_cam_forward(self, device, dtype, backend, cam_pos_data): extrinsics_result = extrinsics.cam_forward() assert torch.allclose(extrinsics_result, expected, rtol=1e-3, atol=1e-3) assert extrinsics_result.shape == (num_cams, 3, 1) + diff --git a/tests/python/kaolin/visualize/test_ipython.py b/tests/python/kaolin/visualize/test_ipython.py new file mode 100644 index 000000000..d27ffe073 --- /dev/null +++ b/tests/python/kaolin/visualize/test_ipython.py @@ -0,0 +1,625 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import random + +import torch +import numpy as np +import pytest + +import kaolin + +class DummyRenderer(): + def __init__(self, height, width, value): + self.height = height + self.width = width + self.value = value + self.render_count = 0 + self.event_count = 0 + + def __call__(self, camera): + self.render_count += 1 + return torch.full((self.height, self.width, 3), self.value, + device=camera.device, dtype=torch.uint8) + +@pytest.mark.parametrize('height,width', [(64, 64), (32, 32)]) +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +class TestVisualizers: + + @pytest.fixture(autouse=True) + def camera(self, height, width, device): + return kaolin.render.camera.Camera.from_args( + eye=(torch.rand((3,)) - 0.5) * 10, + at=(torch.rand((3,)) - 0.5) * 10, + up=(torch.rand((3,)) - 0.5) * 10, + fov=random.uniform(0.1, math.pi - 0.1), + height=height, + width=width, + dtype=torch.float, + device=device + ) + + @pytest.fixture(autouse=True) + def renderer(self, height, width): + return DummyRenderer( + height, width, 0 + ) + + @pytest.fixture(autouse=True) + def fast_renderer(self, height, width): + return DummyRenderer( + int(height / 4), int(width / 4), 255 + ) + + #TODO(cfujitsang): can't find a way to test max_fps + @pytest.mark.parametrize('with_fast_renderer', [True, False]) + @pytest.mark.parametrize('world_up_axis', [0, 1, 2]) + @pytest.mark.parametrize('with_focus_at', [True, False]) + @pytest.mark.parametrize('with_sensitivity', [True, False]) + @pytest.mark.parametrize('with_additional_event', [True, False]) + @pytest.mark.parametrize('update_only_on_release', [True, False]) + def test_turntable_visualizer( + self, height, width, device, camera, renderer, fast_renderer, world_up_axis, + with_focus_at, with_sensitivity, with_additional_event, + update_only_on_release, with_fast_renderer): + kwargs = {} + + if with_focus_at: + focus_at = torch.rand((3,), device=camera.device, dtype=camera.dtype) - 0.5 * 10 + kwargs['focus_at'] = focus_at + else: + focus_at = torch.zeros((3,), device=camera.device, dtype=camera.dtype) + + if with_sensitivity: + zoom_sensitivity = 0.01 + forward_sensitivity = 0.01 + mouse_sensitivity = 2. + kwargs['zoom_sensitivity'] = zoom_sensitivity + kwargs['forward_sensitivity'] = forward_sensitivity + kwargs['mouse_sensitivity'] = mouse_sensitivity + else: + zoom_sensitivity = 0.001 + forward_sensitivity = 0.001 + mouse_sensitivity = 1.5 + + global event_count + event_count = 0 + if with_additional_event: + def additional_event_handler(visualizer, event): + with visualizer.out: + if event['type'] == 'mousedown' and event['buttons'] == 3: + global event_count + event_count += 1 + return False + return True + kwargs['additional_event_handler'] = additional_event_handler + kwargs['additional_watched_events'] = [] + + if with_fast_renderer: + kwargs['fast_render'] = fast_renderer + + viz = kaolin.visualize.IpyTurntableVisualizer( + height, + width, + copy.deepcopy(camera), + renderer, + world_up_axis=world_up_axis, + update_only_on_release=update_only_on_release, + **kwargs + ) + expected_render_count = 0 + expected_fast_render_count = 0 + def check_counts(): + if with_fast_renderer: + assert renderer.render_count == expected_render_count + assert fast_renderer.render_count == expected_fast_render_count + else: + assert renderer.render_count == expected_render_count + expected_fast_render_count + assert torch.allclose(viz.focus_at, focus_at) + check_counts() + assert viz.canvas.height == height + assert viz.canvas.width == width + + # Test reorientation at ctor + assert torch.allclose(viz.camera.cam_pos(), camera.cam_pos(), atol=1e-5, rtol=1e-5), \ + "After ctor: camera moved" + signed_world_up = torch.zeros((3,), device=camera.device) + signed_world_distance = float(camera.cam_up().squeeze()[world_up_axis] >= 0) * 2. - 1. + signed_world_up[world_up_axis] = signed_world_distance + assert torch.dot(signed_world_up, viz.camera.cam_up().squeeze()) >= 0, \ + "After ctor: camera up is wrong direction" + assert torch.dot(signed_world_up, viz.camera.cam_right().squeeze()) == 0, \ + "After ctor: camera right is not perpendicular to the world up" + + expected_cam_forward = torch.nn.functional.normalize(viz.focus_at - camera.cam_pos().squeeze(), dim=-1) + assert torch.allclose( + torch.dot(-viz.camera.cam_forward().squeeze(), expected_cam_forward), + torch.ones((1,), device=camera.device) + ), "After ctor: camera is not looking at focus_at" + + ctor_camera = copy.deepcopy(viz.camera) + ref_radius = torch.linalg.norm( + viz.focus_at - ctor_camera.cam_pos().squeeze(), + dim=-1 + ) + signed_world_right = torch.zeros((3,), device=camera.device) + signed_world_right[world_up_axis - 1] = signed_world_distance + signed_world_forward = torch.zeros((3,), device=camera.device) + signed_world_forward[world_up_axis - 2] = signed_world_distance + ctor_cam_2d_pos = torch.stack([ + viz.camera.cam_pos().squeeze()[world_up_axis - 1], + viz.camera.cam_pos().squeeze()[world_up_axis - 2], + ], dim=0) + + try: + viz.show() + except NameError: # show() use "display()" that is builtin only in ipython + pass + + expected_render_count += 1 + check_counts() + assert torch.equal(ctor_camera.view_matrix(), viz.camera.view_matrix()), \ + "After .show(): camera have moved" + assert torch.equal(ctor_camera.params, viz.camera.params), \ + "After .show(): camera intrinsics have changed" + + + from_x = random.randint(0, width) + from_y = random.randint(0, height) + viz._handle_event({'type': 'mousedown', 'relativeX': from_x, 'relativeY': from_y, 'buttons': 1}) + check_counts() + assert torch.equal(ctor_camera.view_matrix(), viz.camera.view_matrix()), \ + "After mousedown: camera have moved" + assert torch.equal(ctor_camera.params, viz.camera.params), \ + "After mousedown: camera intrinsics have changed" + + to_x = random.randint(0, width) + while to_x != from_x: + to_x = random.randint(0, width) + + to_y = random.randint(0, height) + while to_y != from_y: + to_y = random.randint(0, height) + + viz._handle_event({'type': 'mousemove', 'relativeX': to_x, 'relativeY': to_y, 'buttons': 1}) + if not update_only_on_release: + expected_fast_render_count += 1 + check_counts() + cur_radius = torch.linalg.norm( + viz.focus_at - viz.camera.cam_pos().squeeze(), + dim=-1 + ) + assert torch.allclose(cur_radius, ref_radius) + cur_focus_at = ( + viz.camera.cam_pos() - viz.camera.cam_forward() * cur_radius + ).squeeze() + assert torch.allclose(viz.focus_at, cur_focus_at, atol=1e-5, rtol=1e-5) + + azimuth_diff = mouse_sensitivity * (to_x - from_x) * math.pi / viz.canvas.width + elevation_diff = mouse_sensitivity * (to_y - from_y) * math.pi / viz.canvas.height + + cur_cam_pos = kaolin.visualize.ipython.rotate_around_axis( + ctor_camera.cam_pos().squeeze(-1) - focus_at.unsqueeze(0), + -azimuth_diff, + signed_world_up.unsqueeze(0) + ) + cur_cam_pos = kaolin.visualize.ipython.rotate_around_axis( + cur_cam_pos, + -elevation_diff, + viz.camera.cam_right().squeeze(-1), + ) + focus_at.unsqueeze(0) + assert torch.allclose(cur_cam_pos, viz.camera.cam_pos().squeeze(-1), + atol=1e-4, rtol=1e-4) + cur_camera = copy.deepcopy(viz.camera) + viz._handle_event({'type': 'mouseup', 'buttons': 1}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After mouseup: camera have moved" + assert torch.equal(cur_camera.params, viz.camera.params), \ + "After mouseup: camera intrinsics have changed" + wheel_amount = 120 * random.randint(1, 10) + viz._handle_event({'type': 'wheel', 'deltaY': wheel_amount, 'ctrlKey': False}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After unzoom: camera have moved" + assert viz.camera.fov_x > cur_camera.fov_x, \ + "After unzoom: Didn't unzoom" + assert viz.camera.fov_x < 180. + cur_camera = copy.deepcopy(viz.camera) + viz._handle_event({'type': 'wheel', 'deltaY': -2. * wheel_amount, 'ctrlKey': False}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After zoom: camera have moved" + assert viz.camera.fov_x < cur_camera.fov_x, \ + "After zoom: Didn't zoom" + assert viz.camera.fov_x > 0. + cur_camera = copy.deepcopy(viz.camera) + viz._handle_event({'type': 'wheel', 'deltaY': -wheel_amount, 'ctrlKey': True}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.params, viz.camera.params), \ + "After move forward: camera intrinsics have changed" + normalized_distance = torch.nn.functional.normalize( + cur_camera.cam_pos().squeeze() - viz.camera.cam_pos().squeeze(), + dim=-1 + ) + assert torch.allclose(cur_camera.cam_forward(), viz.camera.cam_forward()), \ + "After move forward: camera have changed cam_forward()" + assert torch.allclose(cur_camera.cam_up(), viz.camera.cam_up()), \ + "After move forward: camera have change cam_up()" + assert torch.allclose(normalized_distance, cur_camera.cam_forward().squeeze(), + atol=1e-5, rtol=1e-5), \ + "After move forward: camera haven't moved forward" + assert torch.all(torch.sign(focus_at - cur_camera.cam_pos().squeeze()) * + torch.sign(focus_at - viz.camera.cam_pos().squeeze()) >= 0.), \ + "After move forward: camera have crossed the focusing point" + + assert event_count == 0 + viz._handle_event({'type': 'mousedown', 'buttons': 3, 'relativeX': 0, 'relativeY': 0}) + check_counts() + if with_additional_event: + assert event_count == 1 + else: + assert event_count == 0 + + @pytest.mark.parametrize('with_fast_renderer', [True, False]) + @pytest.mark.parametrize('with_world_up', [True, False]) + @pytest.mark.parametrize('with_sensitivity', [True, False]) + @pytest.mark.parametrize('with_additional_event', [True, False]) + @pytest.mark.parametrize('update_only_on_release', [True, False]) + def test_first_person_visualizer( + self, height, width, device, camera, renderer, fast_renderer, + with_fast_renderer, with_world_up, with_sensitivity, + with_additional_event, update_only_on_release): + kwargs = {} + if with_fast_renderer: + kwargs['fast_render'] = fast_renderer + if with_world_up: + world_up = torch.nn.functional.normalize( + torch.rand((3,), device=camera.device, dtype=camera.dtype), + dim=-1 + ) + kwargs['world_up'] = world_up + else: + world_up = camera.cam_up().squeeze() + + if with_sensitivity: + rotation_sensitivity = 0.1 + translation_sensitivity = 0.1 + key_move_sensitivity = 0.1 + zoom_sensitivity = 0.01 + kwargs['rotation_sensitivity'] = rotation_sensitivity + kwargs['translation_sensitivity'] = translation_sensitivity + kwargs['key_move_sensitivity'] = key_move_sensitivity + kwargs['zoom_sensitivity'] = zoom_sensitivity + + up_key = 'w' + down_key = 's' + left_key = 'a' + right_key = 'd' + forward_key = 'e' + backward_key = 'q' + kwargs['up_key'] = up_key + kwargs['down_key'] = down_key + kwargs['left_key'] = left_key + kwargs['right_key'] = right_key + kwargs['forward_key'] = forward_key + kwargs['backward_key'] = backward_key + else: + rotation_sensitivity = 0.4 + translation_sensitivity = 1. + key_move_sensitivity = 0.05 + zoom_sensitivity= 0.001 + up_key = 'i' + down_key = 'k' + left_key = 'j' + right_key = 'l' + forward_key = 'o' + backward_key = 'u' + + global event_count + event_count = 0 + if with_additional_event: + def additional_event_handler(visualizer, event): + with visualizer.out: + if event['type'] == 'mousedown' and event['buttons'] == 3: + global event_count + event_count += 1 + return False + return True + kwargs['additional_event_handler'] = additional_event_handler + kwargs['additional_watched_events'] = ['mouseenter'] + + viz = kaolin.visualize.IpyFirstPersonVisualizer( + height, + width, + copy.deepcopy(camera), + renderer, + update_only_on_release=update_only_on_release, + **kwargs + ) + expected_render_count = 0 + expected_fast_render_count = 0 + def check_counts(): + if with_fast_renderer: + assert renderer.render_count == expected_render_count + assert fast_renderer.render_count == expected_fast_render_count + else: + assert renderer.render_count == expected_render_count + expected_fast_render_count + check_counts() + assert viz.canvas.height == height + assert viz.canvas.width == width + + # Test reorientation at ctor + expected_extrinsics = kaolin.render.camera.CameraExtrinsics.from_lookat( + eye=camera.cam_pos().squeeze(), + at=(camera.cam_pos().squeeze() - camera.cam_forward().squeeze()), + up=world_up, + device=camera.device, + dtype=camera.dtype + ) + assert torch.allclose(expected_extrinsics.view_matrix(), viz.camera.view_matrix(), + atol=1e-5, rtol=1e-5) + ctor_camera = copy.deepcopy(viz.camera) + + try: + viz.show() + except NameError: # show() use "display()" that is builtin only in ipython + pass + + expected_render_count += 1 + check_counts() + assert torch.equal(ctor_camera.view_matrix(), viz.camera.view_matrix()), \ + "After .show(): camera have moved" + assert torch.equal(ctor_camera.params, viz.camera.params), \ + "After .show(): camera intrinsics have changed" + + from_x = random.randint(0, width) + from_y = random.randint(0, height) + viz._handle_event({'type': 'mousedown', 'relativeX': from_x, 'relativeY': from_y, 'buttons': 1}) + check_counts() + assert torch.equal(ctor_camera.view_matrix(), viz.camera.view_matrix()), \ + "After mousedown: camera have moved" + assert torch.equal(ctor_camera.params, viz.camera.params), \ + "After mousedown: camera intrinsics have changed" + + to_x = random.randint(0, width) + while to_x != from_x: + to_x = random.randint(0, width) + + to_y = random.randint(0, height) + while to_y != from_y: + to_y = random.randint(0, height) + + ctor_elevation = viz.elevation + + viz._handle_event({'type': 'mousemove', 'relativeX': to_x, 'relativeY': to_y, 'buttons': 1}) + if not update_only_on_release: + expected_fast_render_count += 1 + check_counts() + + azimuth_diff = rotation_sensitivity * (to_x - from_x) * math.pi / viz.canvas.width + elevation_diff = rotation_sensitivity * (to_y - from_y) * math.pi / viz.canvas.height + _elevation = ctor_elevation + elevation_diff + if _elevation > math.pi / 2.: + elevation_diff = math.pi / 2. - ctor_elevation + if _elevation < -math.pi / 2.: + elevation_diff = -math.pi / 2. - ctor_elevation + assert viz.elevation == ctor_elevation + elevation_diff + + cur_cam_forward = kaolin.visualize.ipython.rotate_around_axis( + ctor_camera.cam_forward().squeeze(-1), + -azimuth_diff, + world_up.unsqueeze(0) + ) + cur_cam_right = kaolin.visualize.ipython.rotate_around_axis( + ctor_camera.cam_right().squeeze(-1), + -azimuth_diff, + world_up.unsqueeze(0) + ) + cur_cam_up = kaolin.visualize.ipython.rotate_around_axis( + ctor_camera.cam_up().squeeze(-1), + -azimuth_diff, + world_up.unsqueeze(0) + ) + + cur_cam_forward = kaolin.visualize.ipython.rotate_around_axis( + cur_cam_forward, + -elevation_diff, + cur_cam_right, + ) + cur_cam_up = kaolin.visualize.ipython.rotate_around_axis( + cur_cam_up, + -elevation_diff, + cur_cam_right, + ) + + assert torch.allclose(ctor_camera.cam_pos().squeeze(-1), viz.camera.cam_pos().squeeze(-1), + atol=1e-4, rtol=1e-4) + assert torch.allclose(cur_cam_right, viz.camera.cam_right().squeeze(-1), + atol=1e-4, rtol=1e-4) + assert torch.allclose(cur_cam_forward, viz.camera.cam_forward().squeeze(-1), + atol=1e-4, rtol=1e-4) + assert torch.allclose(cur_cam_up, viz.camera.cam_up().squeeze(-1), + atol=1e-4, rtol=1e-4) + cur_camera = copy.deepcopy(viz.camera) + + viz._handle_event({'type': 'mouseup'}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After mouseup: camera have moved" + assert torch.equal(cur_camera.params, viz.camera.params), \ + "After mouseup: camera intrinsics have changed" + + from_x = random.randint(0, width) + from_y = random.randint(0, height) + + viz._handle_event({ + 'type': 'mousedown', 'relativeX': from_x, 'relativeY': from_y, 'buttons': 2 + }) + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After mousedown: camera have moved" + assert torch.equal(cur_camera.params, viz.camera.params), \ + "After mousedown: camera intrinsics have changed" + + to_x = random.randint(0, width) + while to_x != from_x: + to_x = random.randint(0, width) + + to_y = random.randint(0, height) + while to_y != from_y: + to_y = random.randint(0, height) + + viz._handle_event({ + 'type': 'mousemove', 'relativeX': to_x, 'relativeY': to_y, 'buttons': 2 + }) + if not update_only_on_release: + expected_fast_render_count += 1 + check_counts() + + cur_camera.move_up(translation_sensitivity * (to_y - from_y) / height) + cur_camera.move_right(-translation_sensitivity * (to_x - from_x) / width) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'mouseup'}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After mouseup: camera have moved" + assert torch.equal(cur_camera.params, viz.camera.params), \ + "After mouseup: camera intrinsics have changed" + + viz._handle_event({'type': 'keydown', 'key': up_key}) + expected_fast_render_count += 1 + check_counts() + cur_camera.move_up(key_move_sensitivity) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': up_key}) + expected_render_count += 1 + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keydown', 'key': down_key}) + expected_fast_render_count += 1 + check_counts() + cur_camera.move_up(-key_move_sensitivity) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': down_key}) + expected_render_count += 1 + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keydown', 'key': left_key}) + expected_fast_render_count += 1 + check_counts() + cur_camera.move_right(-key_move_sensitivity) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': left_key}) + expected_render_count += 1 + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keydown', 'key': right_key}) + expected_fast_render_count += 1 + check_counts() + cur_camera.move_right(key_move_sensitivity) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': right_key}) + expected_render_count += 1 + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keydown', 'key': forward_key}) + expected_fast_render_count += 1 + check_counts() + cur_camera.move_forward(-key_move_sensitivity) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': forward_key}) + expected_render_count += 1 + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keydown', 'key': backward_key}) + expected_fast_render_count += 1 + check_counts() + cur_camera.move_forward(key_move_sensitivity) + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': backward_key}) + expected_render_count += 1 + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keydown', 'key': 'x'}) + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + viz._handle_event({'type': 'keyup', 'key': 'x'}) + check_counts() + assert torch.allclose(cur_camera.view_matrix(), viz.camera.view_matrix()) + assert torch.allclose(cur_camera.params, viz.camera.params) + + wheel_amount = 120 * random.randint(1, 10) + viz._handle_event({'type': 'wheel', 'deltaY': wheel_amount}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After unzoom: camera have moved" + assert viz.camera.fov_x > cur_camera.fov_x, \ + "After unzoom: Didn't unzoom" + assert viz.camera.fov_x < 180. + cur_camera = copy.deepcopy(viz.camera) + viz._handle_event({'type': 'wheel', 'deltaY': -2. * wheel_amount}) + expected_render_count += 1 + check_counts() + assert torch.equal(cur_camera.view_matrix(), viz.camera.view_matrix()), \ + "After zoom: camera have moved" + assert viz.camera.fov_x < cur_camera.fov_x, \ + "After zoom: Didn't zoom" + assert viz.camera.fov_x > 0. + + assert event_count == 0 + viz._handle_event({'type': 'mousedown', 'buttons': 3, 'relativeX': 0, 'relativeY': 0}) + check_counts() + if with_additional_event: + assert event_count == 1 + else: + assert event_count == 0 + diff --git a/tools/doc_requirements.txt b/tools/doc_requirements.txt index 2c9a8bae4..fd26b35f8 100644 --- a/tools/doc_requirements.txt +++ b/tools/doc_requirements.txt @@ -1,3 +1,3 @@ -setuptools>=50.3 +setuptools==58.0.0 sphinx>=3.5.4,<5.1.0 sphinx_rtd_theme==1.0.0 diff --git a/tools/linux/Dockerfile.install b/tools/linux/Dockerfile.install index 375e0d010..ffbe279b4 100644 --- a/tools/linux/Dockerfile.install +++ b/tools/linux/Dockerfile.install @@ -34,6 +34,7 @@ RUN echo "Acquire { https::Verify-Peer false }" > /etc/apt/apt.conf.d/99verify-p libglvnd-dev \ curl \ cmake \ + xvfb \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* @@ -60,7 +61,8 @@ RUN npm install RUN pip install --upgrade pip && \ pip install --no-cache-dir setuptools==58.0.0 ninja cython==0.29.20 \ - imageio imageio-ffmpeg + imageio imageio-ffmpeg && \ + pip install --no-cache-dir -r tools/viz_requirements.txt RUN cd /tmp && \ git clone https://github.com/NVlabs/nvdiffrast && \ diff --git a/tools/viz_requirements.txt b/tools/viz_requirements.txt new file mode 100644 index 000000000..6c4b6ccad --- /dev/null +++ b/tools/viz_requirements.txt @@ -0,0 +1,4 @@ +ipycanvas +ipyevents +jupyter_client<8 +pyzmq<25