{ "cells": [ { "cell_type": "markdown", "id": "b1382a54", "metadata": {}, "source": [ "## Example: Fitting tabular data\n", "\n", "This straightforward example makes use of a small,\n", "fairly random UCI repository dataset with about 45,000 datapoints. We'll\n", "download this data, do some light preprocessing, and fit an RBF kernel.\n", "\n", "These experiments used xGPR v0.4.8. Note that if setting device to cuda,\n", "xGPR always uses the currently active cuda device. To control which\n", "device this is, you can set the environment variable \"CUDA_VISIBLE_DEVICES\"." ] }, { "cell_type": "code", "execution_count": 1, "id": "c1d0a6db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/ssd1/Documents/gp_proteins/venv_testing/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import math\n", "import time\n", "\n", "import wget\n", "import pandas as pd\n", "import numpy as np\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "\n", "from xGPR import xGPRegression as xGPReg\n", "from xGPR import build_regression_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "39d7217e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-1 / unknown" ] } ], "source": [ "fname = wget.download(\"https://archive.ics.uci.edu/ml/machine-learning-databases/00265/CASP.csv\")\n", "raw_data = pd.read_csv(fname)\n", "os.remove(fname)" ] }, { "cell_type": "code", "execution_count": 3, "id": "ec1e70f2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | RMSD | \n", "F1 | \n", "F2 | \n", "F3 | \n", "F4 | \n", "F5 | \n", "F6 | \n", "F7 | \n", "F8 | \n", "F9 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "17.284 | \n", "13558.30 | \n", "4305.35 | \n", "0.31754 | \n", "162.1730 | \n", "1.872791e+06 | \n", "215.3590 | \n", "4287.87 | \n", "102 | \n", "27.0302 | \n", "
| 1 | \n", "6.021 | \n", "6191.96 | \n", "1623.16 | \n", "0.26213 | \n", "53.3894 | \n", "8.034467e+05 | \n", "87.2024 | \n", "3328.91 | \n", "39 | \n", "38.5468 | \n", "
| 2 | \n", "9.275 | \n", "7725.98 | \n", "1726.28 | \n", "0.22343 | \n", "67.2887 | \n", "1.075648e+06 | \n", "81.7913 | \n", "2981.04 | \n", "29 | \n", "38.8119 | \n", "
| 3 | \n", "15.851 | \n", "8424.58 | \n", "2368.25 | \n", "0.28111 | \n", "67.8325 | \n", "1.210472e+06 | \n", "109.4390 | \n", "3248.22 | \n", "70 | \n", "39.0651 | \n", "
| 4 | \n", "7.962 | \n", "7460.84 | \n", "1736.94 | \n", "0.23280 | \n", "52.4123 | \n", "1.021020e+06 | \n", "94.5234 | \n", "2814.42 | \n", "41 | \n", "39.9147 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 45725 | \n", "3.762 | \n", "8037.12 | \n", "2777.68 | \n", "0.34560 | \n", "64.3390 | \n", "1.105797e+06 | \n", "112.7460 | \n", "3384.21 | \n", "84 | \n", "36.8036 | \n", "
| 45726 | \n", "6.521 | \n", "7978.76 | \n", "2508.57 | \n", "0.31440 | \n", "75.8654 | \n", "1.116725e+06 | \n", "102.2770 | \n", "3974.52 | \n", "54 | \n", "36.0470 | \n", "
| 45727 | \n", "10.356 | \n", "7726.65 | \n", "2489.58 | \n", "0.32220 | \n", "70.9903 | \n", "1.076560e+06 | \n", "103.6780 | \n", "3290.46 | \n", "46 | \n", "37.4718 | \n", "
| 45728 | \n", "9.791 | \n", "8878.93 | \n", "3055.78 | \n", "0.34416 | \n", "94.0314 | \n", "1.242266e+06 | \n", "115.1950 | \n", "3421.79 | \n", "41 | \n", "35.6045 | \n", "
| 45729 | \n", "18.827 | \n", "12732.40 | \n", "4444.36 | \n", "0.34905 | \n", "157.6300 | \n", "1.788897e+06 | \n", "229.4590 | \n", "4626.85 | \n", "141 | \n", "29.8118 | \n", "
45730 rows × 10 columns
\n", "