{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/eric/miniconda3/envs/dl3/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.\n", "Loading checkpoint shards: 100%|██████████| 33/33 [00:08<00:00, 4.08it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 262410240 || all params: 6738415616 || trainable%: 3.894242429584207\n", "trainable params: 8388608 || all params: 6746804224 || trainable%: 0.12433454005023165\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import os\n", "\n", "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "import torch\n", "import torch.nn as nn\n", "import bitsandbytes as bnb\n", "from dataset import load_dataset\n", "import transformers\n", "from transformers import AutoTokenizer, AutoConfig, LLaMAForCausalLM, LLaMATokenizer\n", "from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model\n", "\n", "model = LLaMAForCausalLM.from_pretrained(\n", " \"./7B/llama-7b\",\n", " load_in_8bit=True,\n", " device_map=\"auto\",\n", ")\n", "\n", "tokenizer = LLaMATokenizer.from_pretrained(\"./7B/tokenizer\")\n", "\n", "\n", "def print_trainable_parameters(model):\n", " \"\"\"\n", " Prints the number of trainable parameters in the model.\n", " \"\"\"\n", " trainable_params = 0\n", " all_param = 0\n", " for _, param in model.named_parameters():\n", " all_param += param.numel()\n", " if param.requires_grad:\n", " trainable_params += param.numel()\n", " print(\n", " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", " )\n", "\n", "\n", "print_trainable_parameters(model)\n", "model = prepare_model_for_int8_training(model)\n", "\n", "config = LoraConfig(\n", " r=16,\n", " lora_alpha=32,\n", " target_modules=[\"q_proj\", \"v_proj\"],\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", ")\n", "model = get_peft_model(model, config)\n", "\n", "print_trainable_parameters(model)\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.pad_token_id = tokenizer.eos_token_id" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", "\n", "\u001b[A\u001b[A" ] }, { "ename": "OutOfMemoryError", "evalue": "CUDA out of memory. Tried to allocate 22.00 MiB (GPU 0; 23.65 GiB total capacity; 19.92 GiB already allocated; 41.31 MiB free; 20.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 26\u001b[0m\n\u001b[1;32m 9\u001b[0m trainer \u001b[39m=\u001b[39m transformers\u001b[39m.\u001b[39mTrainer(\n\u001b[1;32m 10\u001b[0m model\u001b[39m=\u001b[39mmodel,\n\u001b[1;32m 11\u001b[0m train_dataset\u001b[39m=\u001b[39mdata[\u001b[39m\"\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m data_collator\u001b[39m=\u001b[39mtransformers\u001b[39m.\u001b[39mDataCollatorForLanguageModeling(tokenizer, mlm\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m),\n\u001b[1;32m 24\u001b[0m )\n\u001b[1;32m 25\u001b[0m model\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_cache \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m \u001b[39m# silence the warnings. Please re-enable for inference!\u001b[39;00m\n\u001b[0;32m---> 26\u001b[0m trainer\u001b[39m.\u001b[39;49mtrain()\n\u001b[1;32m 28\u001b[0m model\u001b[39m.\u001b[39msave_pretrained(\u001b[39m\"\u001b[39m\u001b[39moutputs\u001b[39m\u001b[39m\"\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/trainer.py:1628\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1623\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel_wrapped \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\n\u001b[1;32m 1625\u001b[0m inner_training_loop \u001b[39m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1626\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_inner_training_loop, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_train_batch_size, args\u001b[39m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1627\u001b[0m )\n\u001b[0;32m-> 1628\u001b[0m \u001b[39mreturn\u001b[39;00m inner_training_loop(\n\u001b[1;32m 1629\u001b[0m args\u001b[39m=\u001b[39;49margs,\n\u001b[1;32m 1630\u001b[0m resume_from_checkpoint\u001b[39m=\u001b[39;49mresume_from_checkpoint,\n\u001b[1;32m 1631\u001b[0m trial\u001b[39m=\u001b[39;49mtrial,\n\u001b[1;32m 1632\u001b[0m ignore_keys_for_eval\u001b[39m=\u001b[39;49mignore_keys_for_eval,\n\u001b[1;32m 1633\u001b[0m )\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/trainer.py:1895\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1893\u001b[0m tr_loss_step \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining_step(model, inputs)\n\u001b[1;32m 1894\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1895\u001b[0m tr_loss_step \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtraining_step(model, inputs)\n\u001b[1;32m 1897\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 1898\u001b[0m args\u001b[39m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1899\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1900\u001b[0m \u001b[39mand\u001b[39;00m (torch\u001b[39m.\u001b[39misnan(tr_loss_step) \u001b[39mor\u001b[39;00m torch\u001b[39m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1901\u001b[0m ):\n\u001b[1;32m 1902\u001b[0m \u001b[39m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1903\u001b[0m tr_loss \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m tr_loss \u001b[39m/\u001b[39m (\u001b[39m1\u001b[39m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mglobal_step \u001b[39m-\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_globalstep_last_logged)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/trainer.py:2637\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2634\u001b[0m \u001b[39mreturn\u001b[39;00m loss_mb\u001b[39m.\u001b[39mreduce_mean()\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mdevice)\n\u001b[1;32m 2636\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2637\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcompute_loss(model, inputs)\n\u001b[1;32m 2639\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mn_gpu \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 2640\u001b[0m loss \u001b[39m=\u001b[39m loss\u001b[39m.\u001b[39mmean() \u001b[39m# mean() to average on multi-gpu parallel training\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/trainer.py:2669\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2667\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2668\u001b[0m labels \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m-> 2669\u001b[0m outputs \u001b[39m=\u001b[39m model(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49minputs)\n\u001b[1;32m 2670\u001b[0m \u001b[39m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 2671\u001b[0m \u001b[39m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 2672\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mpast_index \u001b[39m>\u001b[39m\u001b[39m=\u001b[39m \u001b[39m0\u001b[39m:\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/nn/modules/module.py:1488\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1485\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1486\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1487\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1488\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1489\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1490\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/peft-0.3.0.dev0-py3.10.egg/peft/peft_model.py:530\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 519\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 520\u001b[0m input_ids\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 528\u001b[0m ):\n\u001b[1;32m 529\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpeft_config, PromptLearningConfig):\n\u001b[0;32m--> 530\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbase_model(\n\u001b[1;32m 531\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_ids,\n\u001b[1;32m 532\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 533\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m 534\u001b[0m labels\u001b[39m=\u001b[39;49mlabels,\n\u001b[1;32m 535\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 536\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 537\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 538\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 539\u001b[0m )\n\u001b[1;32m 541\u001b[0m batch_size \u001b[39m=\u001b[39m input_ids\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m 542\u001b[0m \u001b[39mif\u001b[39;00m attention_mask \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 543\u001b[0m \u001b[39m# concat prompt attention mask\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/nn/modules/module.py:1488\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1485\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1486\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1487\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1488\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1489\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1490\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/accelerate/hooks.py:158\u001b[0m, in \u001b[0;36madd_hook_to_module..new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 156\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 157\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 158\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:772\u001b[0m, in \u001b[0;36mLLaMAForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 769\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[1;32m 771\u001b[0m \u001b[39m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 772\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 773\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_ids,\n\u001b[1;32m 774\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 775\u001b[0m past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m 776\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m 777\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 778\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 779\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 780\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 781\u001b[0m )\n\u001b[1;32m 783\u001b[0m hidden_states \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 784\u001b[0m logits \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlm_head(hidden_states)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/nn/modules/module.py:1488\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1485\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1486\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1487\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1488\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1489\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1490\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/accelerate/hooks.py:158\u001b[0m, in \u001b[0;36madd_hook_to_module..new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 156\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 157\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 158\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:621\u001b[0m, in \u001b[0;36mLLaMAModel.forward\u001b[0;34m(self, input_ids, attention_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 614\u001b[0m layer_outputs \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mcheckpoint\u001b[39m.\u001b[39mcheckpoint(\n\u001b[1;32m 615\u001b[0m create_custom_forward(decoder_layer),\n\u001b[1;32m 616\u001b[0m hidden_states,\n\u001b[1;32m 617\u001b[0m attention_mask,\n\u001b[1;32m 618\u001b[0m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 619\u001b[0m )\n\u001b[1;32m 620\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 621\u001b[0m layer_outputs \u001b[39m=\u001b[39m decoder_layer(\n\u001b[1;32m 622\u001b[0m hidden_states,\n\u001b[1;32m 623\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 624\u001b[0m past_key_value\u001b[39m=\u001b[39;49mpast_key_value,\n\u001b[1;32m 625\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 626\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 627\u001b[0m )\n\u001b[1;32m 629\u001b[0m hidden_states \u001b[39m=\u001b[39m layer_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 631\u001b[0m \u001b[39mif\u001b[39;00m use_cache:\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/nn/modules/module.py:1488\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1485\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1486\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1487\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1488\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1489\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1490\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/accelerate/hooks.py:158\u001b[0m, in \u001b[0;36madd_hook_to_module..new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 156\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 157\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 158\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:329\u001b[0m, in \u001b[0;36mLLaMADecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, output_attentions, use_cache, past_key_value)\u001b[0m\n\u001b[1;32m 327\u001b[0m residual \u001b[39m=\u001b[39m hidden_states\n\u001b[1;32m 328\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 329\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmlp(hidden_states)\n\u001b[1;32m 330\u001b[0m hidden_states \u001b[39m=\u001b[39m residual \u001b[39m+\u001b[39m hidden_states\n\u001b[1;32m 332\u001b[0m outputs \u001b[39m=\u001b[39m (hidden_states,)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/nn/modules/module.py:1488\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1485\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1486\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1487\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1488\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1489\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1490\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/accelerate/hooks.py:158\u001b[0m, in \u001b[0;36madd_hook_to_module..new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 156\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 157\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 158\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:161\u001b[0m, in \u001b[0;36mLLaMAMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[0;32m--> 161\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdown_proj(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mact_fn(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgate_proj(x)) \u001b[39m*\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mup_proj(x))\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/nn/modules/module.py:1488\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1485\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1486\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1487\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1488\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1489\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1490\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/accelerate/hooks.py:158\u001b[0m, in \u001b[0;36madd_hook_to_module..new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 156\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 157\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 158\u001b[0m output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:242\u001b[0m, in \u001b[0;36mLinear8bitLt.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias\u001b[39m.\u001b[39mdtype \u001b[39m!=\u001b[39m x\u001b[39m.\u001b[39mdtype:\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias\u001b[39m.\u001b[39mdata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mto(x\u001b[39m.\u001b[39mdtype)\n\u001b[0;32m--> 242\u001b[0m out \u001b[39m=\u001b[39m bnb\u001b[39m.\u001b[39;49mmatmul(x, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, bias\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, state\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstate)\n\u001b[1;32m 243\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mhas_fp16_weights:\n\u001b[1;32m 244\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mCB \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mCxB \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 245\u001b[0m \u001b[39m# we converted 8-bit row major to turing/ampere format in the first inference pass\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39m# we no longer need the row-major weight\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:488\u001b[0m, in \u001b[0;36mmatmul\u001b[0;34m(A, B, out, state, threshold, bias)\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[39mif\u001b[39;00m threshold \u001b[39m>\u001b[39m \u001b[39m0.0\u001b[39m:\n\u001b[1;32m 487\u001b[0m state\u001b[39m.\u001b[39mthreshold \u001b[39m=\u001b[39m threshold\n\u001b[0;32m--> 488\u001b[0m \u001b[39mreturn\u001b[39;00m MatMul8bitLt\u001b[39m.\u001b[39;49mapply(A, B, out, bias, state)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/torch/autograd/function.py:508\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 505\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m torch\u001b[39m.\u001b[39m_C\u001b[39m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m 506\u001b[0m \u001b[39m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m 507\u001b[0m args \u001b[39m=\u001b[39m _functorch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 508\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mapply(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 510\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39msetup_context \u001b[39m==\u001b[39m _SingleLevelFunction\u001b[39m.\u001b[39msetup_context:\n\u001b[1;32m 511\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 512\u001b[0m \u001b[39m'\u001b[39m\u001b[39mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 513\u001b[0m \u001b[39m'\u001b[39m\u001b[39m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 514\u001b[0m \u001b[39m'\u001b[39m\u001b[39mstaticmethod. For more details, please see \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 515\u001b[0m \u001b[39m'\u001b[39m\u001b[39mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[39m'\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/dl3/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:397\u001b[0m, in \u001b[0;36mMatMul8bitLt.forward\u001b[0;34m(ctx, A, B, out, bias, state)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[39m# 4. Mixed-precision decomposition matmul\u001b[39;00m\n\u001b[1;32m 396\u001b[0m \u001b[39mif\u001b[39;00m coo_tensorA \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m subA \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 397\u001b[0m output \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mmatmul(subA, state\u001b[39m.\u001b[39;49msubB)\n\u001b[1;32m 399\u001b[0m \u001b[39m# 5. Save state\u001b[39;00m\n\u001b[1;32m 400\u001b[0m ctx\u001b[39m.\u001b[39mstate \u001b[39m=\u001b[39m state\n", "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 22.00 MiB (GPU 0; 23.65 GiB total capacity; 19.92 GiB already allocated; 41.31 MiB free; 20.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" ] } ], "source": [ "data = load_dataset(\"laion/OIG\", streaming=True)\n", "data = data.shuffle(seed=42).map(\n", " lambda x: tokenizer(\n", " x[\"text\"], truncation=True, max_length=256, padding=\"max_length\"\n", " ),\n", " batched=True,\n", ")\n", "\n", "trainer = transformers.Trainer(\n", " model=model,\n", " train_dataset=data[\"train\"],\n", " args=transformers.TrainingArguments(\n", " per_device_train_batch_size=4,\n", " gradient_accumulation_steps=4,\n", " warmup_steps=1000,\n", " max_steps=6000,\n", " learning_rate=2e-4,\n", " fp16=True,\n", " logging_steps=1,\n", " output_dir=\"checkpoints\",\n", " save_total_limit=3,\n", " ),\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n", ")\n", "model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n", "trainer.train()\n", "\n", "model.save_pretrained(\"outputs\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "dl3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "90bfda469df5ac7fed8d7e225d563f60a7a7aa420ccfadb091c914debf775e49" } } }, "nbformat": 4, "nbformat_minor": 2 }