diff --git a/learn_the_basics.rst b/learn_the_basics.rst index 040ac35d..c6a2edba 100644 --- a/learn_the_basics.rst +++ b/learn_the_basics.rst @@ -4,6 +4,11 @@ Learn the basics .. grid:: 1 1 3 3 :gutter: 4 + .. grid-item-card:: Transpiling Functions from PyTorch to TensorFlow + :link: learn_the_basics/torch_to_tf_functions.ipynb + + Transpiling Kornia functions to TensorFlow. + .. grid-item-card:: Trace Code :link: learn_the_basics/03_trace_code.ipynb @@ -23,6 +28,7 @@ Learn the basics :hidden: :maxdepth: -1 + learn_the_basics/torch_to_tf_functions.ipynb learn_the_basics/03_trace_code.ipynb learn_the_basics/05_lazy_vs_eager.ipynb learn_the_basics/06_how_to_use_decorators.ipynb diff --git a/learn_the_basics/torch_to_tf_functions.ipynb b/learn_the_basics/torch_to_tf_functions.ipynb new file mode 100644 index 00000000..2c4403ff --- /dev/null +++ b/learn_the_basics/torch_to_tf_functions.ipynb @@ -0,0 +1,326 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Transpiling Functions from PyTorch to TensorFlow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can install the dependencies required for this notebook by running the cell below ⬇️, or check out the [Get Started](https://ivy.dev/docs/overview/get_started.html) section of the docs to find out more about installing ivy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install ivy\n", + "!pip install torch\n", + "!pip install tensorflow\n", + "!pip install kornia" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we'll go through an example of how any function using torch can be converted, and used in, tensorflow via `ivy.transpile`. We'll use kornia as our example, which is a state-of-the-art computer vision library built on top of torch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, some boiler plate imports:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import ivy\n", + "import kornia\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, lets transpile a kornia function that we want to use in tensorflow. The `ivy.transpile` call returns a new tensorflow function which is mathematically equivalent to the torch function we passed. This can take up to a minute to run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf_rgb_to_grayscale = ivy.transpile(kornia.color.rgb_to_grayscale, source=\"torch\", target=\"tensorflow\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now use this function exactly as the original kornia function would be, just passing tensorflow tensors rather than torch tensors:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# using the original torch function\n", + "kornia.color.rgb_to_grayscale(torch.rand(1, 3, 28, 28))\n", + "\n", + "# using the transpiled tensorflow function\n", + "tf_rgb_to_grayscale(tf.random.uniform((1, 3, 28, 28)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, lets check that the outputs of both the original torch function and transpiled tensorflow are identical when given the same inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_x = torch.rand(1, 3, 28, 28)\n", + "tf_x = tf.convert_to_tensor(torch_x.numpy())\n", + "\n", + "torch_out = kornia.color.rgb_to_grayscale(torch_x)\n", + "tf_out = tf_rgb_to_grayscale(tf_x)\n", + "\n", + "np.allclose(torch_out.numpy(), tf_out.numpy(), atol=1e-6)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}