{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Rectified Linear Unit\n", "\n", "The Rectified Linear Unit (ReLU) function is widely used in machine learning and defined as follows:\n", "\n", "$$\n", "ReLU(x) = \\left\\{\n", " \\begin{array}\\\\\n", " 0 & if\\ x\\leq0 \\\\\n", " x & if\\ x>0\n", " \\end{array}\n", " \\right.\n", "$$\n", " \n", "In this example we apply the rectified linear unit function to several values designed to illustrate its behavior. As with other Cicada functions, ReLU operates element-wise on arrays of any shape." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Player 0 values: [-5 -1 0 1 5]\n", "INFO:root:Player 1 values: None\n", "INFO:root:Player 2 values: None\n", "INFO:root:Player 0 relu: [0. 0. 0. 1. 5.]\n", "INFO:root:Player 1 relu: [0. 0. 0. 1. 5.]\n", "INFO:root:Player 2 relu: [0. 0. 0. 1. 5.]\n" ] } ], "source": [ "import logging\n", "\n", "import numpy\n", "\n", "from cicada.additive import AdditiveProtocolSuite\n", "from cicada.communicator import SocketCommunicator\n", "from cicada.logger import Logger\n", "\n", "logging.basicConfig(level=logging.INFO)\n", "\n", "def main(communicator):\n", " log = Logger(logging.getLogger(), communicator)\n", " protocol = AdditiveProtocolSuite(communicator)\n", "\n", " values = numpy.array([-5, -1, 0, 1, 5]) if communicator.rank == 0 else None\n", " log.info(f\"Player {communicator.rank} values: {values}\")\n", "\n", " values_share = protocol.share(src=0, secret=values, shape=(5,))\n", " relu_share = protocol.relu(values_share)\n", " relu = protocol.reveal(relu_share)\n", "\n", " log.info(f\"Player {communicator.rank} relu: {relu}\")\n", "\n", "SocketCommunicator.run(world_size=3, fn=main);" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 4 }