diff --git a/DEMO_Estimate_the_Memory_Consumption_of_the_KV_Cache.ipynb b/DEMO_Estimate_the_Memory_Consumption_of_the_KV_Cache.ipynb new file mode 100644 index 0000000..7b5f801 --- /dev/null +++ b/DEMO_Estimate_the_Memory_Consumption_of_the_KV_Cache.ipynb @@ -0,0 +1,679 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "10aa92a552fe48dbbeea1737d40fe0b7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_882f53a067254928b788f4d183ef5c62", + "IPY_MODEL_b9661dd3895f4d88a0b0a4a6401f8ac4", + "IPY_MODEL_52060bf55f6c4af8a29534c0d49ad536" + ], + "layout": "IPY_MODEL_9974ca19bb08471584c6191d38f9ac26" + } + }, + "882f53a067254928b788f4d183ef5c62": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_60dd1cdf6c6b40dc8a73efd918ab80a1", + "placeholder": "​", + "style": "IPY_MODEL_1c017a88d17c4d1f8707abf3367181da", + "value": "config.json: 100%" + } + }, + "b9661dd3895f4d88a0b0a4a6401f8ac4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_50a045105a784fed8cd3bd25ec27ffda", + "max": 2182, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b0419ff90fcc45348aa1eab896f0b960", + "value": 2182 + } + }, + "52060bf55f6c4af8a29534c0d49ad536": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b99e9f7333a545e9a480e444d450807f", + "placeholder": "​", + "style": "IPY_MODEL_9f49f0cdc3cd4772af5f358d6188a6d3", + "value": " 2.18k/2.18k [00:00<00:00, 75.8kB/s]" + } + }, + "9974ca19bb08471584c6191d38f9ac26": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "60dd1cdf6c6b40dc8a73efd918ab80a1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c017a88d17c4d1f8707abf3367181da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "50a045105a784fed8cd3bd25ec27ffda": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b0419ff90fcc45348aa1eab896f0b960": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b99e9f7333a545e9a480e444d450807f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9f49f0cdc3cd4772af5f358d6188a6d3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "This notebook estimates the memory consumption of the KV cache. It works for Transformer models with and without Grouped-Query Attention (GQA).\n", + "\n", + "To get the estimation, run all the cells.\n", + "\n", + "\n", + "*Important Note: The memory consumption computed by this notebook is an **estimate**. In practice, many factors can decrease or increase consumption depending on the optimizations implemented by the inference framework or specific architectural tricks implemented by the model.*\n", + "\n", + "\n", + "Formula was from [Nvidia](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization) and [this repo](https://github.com/Leon-Sander/KV-Cache-Calculator?utm_source=chatgpt.com)\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "Zny8982jpQ9O" + } + }, + { + "cell_type": "markdown", + "source": [ + "## KV Cache Memory Formula\n", + "\n", + "To estimate memory usage of the key-value cache during inference:\n", + "\n", + "$$\n", + "M_{\\text{KVCache}} = 2 \\cdot L \\cdot s \\cdot \\frac{h}{g} \\cdot b\n", + "$$\n", + "\n", + "Where: \n", + "- \\( L \\): number of transformer layers \n", + "- \\( s \\): sequence length \n", + "- \\( h \\): hidden size (embedding dim) \n", + "- \\( g \\): number of attention heads \n", + "- \\( b \\): batch size \n", + "- Factor \\( 2 \\) accounts for both key and value caches\n" + ], + "metadata": { + "id": "yqW8ZscGTliR" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Make Sure Your Transformers Is Up-To-Date" + ], + "metadata": { + "id": "TE_E_RGKTYZb" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade git+https://github.com/huggingface/transformers" + ], + "metadata": { + "id": "SdV3LUtH2YK1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!huggingface-cli login" + ], + "metadata": { + "id": "bMleAv7JQWdA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Get the Model's Configuration" + ], + "metadata": { + "id": "p-u_NDdNTpQ5" + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "In the following interactive cell, enter the name of the model. It can be the name of the repository on the Hugging Face Hub or a local path.\n", + "This cell retrieves the architecture of the model." + ], + "metadata": { + "id": "u3wJXIeXpKTV" + } + }, + { + "cell_type": "code", + "source": [ + "from transformers import AutoConfig\n", + "\n", + "model_name = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n", + "\n", + "\n", + "model_config = AutoConfig.from_pretrained(model_name)\n", + "#For Llama 4, we are only interested in the text part" + ], + "metadata": { + "id": "PAaucRvmmf0F", + "outputId": "3c1b5915-755d-4f00-9c2a-0926bb5f59e3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "10aa92a552fe48dbbeea1737d40fe0b7", + "882f53a067254928b788f4d183ef5c62", + "b9661dd3895f4d88a0b0a4a6401f8ac4", + "52060bf55f6c4af8a29534c0d49ad536", + "9974ca19bb08471584c6191d38f9ac26", + "60dd1cdf6c6b40dc8a73efd918ab80a1", + "1c017a88d17c4d1f8707abf3367181da", + "50a045105a784fed8cd3bd25ec27ffda", + "b0419ff90fcc45348aa1eab896f0b960", + "b99e9f7333a545e9a480e444d450807f", + "9f49f0cdc3cd4772af5f358d6188a6d3" + ] + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "config.json: 0%| | 0.00/2.18k [00:00 0:\n", + " print(\"Key-value heads (g): \"+str(kv_heads))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SWeYIu79Up_Q", + "outputId": "3bb8f9ae-448d-4b56-9b37-d8752fa31f3b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model: meta-llama/Llama-4-Scout-17B-16E-Instruct\n", + "Hidden layers (L): 48\n", + "Hidden size (h): 5120\n", + "Attention heads (a): 40\n", + "Key-value heads (g): 8\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Set Inference Hyperparameters" + ], + "metadata": { + "id": "WgcseYKuTsBy" + } + }, + { + "cell_type": "markdown", + "source": [ + "In the following interactive cell enter:\n", + "- bitwidth: The bit-width of the KV cache.\n", + "- seqlen: The maximum sequence length in your batches.\n", + "- batch_size: The number of instances in one batch." + ], + "metadata": { + "id": "rWkpPCD0GLuo" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mMAwju0nKh3r", + "outputId": "82b15eb6-246b-451a-9f8d-8924ed95bd59" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Bit-width of the KV cache: 16-bit\n", + "Sequence length (s): 5000\n", + "Batch size (b): 1\n" + ] + } + ], + "source": [ + "#The bit-witdh of the KV cache, e.g., 16-bit if not quantized\n", + "bitwidth = 16 # @param {type:\"number\"}\n", + "print(\"Bit-width of the KV cache: \"+str(bitwidth)+\"-bit\")\n", + "\n", + "#The maximum number of tokens in a sequence\n", + "seqlen = 5000 # @param {type:\"integer\"}\n", + "print(\"Sequence length (s): \"+str(seqlen))\n", + "\n", + "#The batch size\n", + "batch_size = 1 # @param {type:\"integer\"}\n", + "print(\"Batch size (b): \"+str(batch_size))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Estimate" + ], + "metadata": { + "id": "CjeTr7awTuvg" + } + }, + { + "cell_type": "markdown", + "source": [ + "Run the following cell to get the estimation given the information provided in the previous cells." + ], + "metadata": { + "id": "9MOm8LVQHm_i" + } + }, + { + "cell_type": "code", + "source": [ + "def kv_cache():\n", + " return round(2*hidden_layers*seqlen*batch_size*hidden_size*2/(1000**3),4)\n", + "\n", + "def kv_cache_gqa():\n", + " return round(2*hidden_layers*seqlen*batch_size*(hidden_size/kv_heads)*2/(1000**3),4)\n", + "\n", + "if kv_heads > 0:\n", + " kv_cache_cost = kv_cache_gqa()\n", + " print(\"Memory consumption of the KV cache (with GQA): \"+str(kv_cache_cost)+\" GB \\n\")\n", + "else:\n", + " kv_cache_cost = kv_cache()\n", + " print(\"Memory consumption of the KV cache: \"+str(kv_cache_cost)+\" GB \\n\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tgSFQfZcOEKx", + "outputId": "afc9831a-d29c-4585-814b-e607b13b8740" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Memory consumption of the KV cache (with GQA): 0.6144 GB \n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "m8RvNMdMVVZM" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file