{ "cells": [ { "cell_type": "markdown", "id": "41f0fa12", "metadata": {}, "source": [ "## Example: Sequence Data\n", "\n", "Next, let's consider the AAV dataset, designed vs mutant split, \n", "from the FLIP benchmark suite. For this dataset, we train on 200,000\n", "length 57 amino acid sequences and try to predict the fitness\n", "of a pre-specified test set. Dallago et al. report that a standard\n", "1d-CNN trained on this achieves a Spearman's r of 0.75, while\n", "a 750-million parameter pretrained model that took 50 GPU-days of\n", "time to train achieves Spearman's r of 0.79.\n", "\n", "We'll evaluate a convolution kernel and show that using one-hot\n", "encoded input we can easily match the 1dCNN. (If we use embeddings\n", "as input, we can easily match the pretrained model, although we\n", "won't do that comparison here.)\n", "\n", "This was run using xGPR v0.4.8." ] }, { "cell_type": "code", "execution_count": 1, "id": "b9a773dd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/ssd1/Documents/gp_proteins/venv_testing/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import shutil\n", "import subprocess\n", "import math\n", "import time\n", "import zipfile\n", "\n", "import pandas as pd\n", "import numpy as np\n", "\n", "from xGPR import xGPRegression as xGPReg\n", "from xGPR import build_regression_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "a42875b3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Cloning into 'FLIP'...\n" ] } ], "source": [ "#This may take a minute, because we have to download the data here...\n", "subprocess.run([\"git\", \"clone\", \"https://github.com/J-SNACKKB/FLIP\"])\n", "\n", "shutil.move(os.path.join(\"FLIP\", \"splits\", \"aav\", \"full_data.csv.zip\"), \"full_data.csv.zip\")\n", "fname = \"full_data.csv.zip\"\n", "\n", "with zipfile.ZipFile(fname, \"r\") as zip_ref:\n", " zip_ref.extractall()\n", "\n", "os.remove(\"full_data.csv.zip\")\n", "\n", "\n", "shutil.rmtree(\"FLIP\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "4e5dc401", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_194426/4195622896.py:1: DtypeWarning: Columns (12) have mixed types. Specify dtype option on import or set low_memory=False.\n", " raw_data = pd.read_csv(\"full_data.csv\")\n" ] } ], "source": [ "raw_data = pd.read_csv(\"full_data.csv\")\n", "os.remove(\"full_data.csv\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "80ee02b7", "metadata": {}, "outputs": [], "source": [ "raw_data[\"input_seq\"] = [f.upper().replace(\"*\", \"\") for f in raw_data[\"mutated_region\"].tolist()]" ] }, { "cell_type": "markdown", "id": "1fef6ad7", "metadata": {}, "source": [ "We'll use simple one-hot encoding for the sequences. This may take a minute to set up. Notice that when\n", "encoding the sequences we record the length of each sequence so that the zero-padding we've added\n", "to the end of the sequence can be masked-out when fitting the model. If you\n", "want the zero-padding included for some reason, you can just set all sequence lengths to be the same.\n", "\n", "We're saving all of the encoded sequences to disk, which isn't STRICTLY necessary in this case\n", "but may be if there are a lot of sequences. Alternatively we could create a custom Dataset object\n", "(see the Advanced tutorials) which encodes the sequences as they are loaded from a fasta." ] }, { "cell_type": "code", "execution_count": 8, "id": "98745fa1", "metadata": {}, "outputs": [], "source": [ "def one_hot_encode(input_seq_list, y_values, chunk_size, ftype = \"train\"):\n", " aas = [\"A\", \"C\", \"D\", \"E\", \"F\", \"G\", \"H\", \"I\",\n", " \"K\", \"L\", \"M\", \"N\", \"P\", \"Q\", \"R\", \"S\", \"T\",\n", " \"V\", \"W\", \"Y\", \"-\"]\n", " output_x, output_y, output_seqlen = [], [], []\n", " xfiles, yfiles, seqlen_files = [], [], []\n", " fcounter = 0\n", " \n", " for seq, y_value in zip(input_seq_list, y_values):\n", " encoded_x = np.zeros((1,57,21), dtype = np.uint8)\n", " for i, letter in enumerate(seq):\n", " encoded_x[0, i, aas.index(letter)] = 1\n", "\n", " output_x.append(encoded_x)\n", " output_y.append(y_value)\n", " output_seqlen.append(len(seq))\n", "\n", " if len(output_x) >= chunk_size:\n", " xfiles.append(f\"{fcounter}_{ftype}_xblock.npy\")\n", " yfiles.append(f\"{fcounter}_{ftype}_yblock.npy\")\n", " seqlen_files.append(f\"{fcounter}_{ftype}_seqlen.npy\")\n", " np.save(xfiles[-1], np.vstack(output_x))\n", " np.save(yfiles[-1], np.asarray(output_y))\n", " np.save(seqlen_files[-1], np.array(output_seqlen).astype(np.int32))\n", " fcounter += 1\n", " output_x, output_y, output_seqlen = [], [], []\n", " print(f\"Encoded file {fcounter}\")\n", " return xfiles, yfiles, seqlen_files" ] }, { "cell_type": "code", "execution_count": 9, "id": "18d97977", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Encoded file 1\n", "Encoded file 2\n", "Encoded file 3\n", "Encoded file 4\n", "Encoded file 5\n", "Encoded file 6\n", "Encoded file 7\n", "Encoded file 8\n", "Encoded file 9\n", "Encoded file 10\n", "Encoded file 11\n", "Encoded file 12\n", "Encoded file 13\n", "Encoded file 14\n", "Encoded file 15\n", "Encoded file 16\n", "Encoded file 17\n", "Encoded file 18\n", "Encoded file 19\n", "Encoded file 20\n", "Encoded file 21\n", "Encoded file 22\n", "Encoded file 23\n", "Encoded file 24\n", "Encoded file 25\n", "Encoded file 26\n", "Encoded file 27\n", "Encoded file 28\n", "Encoded file 29\n", "Encoded file 30\n", "Encoded file 31\n", "Encoded file 32\n", "Encoded file 33\n", "Encoded file 34\n", "Encoded file 35\n", "Encoded file 36\n", "Encoded file 37\n", "Encoded file 38\n", "Encoded file 39\n", "Encoded file 40\n", "Encoded file 41\n", "Encoded file 42\n", "Encoded file 43\n", "Encoded file 44\n", "Encoded file 45\n", "Encoded file 46\n", "Encoded file 47\n", "Encoded file 48\n", "Encoded file 49\n", "Encoded file 50\n", "Encoded file 51\n", "Encoded file 52\n", "Encoded file 53\n", "Encoded file 54\n", "Encoded file 55\n", "Encoded file 56\n", "Encoded file 57\n", "Encoded file 58\n", "Encoded file 59\n", "Encoded file 60\n", "Encoded file 61\n", "Encoded file 62\n", "Encoded file 63\n", "Encoded file 64\n", "Encoded file 65\n", "Encoded file 66\n", "Encoded file 67\n", "Encoded file 68\n", "Encoded file 69\n", "Encoded file 70\n", "Encoded file 71\n", "Encoded file 72\n", "Encoded file 73\n", "Encoded file 74\n", "Encoded file 75\n", "Encoded file 76\n", "Encoded file 77\n", "Encoded file 78\n", "Encoded file 79\n", "Encoded file 80\n", "Encoded file 81\n", "Encoded file 82\n", "Encoded file 83\n", "Encoded file 84\n", "Encoded file 85\n", "Encoded file 86\n", "Encoded file 87\n", "Encoded file 88\n", "Encoded file 89\n", "Encoded file 90\n", "Encoded file 91\n", "Encoded file 92\n", "Encoded file 93\n", "Encoded file 94\n", "Encoded file 95\n", "Encoded file 96\n", "Encoded file 97\n", "Encoded file 98\n", "Encoded file 99\n", "Encoded file 100\n", "Encoded file 1\n", "Encoded file 2\n", "Encoded file 3\n", "Encoded file 4\n", "Encoded file 5\n", "Encoded file 6\n", "Encoded file 7\n", "Encoded file 8\n", "Encoded file 9\n", "Encoded file 10\n", "Encoded file 11\n", "Encoded file 12\n", "Encoded file 13\n", "Encoded file 14\n", "Encoded file 15\n", "Encoded file 16\n", "Encoded file 17\n", "Encoded file 18\n", "Encoded file 19\n", "Encoded file 20\n", "Encoded file 21\n", "Encoded file 22\n", "Encoded file 23\n", "Encoded file 24\n", "Encoded file 25\n", "Encoded file 26\n", "Encoded file 27\n", "Encoded file 28\n", "Encoded file 29\n", "Encoded file 30\n", "Encoded file 31\n", "Encoded file 32\n", "Encoded file 33\n", "Encoded file 34\n", "Encoded file 35\n", "Encoded file 36\n", "Encoded file 37\n", "Encoded file 38\n", "Encoded file 39\n", "Encoded file 40\n", "Encoded file 41\n" ] } ], "source": [ "train_data = raw_data[raw_data[\"des_mut_split\"]==\"train\"]\n", "test_data = raw_data[raw_data[\"des_mut_split\"]==\"test\"]\n", "\n", "\n", "train_x_files, train_y_files, train_seqlen_files = one_hot_encode(train_data[\"input_seq\"].tolist(),\n", " train_data[\"score\"].tolist(), 2000, \"train\")\n", "test_x_files, test_y_files, test_seqlen_files = one_hot_encode(test_data[\"input_seq\"].tolist(),\n", " test_data[\"score\"].tolist(), 2000, \"test\")" ] }, { "cell_type": "markdown", "id": "07b24d58-245d-454e-973b-b3ef81f4e740", "metadata": {}, "source": [ "Notice that we pass the list of seqlen_files into the dataset builder. This is required if working with\n", "3d arrays / convolution kernels. If you are working with 2d arrays / fixed-length vector kernels,\n", "the default for the third argument (```None```) is appropriate." ] }, { "cell_type": "code", "execution_count": 10, "id": "99584178", "metadata": {}, "outputs": [], "source": [ "training_dset = build_regression_dataset(train_x_files, train_y_files, train_seqlen_files, chunk_size = 2000)" ] }, { "cell_type": "markdown", "id": "d4b89e96", "metadata": {}, "source": [ "Here we'll use the Conv1dRBF kernel, a kernel for sequences. Convolution kernels are usually slower than RBF / Matern, especially if the sequence is long. We'll run a quick and dirty tuning experiment using 1024 random features, then fine-tune this using a larger number of random features just as we did for the tabular dataset.\n", "\n", "Many kernels in xGPR have kernel-specific settings. For Conv1dRBF, we can set two key options: sequence averaging, which is one of 'none', 'sqrt' or 'full', and the width of the convolution to use. Just as with a convolutional network, the width of the convolution filters can affect performance. One way to choose a good setting: see what marginal likelihood score you get from hyperparameter tuning (e.g. with ``crude_bayes`` or ``crude_grid``) using a small number of RFFs (e.g. 1024 - 2048) for several different settings of \"conv_width\". The smallest score achieved likely corresponds to the best value for \"conv_width\"." ] }, { "cell_type": "code", "execution_count": 11, "id": "1c451d27", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Grid point 0 acquired.\n", "Grid point 1 acquired.\n", "Grid point 2 acquired.\n", "Grid point 3 acquired.\n", "Grid point 4 acquired.\n", "Grid point 5 acquired.\n", "Grid point 6 acquired.\n", "Grid point 7 acquired.\n", "Grid point 8 acquired.\n", "Grid point 9 acquired.\n", "New hparams: [-2.4229936]\n", "Additional acquisition 10.\n", "New hparams: [-1.9242547]\n", "Additional acquisition 11.\n", "New hparams: [-1.8032397]\n", "Additional acquisition 12.\n", "New hparams: [-1.8347338]\n", "Best score achieved: 128773.492\n", "Best hyperparams: [-0.2832717 -1.9242547]\n", "Best estimated negative marginal log likelihood: 128773.492\n", "Wallclock: 20.574951887130737\n" ] } ], "source": [ "aav_model = xGPReg(num_rffs = 1024, variance_rffs = 512,\n", " kernel_choice = \"Conv1dRBF\",\n", " kernel_settings = {\"conv_width\":11, \"averaging\":'none'},\n", " verbose = True, device = \"cuda\")\n", "\n", "start_time = time.time()\n", "hparams, niter, best_score = aav_model.tune_hyperparams_crude(training_dset)\n", "end_time = time.time()\n", "\n", "print(f\"Best estimated negative marginal log likelihood: {best_score}\")\n", "print(f\"Wallclock: {end_time - start_time}\")" ] }, { "cell_type": "markdown", "id": "9974ca79-ec65-44ca-a897-6288d7bed2a7", "metadata": {}, "source": [ "We now have a rough estimate of our hyperparameters, acquired using a sketchy kernel approximation\n", "(num_rffs=1024) and a crude tuning procedure. Let's fine-tune this a little. We could use\n", "the built-in tuning routine in xGPR the way we did for the tabular data, or we could use\n", "Optuna (or some other library), or we could do a simple gridsearch. For illustrative\n", "purposes here, we'll use Optuna using num_rffs=4,096 (a somewhat better kernel\n", "approximation) and see what that looks like. We'll search the region around the\n", "hyperparameters obtained from ``tune_hyperparams_crude``. To run this\n", "next piece, you'll need to have Optuna installed." ] }, { "cell_type": "code", "execution_count": 12, "id": "8d6ec16a-6e18-4b46-923f-a357fb745d88", "metadata": {}, "outputs": [], "source": [ "import optuna\n", "from optuna.samplers import TPESampler\n", "\n", "def objective(trial):\n", " lambda_ = trial.suggest_float(\"lambda_\", -2., 0.)\n", " sigma = trial.suggest_float(\"sigma\", -3., -1.)\n", " hparams = np.array([lambda_, sigma])\n", " nmll = aav_model.exact_nmll(hparams, training_dset)\n", " return nmll" ] }, { "cell_type": "code", "execution_count": 13, "id": "aef1d565-62f2-438e-9e47-be6a396a9b39", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:04:42,791] A new study created in memory with name: no-name-a0d27c2b-5d22-41b1-8556-94839784d7cb\n", "[I 2025-04-16 10:04:56,000] Trial 0 finished with value: 113938.7126392912 and parameters: {'lambda_': -0.6070616288042767, 'sigma': -2.4277213300992413}. Best is trial 0 with value: 113938.7126392912.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:05:09,268] Trial 1 finished with value: 110497.60820715068 and parameters: {'lambda_': -1.5462970928715938, 'sigma': -1.8973704618342175}. Best is trial 1 with value: 110497.60820715068.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:05:22,538] Trial 2 finished with value: 110013.32776801154 and parameters: {'lambda_': -0.5610620604288739, 'sigma': -2.153787079751078}. Best is trial 2 with value: 110013.32776801154.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:05:35,831] Trial 3 finished with value: 108222.15752609915 and parameters: {'lambda_': -0.038471603230769036, 'sigma': -1.6303405228302734}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:05:49,204] Trial 4 finished with value: 109165.39421682109 and parameters: {'lambda_': -1.0381361970312781, 'sigma': -2.215764963611699}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:06:02,584] Trial 5 finished with value: 112470.01840575809 and parameters: {'lambda_': -1.3136439676982612, 'sigma': -1.5419005852319168}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:06:15,959] Trial 6 finished with value: 117925.01634157504 and parameters: {'lambda_': -1.1228555106407512, 'sigma': -2.8806442067808633}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:06:29,331] Trial 7 finished with value: 112197.12596152317 and parameters: {'lambda_': -1.2039114893391372, 'sigma': -1.5240091885359286}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:06:42,706] Trial 8 finished with value: 111654.5965248778 and parameters: {'lambda_': -1.635016539093, 'sigma': -2.649096487705015}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:06:56,075] Trial 9 finished with value: 108705.16831389123 and parameters: {'lambda_': -0.9368972523163233, 'sigma': -1.9363448258062679}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:07:09,452] Trial 10 finished with value: 112370.9705344652 and parameters: {'lambda_': -0.03129046371014721, 'sigma': -1.0097302786311597}. Best is trial 3 with value: 108222.15752609915.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:07:22,820] Trial 11 finished with value: 108204.02327921556 and parameters: {'lambda_': -0.01866037953665966, 'sigma': -1.635646720641679}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:07:36,180] Trial 12 finished with value: 108794.3001803944 and parameters: {'lambda_': -0.023783021369743196, 'sigma': -1.453582345648574}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:07:49,543] Trial 13 finished with value: 112054.93092277527 and parameters: {'lambda_': -0.3826565674804564, 'sigma': -1.1788298267417594}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:08:02,904] Trial 14 finished with value: 108288.82550766022 and parameters: {'lambda_': -0.2593656053541154, 'sigma': -1.7832602361921872}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:08:16,264] Trial 15 finished with value: 109545.03620512375 and parameters: {'lambda_': -0.7745612805651452, 'sigma': -1.674253799818601}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:08:29,624] Trial 16 finished with value: 110288.55511532603 and parameters: {'lambda_': -0.2603175001600282, 'sigma': -1.333904798028223}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:08:42,986] Trial 17 finished with value: 110162.80893674257 and parameters: {'lambda_': -1.9409547062276986, 'sigma': -2.1327955818958895}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:08:56,346] Trial 18 finished with value: 111491.38554989494 and parameters: {'lambda_': -0.3835066195659215, 'sigma': -1.2419508741475265}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:09:09,706] Trial 19 finished with value: 108306.86184706952 and parameters: {'lambda_': -0.043835835276730635, 'sigma': -1.7445010980321713}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:09:23,067] Trial 20 finished with value: 112432.8051483832 and parameters: {'lambda_': -0.7304436166034872, 'sigma': -2.3852934884287125}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:09:36,433] Trial 21 finished with value: 108276.14911072241 and parameters: {'lambda_': -0.23150335811101075, 'sigma': -1.769944025853498}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:09:49,793] Trial 22 finished with value: 108456.88610376373 and parameters: {'lambda_': -0.1582976030195825, 'sigma': -1.5940025908052449}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:10:03,154] Trial 23 finished with value: 108352.78678593124 and parameters: {'lambda_': -0.44394970021613395, 'sigma': -1.8396633950108474}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:10:16,512] Trial 24 finished with value: 109960.12475624992 and parameters: {'lambda_': -0.2567864803764853, 'sigma': -1.3756449133769941}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:10:29,872] Trial 25 finished with value: 109707.05886217361 and parameters: {'lambda_': -0.17503409708419343, 'sigma': -1.9891378872913226}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:10:43,233] Trial 26 finished with value: 108924.54346795652 and parameters: {'lambda_': -0.51559181835663, 'sigma': -1.6593062029334495}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:10:56,590] Trial 27 finished with value: 111844.7980757734 and parameters: {'lambda_': -0.16068988881110718, 'sigma': -1.1118194185816517}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:11:09,953] Trial 28 finished with value: 110171.82096851458 and parameters: {'lambda_': -0.3405223941277511, 'sigma': -2.088141363129792}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:11:23,315] Trial 29 finished with value: 111724.92956268287 and parameters: {'lambda_': -0.6843658044124352, 'sigma': -2.323089362687985}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:11:36,674] Trial 30 finished with value: 112351.14569938189 and parameters: {'lambda_': -0.8877938379147875, 'sigma': -1.363034462928606}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:11:50,039] Trial 31 finished with value: 108305.69229256368 and parameters: {'lambda_': -0.15545204796200673, 'sigma': -1.774862465080108}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:12:03,411] Trial 32 finished with value: 108446.19393160455 and parameters: {'lambda_': -0.29254801396567537, 'sigma': -1.8600837006081945}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:12:16,775] Trial 33 finished with value: 108626.4612872399 and parameters: {'lambda_': -0.5591398529547196, 'sigma': -1.750835718271973}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-04-16 10:12:30,132] Trial 34 finished with value: 110345.63855749354 and parameters: {'lambda_': -0.03233495711217844, 'sigma': -1.9864859129386128}. Best is trial 11 with value: 108204.02327921556.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated NMLL.\n" ] } ], "source": [ "aav_model.num_rffs = 4096\n", "\n", "sampler = TPESampler(seed=123)\n", "study = optuna.create_study(sampler=sampler)\n", "study.optimize(objective, n_trials=35)" ] }, { "cell_type": "markdown", "id": "47f479d2-eea9-4e0d-bde6-9a01b9d499b9", "metadata": {}, "source": [ "Set the model hyperparameters to the best ones found by Optuna." ] }, { "cell_type": "code", "execution_count": 14, "id": "ec8b4763", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'lambda_': -0.01866037953665966, 'sigma': -1.635646720641679}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "study.best_params" ] }, { "cell_type": "code", "execution_count": 15, "id": "5f55ef42-abec-408a-8750-30e4a1ef40b8", "metadata": {}, "outputs": [], "source": [ "aav_model.set_hyperparams(np.array([-1.9997, -1.005464]), training_dset)" ] }, { "cell_type": "markdown", "id": "d47e4f4e", "metadata": {}, "source": [ "Now we'll fit the model using 8192 RFFs. We like to use a more accurate kernel approximationwhen fitting than when tuning for two reasons. First, tuning is more expensive because the model has to be fit multiple times when tuning hyperparameters. Second, model performance usually\n", "increases faster by increasing the number of rffs used for fitting than for tuning. (Using 16,384 RFFs or 32,768 here for fitting further\n", "increases test set performance as you'd expect, but with diminishing returns. The resulting small performance gain may be worthwhile for some applications and not so much for others.)\n", "\n", "On gpu, for fitting, ``mode=exact`` works well up to 8,192 RFFs or so, while ``mode=cg`` although\n", "slower for small numbers of RFFs is more scalable. On this dataset, using 8,192 RFFs, \"exact\" takes about 70 seconds on our GPU.\n", "We'll use cg here just for illustrative purposes.\n", "\n", "``tol`` determines how tight the fit is. 1e-6 (default) is usually fine. Decreasing the number will improve performance but\n", "with rapidly diminishing returns and make fitting take longer. For noise free data or to get a small additional boost in\n", "performance, use 1e-7. 1e-8 is (nearly always) overkill." ] }, { "cell_type": "code", "execution_count": 16, "id": "fa173a47", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "starting fitting\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "Using rank: 3000\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", "Chunk 20 complete.\n", "Chunk 30 complete.\n", "Chunk 40 complete.\n", "Chunk 50 complete.\n", "Chunk 60 complete.\n", "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", "0 iterations complete.\n", "5 iterations complete.\n", "10 iterations complete.\n", "CG iterations: 13\n", "Now performing variance calculations...\n", "Fitting complete.\n", "Wallclock: 112.52763795852661\n" ] } ], "source": [ "aav_model.num_rffs = 8192\n", "start_time = time.time()\n", "aav_model.fit(training_dset, mode = 'cg', tol = 1e-6)\n", "end_time = time.time()\n", "print(f\"Wallclock: {end_time - start_time}\")" ] }, { "cell_type": "code", "execution_count": 17, "id": "af17b1dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wallclock: 0.5303468704223633\n" ] } ], "source": [ "start_time = time.time()\n", "all_preds, ground_truth = [], []\n", "for xfile, yfile, sfile in zip(test_x_files, test_y_files, test_seqlen_files):\n", " x, y, s = np.load(xfile), np.load(yfile), np.load(sfile)\n", " ground_truth.append(y)\n", " preds = aav_model.predict(x, s, get_var = False)\n", " all_preds.append(preds)\n", " \n", "all_preds, ground_truth = np.concatenate(all_preds), np.concatenate(ground_truth)\n", "end_time = time.time()\n", "print(f\"Wallclock: {end_time - start_time}\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "ca63c6ce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SignificanceResult(statistic=np.float64(0.7659180535079166), pvalue=np.float64(0.0))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from scipy.stats import spearmanr\n", "\n", "spearmanr(all_preds, ground_truth)" ] }, { "cell_type": "markdown", "id": "a6e3d321", "metadata": {}, "source": [ "Spearman's r of about 0.77 is slightly better than a 1d-CNN reported by Dallago et al\n", "for this dataset (0.75) and is similar to the performance of a fine-tuned LLM (Spearman's r 0.79).\n", "As discussed above, we can get further slight improvements in performance\n", "just by tweaking this model. We can do even better by using a more informative\n", "representation of the protein sequences. In our original paper we achieved a Spearman's r\n", "of about 0.8 on this dataset, outperforming fine-tuned LLMs (and costing significantly less to train\n", "than a fine-tuned LLM).\n", "Whether small gains in performance from further \"tweaking\" or more informative representations is worthwhile\n", "obviously depends on your application..." ] }, { "cell_type": "code", "execution_count": null, "id": "1897a0a0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 19, "id": "02dd194d", "metadata": {}, "outputs": [], "source": [ "for testx, testy, tests in zip(test_x_files, test_y_files, test_seqlen_files):\n", " os.remove(testx)\n", " os.remove(testy)\n", " os.remove(tests)" ] }, { "cell_type": "code", "execution_count": 20, "id": "b857b33e-a21e-45c3-9394-fdb50806ea98", "metadata": {}, "outputs": [], "source": [ "for xfile, yfile, sfile in zip(train_x_files, train_y_files, train_seqlen_files):\n", " os.remove(xfile)\n", " os.remove(yfile)\n", " os.remove(sfile)" ] }, { "cell_type": "code", "execution_count": null, "id": "8fadc454-058d-4e0e-be8f-e0313cd28e3e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }