{ "cells": [ { "cell_type": "markdown", "id": "b1382a54", "metadata": {}, "source": [ "## Example: Using a custom Dataset object\n", "\n", "You can use the `build_regression_dataset` and `build_classification_dataset` calls in xGPR to build a Dataset object that wraps your training data, then pass this to the fitting and tuning routines. Both functions work with numpy arrays\n", "either in memory or saved on disk. However, there may be situations where your data is not in the form of a\n", "numpy array or list of `.npy` files -- if your data is stored in an HDF5 file or SQLite db, for example --\n", "and while you could take your data and save it to disk as .npy files, it can sometimes be more convenient to\n", "keep your data in its original form without making a copy of it unnecessarily, especially if the input\n", "dataset is large. In these situations it's easy to create a custom Dataset object by subclassing the\n", "`DatasetBaseclass` object in xGPR (a little like a custom Dataloader in PyTorch).\n", "\n", "In this example, we'll illustrate how to build a custom Dataset that we can pass to all of the training\n", "and tuning functions. You can also use this in situations where there's some special prep you want to\n", "run on each datapoint before it's passed to xGPR." ] }, { "cell_type": "code", "execution_count": 1, "id": "c1d0a6db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/ssd1/Documents/gp_proteins/venv_testing/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import math\n", "import time\n", "\n", "import wget\n", "import pandas as pd\n", "import numpy as np\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "\n", "from xGPR import xGPRegression as xGPReg\n", "from xGPR import DatasetBaseclass" ] }, { "cell_type": "code", "execution_count": 2, "id": "39d7217e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-1 / unknown" ] } ], "source": [ "fname = wget.download(\"https://archive.ics.uci.edu/ml/machine-learning-databases/00265/CASP.csv\")\n", "raw_data = pd.read_csv(fname)\n", "os.remove(fname)" ] }, { "cell_type": "code", "execution_count": 3, "id": "ec1e70f2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | RMSD | \n", "F1 | \n", "F2 | \n", "F3 | \n", "F4 | \n", "F5 | \n", "F6 | \n", "F7 | \n", "F8 | \n", "F9 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "17.284 | \n", "13558.30 | \n", "4305.35 | \n", "0.31754 | \n", "162.1730 | \n", "1.872791e+06 | \n", "215.3590 | \n", "4287.87 | \n", "102 | \n", "27.0302 | \n", "
| 1 | \n", "6.021 | \n", "6191.96 | \n", "1623.16 | \n", "0.26213 | \n", "53.3894 | \n", "8.034467e+05 | \n", "87.2024 | \n", "3328.91 | \n", "39 | \n", "38.5468 | \n", "
| 2 | \n", "9.275 | \n", "7725.98 | \n", "1726.28 | \n", "0.22343 | \n", "67.2887 | \n", "1.075648e+06 | \n", "81.7913 | \n", "2981.04 | \n", "29 | \n", "38.8119 | \n", "
| 3 | \n", "15.851 | \n", "8424.58 | \n", "2368.25 | \n", "0.28111 | \n", "67.8325 | \n", "1.210472e+06 | \n", "109.4390 | \n", "3248.22 | \n", "70 | \n", "39.0651 | \n", "
| 4 | \n", "7.962 | \n", "7460.84 | \n", "1736.94 | \n", "0.23280 | \n", "52.4123 | \n", "1.021020e+06 | \n", "94.5234 | \n", "2814.42 | \n", "41 | \n", "39.9147 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 45725 | \n", "3.762 | \n", "8037.12 | \n", "2777.68 | \n", "0.34560 | \n", "64.3390 | \n", "1.105797e+06 | \n", "112.7460 | \n", "3384.21 | \n", "84 | \n", "36.8036 | \n", "
| 45726 | \n", "6.521 | \n", "7978.76 | \n", "2508.57 | \n", "0.31440 | \n", "75.8654 | \n", "1.116725e+06 | \n", "102.2770 | \n", "3974.52 | \n", "54 | \n", "36.0470 | \n", "
| 45727 | \n", "10.356 | \n", "7726.65 | \n", "2489.58 | \n", "0.32220 | \n", "70.9903 | \n", "1.076560e+06 | \n", "103.6780 | \n", "3290.46 | \n", "46 | \n", "37.4718 | \n", "
| 45728 | \n", "9.791 | \n", "8878.93 | \n", "3055.78 | \n", "0.34416 | \n", "94.0314 | \n", "1.242266e+06 | \n", "115.1950 | \n", "3421.79 | \n", "41 | \n", "35.6045 | \n", "
| 45729 | \n", "18.827 | \n", "12732.40 | \n", "4444.36 | \n", "0.34905 | \n", "157.6300 | \n", "1.788897e+06 | \n", "229.4590 | \n", "4626.85 | \n", "141 | \n", "29.8118 | \n", "
45730 rows × 10 columns
\n", "