Example: Sequence Data

Next, let’s consider the AAV dataset, designed vs mutant split, from the FLIP benchmark suite. For this dataset, we train on 200,000 length 57 amino acid sequences and try to predict the fitness of a pre-specified test set. Dallago et al. report that a standard 1d-CNN trained on this achieves a Spearman’s r of 0.75, while a 750-million parameter pretrained model that took 50 GPU-days of time to train achieves Spearman’s r of 0.79.

We’ll evaluate a convolution kernel and show that using one-hot encoded input we can easily match the 1dCNN. (If we use embeddings as input, we can easily match the pretrained model, although we won’t do that comparison here.)

This was run using xGPR v0.4.8.

[1]:
import os
import shutil
import subprocess
import math
import time
import zipfile

import pandas as pd
import numpy as np

from xGPR import xGPRegression as xGPReg
from xGPR import build_regression_dataset
/ssd1/Documents/gp_proteins/venv_testing/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2]:
#This may take a minute, because we have to download the data here...
subprocess.run(["git", "clone", "https://github.com/J-SNACKKB/FLIP"])

shutil.move(os.path.join("FLIP", "splits", "aav", "full_data.csv.zip"), "full_data.csv.zip")
fname = "full_data.csv.zip"

with zipfile.ZipFile(fname, "r") as zip_ref:
    zip_ref.extractall()

os.remove("full_data.csv.zip")


shutil.rmtree("FLIP")
Cloning into 'FLIP'...
[3]:
raw_data = pd.read_csv("full_data.csv")
os.remove("full_data.csv")
/tmp/ipykernel_194426/4195622896.py:1: DtypeWarning: Columns (12) have mixed types. Specify dtype option on import or set low_memory=False.
  raw_data = pd.read_csv("full_data.csv")
[7]:
raw_data["input_seq"] = [f.upper().replace("*", "") for f in raw_data["mutated_region"].tolist()]

We’ll use simple one-hot encoding for the sequences. This may take a minute to set up. Notice that when encoding the sequences we record the length of each sequence so that the zero-padding we’ve added to the end of the sequence can be masked-out when fitting the model. If you want the zero-padding included for some reason, you can just set all sequence lengths to be the same.

We’re saving all of the encoded sequences to disk, which isn’t STRICTLY necessary in this case but may be if there are a lot of sequences. Alternatively we could create a custom Dataset object (see the Advanced tutorials) which encodes the sequences as they are loaded from a fasta.

[8]:
def one_hot_encode(input_seq_list, y_values, chunk_size, ftype = "train"):
    aas = ["A", "C", "D", "E", "F", "G", "H", "I",
               "K", "L", "M", "N", "P", "Q", "R", "S", "T",
               "V", "W", "Y", "-"]
    output_x, output_y, output_seqlen = [], [], []
    xfiles, yfiles, seqlen_files = [], [], []
    fcounter = 0

    for seq, y_value in zip(input_seq_list, y_values):
        encoded_x = np.zeros((1,57,21), dtype = np.uint8)
        for i, letter in enumerate(seq):
            encoded_x[0, i, aas.index(letter)] = 1

        output_x.append(encoded_x)
        output_y.append(y_value)
        output_seqlen.append(len(seq))

        if len(output_x) >= chunk_size:
            xfiles.append(f"{fcounter}_{ftype}_xblock.npy")
            yfiles.append(f"{fcounter}_{ftype}_yblock.npy")
            seqlen_files.append(f"{fcounter}_{ftype}_seqlen.npy")
            np.save(xfiles[-1], np.vstack(output_x))
            np.save(yfiles[-1], np.asarray(output_y))
            np.save(seqlen_files[-1], np.array(output_seqlen).astype(np.int32))
            fcounter += 1
            output_x, output_y, output_seqlen = [], [], []
            print(f"Encoded file {fcounter}")
    return xfiles, yfiles, seqlen_files
[9]:
train_data = raw_data[raw_data["des_mut_split"]=="train"]
test_data = raw_data[raw_data["des_mut_split"]=="test"]


train_x_files, train_y_files, train_seqlen_files = one_hot_encode(train_data["input_seq"].tolist(),
                                              train_data["score"].tolist(), 2000, "train")
test_x_files, test_y_files, test_seqlen_files = one_hot_encode(test_data["input_seq"].tolist(),
                                            test_data["score"].tolist(), 2000, "test")
Encoded file 1
Encoded file 2
Encoded file 3
Encoded file 4
Encoded file 5
Encoded file 6
Encoded file 7
Encoded file 8
Encoded file 9
Encoded file 10
Encoded file 11
Encoded file 12
Encoded file 13
Encoded file 14
Encoded file 15
Encoded file 16
Encoded file 17
Encoded file 18
Encoded file 19
Encoded file 20
Encoded file 21
Encoded file 22
Encoded file 23
Encoded file 24
Encoded file 25
Encoded file 26
Encoded file 27
Encoded file 28
Encoded file 29
Encoded file 30
Encoded file 31
Encoded file 32
Encoded file 33
Encoded file 34
Encoded file 35
Encoded file 36
Encoded file 37
Encoded file 38
Encoded file 39
Encoded file 40
Encoded file 41
Encoded file 42
Encoded file 43
Encoded file 44
Encoded file 45
Encoded file 46
Encoded file 47
Encoded file 48
Encoded file 49
Encoded file 50
Encoded file 51
Encoded file 52
Encoded file 53
Encoded file 54
Encoded file 55
Encoded file 56
Encoded file 57
Encoded file 58
Encoded file 59
Encoded file 60
Encoded file 61
Encoded file 62
Encoded file 63
Encoded file 64
Encoded file 65
Encoded file 66
Encoded file 67
Encoded file 68
Encoded file 69
Encoded file 70
Encoded file 71
Encoded file 72
Encoded file 73
Encoded file 74
Encoded file 75
Encoded file 76
Encoded file 77
Encoded file 78
Encoded file 79
Encoded file 80
Encoded file 81
Encoded file 82
Encoded file 83
Encoded file 84
Encoded file 85
Encoded file 86
Encoded file 87
Encoded file 88
Encoded file 89
Encoded file 90
Encoded file 91
Encoded file 92
Encoded file 93
Encoded file 94
Encoded file 95
Encoded file 96
Encoded file 97
Encoded file 98
Encoded file 99
Encoded file 100
Encoded file 1
Encoded file 2
Encoded file 3
Encoded file 4
Encoded file 5
Encoded file 6
Encoded file 7
Encoded file 8
Encoded file 9
Encoded file 10
Encoded file 11
Encoded file 12
Encoded file 13
Encoded file 14
Encoded file 15
Encoded file 16
Encoded file 17
Encoded file 18
Encoded file 19
Encoded file 20
Encoded file 21
Encoded file 22
Encoded file 23
Encoded file 24
Encoded file 25
Encoded file 26
Encoded file 27
Encoded file 28
Encoded file 29
Encoded file 30
Encoded file 31
Encoded file 32
Encoded file 33
Encoded file 34
Encoded file 35
Encoded file 36
Encoded file 37
Encoded file 38
Encoded file 39
Encoded file 40
Encoded file 41

Notice that we pass the list of seqlen_files into the dataset builder. This is required if working with 3d arrays / convolution kernels. If you are working with 2d arrays / fixed-length vector kernels, the default for the third argument (None) is appropriate.

[10]:
training_dset = build_regression_dataset(train_x_files, train_y_files, train_seqlen_files, chunk_size = 2000)

Here we’ll use the Conv1dRBF kernel, a kernel for sequences. Convolution kernels are usually slower than RBF / Matern, especially if the sequence is long. We’ll run a quick and dirty tuning experiment using 1024 random features, then fine-tune this using a larger number of random features just as we did for the tabular dataset.

Many kernels in xGPR have kernel-specific settings. For Conv1dRBF, we can set two key options: sequence averaging, which is one of ‘none’, ‘sqrt’ or ‘full’, and the width of the convolution to use. Just as with a convolutional network, the width of the convolution filters can affect performance. One way to choose a good setting: see what marginal likelihood score you get from hyperparameter tuning (e.g. with crude_bayes or crude_grid) using a small number of RFFs (e.g. 1024 - 2048) for several different settings of “conv_width”. The smallest score achieved likely corresponds to the best value for “conv_width”.

[11]:
aav_model = xGPReg(num_rffs = 1024, variance_rffs = 512,
                  kernel_choice = "Conv1dRBF",
                   kernel_settings = {"conv_width":11, "averaging":'none'},
                   verbose = True, device = "cuda")

start_time = time.time()
hparams, niter, best_score = aav_model.tune_hyperparams_crude(training_dset)
end_time = time.time()

print(f"Best estimated negative marginal log likelihood: {best_score}")
print(f"Wallclock: {end_time - start_time}")
Grid point 0 acquired.
Grid point 1 acquired.
Grid point 2 acquired.
Grid point 3 acquired.
Grid point 4 acquired.
Grid point 5 acquired.
Grid point 6 acquired.
Grid point 7 acquired.
Grid point 8 acquired.
Grid point 9 acquired.
New hparams: [-2.4229936]
Additional acquisition 10.
New hparams: [-1.9242547]
Additional acquisition 11.
New hparams: [-1.8032397]
Additional acquisition 12.
New hparams: [-1.8347338]
Best score achieved: 128773.492
Best hyperparams: [-0.2832717 -1.9242547]
Best estimated negative marginal log likelihood: 128773.492
Wallclock: 20.574951887130737

We now have a rough estimate of our hyperparameters, acquired using a sketchy kernel approximation (num_rffs=1024) and a crude tuning procedure. Let’s fine-tune this a little. We could use the built-in tuning routine in xGPR the way we did for the tabular data, or we could use Optuna (or some other library), or we could do a simple gridsearch. For illustrative purposes here, we’ll use Optuna using num_rffs=4,096 (a somewhat better kernel approximation) and see what that looks like. We’ll search the region around the hyperparameters obtained from tune_hyperparams_crude. To run this next piece, you’ll need to have Optuna installed.

[12]:
import optuna
from optuna.samplers import TPESampler

def objective(trial):
    lambda_ = trial.suggest_float("lambda_", -2., 0.)
    sigma = trial.suggest_float("sigma", -3., -1.)
    hparams = np.array([lambda_, sigma])
    nmll = aav_model.exact_nmll(hparams, training_dset)
    return nmll
[13]:
aav_model.num_rffs = 4096

sampler = TPESampler(seed=123)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=35)
[I 2025-04-16 10:04:42,791] A new study created in memory with name: no-name-a0d27c2b-5d22-41b1-8556-94839784d7cb
[I 2025-04-16 10:04:56,000] Trial 0 finished with value: 113938.7126392912 and parameters: {'lambda_': -0.6070616288042767, 'sigma': -2.4277213300992413}. Best is trial 0 with value: 113938.7126392912.
Evaluated NMLL.
[I 2025-04-16 10:05:09,268] Trial 1 finished with value: 110497.60820715068 and parameters: {'lambda_': -1.5462970928715938, 'sigma': -1.8973704618342175}. Best is trial 1 with value: 110497.60820715068.
Evaluated NMLL.
[I 2025-04-16 10:05:22,538] Trial 2 finished with value: 110013.32776801154 and parameters: {'lambda_': -0.5610620604288739, 'sigma': -2.153787079751078}. Best is trial 2 with value: 110013.32776801154.
Evaluated NMLL.
[I 2025-04-16 10:05:35,831] Trial 3 finished with value: 108222.15752609915 and parameters: {'lambda_': -0.038471603230769036, 'sigma': -1.6303405228302734}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:05:49,204] Trial 4 finished with value: 109165.39421682109 and parameters: {'lambda_': -1.0381361970312781, 'sigma': -2.215764963611699}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:06:02,584] Trial 5 finished with value: 112470.01840575809 and parameters: {'lambda_': -1.3136439676982612, 'sigma': -1.5419005852319168}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:06:15,959] Trial 6 finished with value: 117925.01634157504 and parameters: {'lambda_': -1.1228555106407512, 'sigma': -2.8806442067808633}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:06:29,331] Trial 7 finished with value: 112197.12596152317 and parameters: {'lambda_': -1.2039114893391372, 'sigma': -1.5240091885359286}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:06:42,706] Trial 8 finished with value: 111654.5965248778 and parameters: {'lambda_': -1.635016539093, 'sigma': -2.649096487705015}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:06:56,075] Trial 9 finished with value: 108705.16831389123 and parameters: {'lambda_': -0.9368972523163233, 'sigma': -1.9363448258062679}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:07:09,452] Trial 10 finished with value: 112370.9705344652 and parameters: {'lambda_': -0.03129046371014721, 'sigma': -1.0097302786311597}. Best is trial 3 with value: 108222.15752609915.
Evaluated NMLL.
[I 2025-04-16 10:07:22,820] Trial 11 finished with value: 108204.02327921556 and parameters: {'lambda_': -0.01866037953665966, 'sigma': -1.635646720641679}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:07:36,180] Trial 12 finished with value: 108794.3001803944 and parameters: {'lambda_': -0.023783021369743196, 'sigma': -1.453582345648574}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:07:49,543] Trial 13 finished with value: 112054.93092277527 and parameters: {'lambda_': -0.3826565674804564, 'sigma': -1.1788298267417594}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:08:02,904] Trial 14 finished with value: 108288.82550766022 and parameters: {'lambda_': -0.2593656053541154, 'sigma': -1.7832602361921872}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:08:16,264] Trial 15 finished with value: 109545.03620512375 and parameters: {'lambda_': -0.7745612805651452, 'sigma': -1.674253799818601}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:08:29,624] Trial 16 finished with value: 110288.55511532603 and parameters: {'lambda_': -0.2603175001600282, 'sigma': -1.333904798028223}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:08:42,986] Trial 17 finished with value: 110162.80893674257 and parameters: {'lambda_': -1.9409547062276986, 'sigma': -2.1327955818958895}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:08:56,346] Trial 18 finished with value: 111491.38554989494 and parameters: {'lambda_': -0.3835066195659215, 'sigma': -1.2419508741475265}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:09:09,706] Trial 19 finished with value: 108306.86184706952 and parameters: {'lambda_': -0.043835835276730635, 'sigma': -1.7445010980321713}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:09:23,067] Trial 20 finished with value: 112432.8051483832 and parameters: {'lambda_': -0.7304436166034872, 'sigma': -2.3852934884287125}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:09:36,433] Trial 21 finished with value: 108276.14911072241 and parameters: {'lambda_': -0.23150335811101075, 'sigma': -1.769944025853498}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:09:49,793] Trial 22 finished with value: 108456.88610376373 and parameters: {'lambda_': -0.1582976030195825, 'sigma': -1.5940025908052449}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:10:03,154] Trial 23 finished with value: 108352.78678593124 and parameters: {'lambda_': -0.44394970021613395, 'sigma': -1.8396633950108474}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:10:16,512] Trial 24 finished with value: 109960.12475624992 and parameters: {'lambda_': -0.2567864803764853, 'sigma': -1.3756449133769941}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:10:29,872] Trial 25 finished with value: 109707.05886217361 and parameters: {'lambda_': -0.17503409708419343, 'sigma': -1.9891378872913226}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:10:43,233] Trial 26 finished with value: 108924.54346795652 and parameters: {'lambda_': -0.51559181835663, 'sigma': -1.6593062029334495}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:10:56,590] Trial 27 finished with value: 111844.7980757734 and parameters: {'lambda_': -0.16068988881110718, 'sigma': -1.1118194185816517}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:11:09,953] Trial 28 finished with value: 110171.82096851458 and parameters: {'lambda_': -0.3405223941277511, 'sigma': -2.088141363129792}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:11:23,315] Trial 29 finished with value: 111724.92956268287 and parameters: {'lambda_': -0.6843658044124352, 'sigma': -2.323089362687985}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:11:36,674] Trial 30 finished with value: 112351.14569938189 and parameters: {'lambda_': -0.8877938379147875, 'sigma': -1.363034462928606}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:11:50,039] Trial 31 finished with value: 108305.69229256368 and parameters: {'lambda_': -0.15545204796200673, 'sigma': -1.774862465080108}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:12:03,411] Trial 32 finished with value: 108446.19393160455 and parameters: {'lambda_': -0.29254801396567537, 'sigma': -1.8600837006081945}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:12:16,775] Trial 33 finished with value: 108626.4612872399 and parameters: {'lambda_': -0.5591398529547196, 'sigma': -1.750835718271973}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.
[I 2025-04-16 10:12:30,132] Trial 34 finished with value: 110345.63855749354 and parameters: {'lambda_': -0.03233495711217844, 'sigma': -1.9864859129386128}. Best is trial 11 with value: 108204.02327921556.
Evaluated NMLL.

Set the model hyperparameters to the best ones found by Optuna.

[14]:
study.best_params
[14]:
{'lambda_': -0.01866037953665966, 'sigma': -1.635646720641679}
[15]:
aav_model.set_hyperparams(np.array([-1.9997, -1.005464]), training_dset)

Now we’ll fit the model using 8192 RFFs. We like to use a more accurate kernel approximationwhen fitting than when tuning for two reasons. First, tuning is more expensive because the model has to be fit multiple times when tuning hyperparameters. Second, model performance usually increases faster by increasing the number of rffs used for fitting than for tuning. (Using 16,384 RFFs or 32,768 here for fitting further increases test set performance as you’d expect, but with diminishing returns. The resulting small performance gain may be worthwhile for some applications and not so much for others.)

On gpu, for fitting, mode=exact works well up to 8,192 RFFs or so, while mode=cg although slower for small numbers of RFFs is more scalable. On this dataset, using 8,192 RFFs, “exact” takes about 70 seconds on our GPU. We’ll use cg here just for illustrative purposes.

tol determines how tight the fit is. 1e-6 (default) is usually fine. Decreasing the number will improve performance but with rapidly diminishing returns and make fitting take longer. For noise free data or to get a small additional boost in performance, use 1e-7. 1e-8 is (nearly always) overkill.

[16]:
aav_model.num_rffs = 8192
start_time = time.time()
aav_model.fit(training_dset, mode = 'cg', tol = 1e-6)
end_time = time.time()
print(f"Wallclock: {end_time - start_time}")
starting fitting
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Using rank: 3000
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
0 iterations complete.
5 iterations complete.
10 iterations complete.
CG iterations: 13
Now performing variance calculations...
Fitting complete.
Wallclock: 112.52763795852661
[17]:
start_time = time.time()
all_preds, ground_truth = [], []
for xfile, yfile, sfile in zip(test_x_files, test_y_files, test_seqlen_files):
    x, y, s = np.load(xfile), np.load(yfile), np.load(sfile)
    ground_truth.append(y)
    preds = aav_model.predict(x, s, get_var = False)
    all_preds.append(preds)

all_preds, ground_truth = np.concatenate(all_preds), np.concatenate(ground_truth)
end_time = time.time()
print(f"Wallclock: {end_time - start_time}")
Wallclock: 0.5303468704223633
[18]:
from scipy.stats import spearmanr

spearmanr(all_preds, ground_truth)
[18]:
SignificanceResult(statistic=np.float64(0.7659180535079166), pvalue=np.float64(0.0))

Spearman’s r of about 0.77 is slightly better than a 1d-CNN reported by Dallago et al for this dataset (0.75) and is similar to the performance of a fine-tuned LLM (Spearman’s r 0.79). As discussed above, we can get further slight improvements in performance just by tweaking this model. We can do even better by using a more informative representation of the protein sequences. In our original paper we achieved a Spearman’s r of about 0.8 on this dataset, outperforming fine-tuned LLMs (and costing significantly less to train than a fine-tuned LLM). Whether small gains in performance from further “tweaking” or more informative representations is worthwhile obviously depends on your application…

[ ]:

[19]:
for testx, testy, tests in zip(test_x_files, test_y_files, test_seqlen_files):
    os.remove(testx)
    os.remove(testy)
    os.remove(tests)
[20]:
for xfile, yfile, sfile in zip(train_x_files, train_y_files, train_seqlen_files):
    os.remove(xfile)
    os.remove(yfile)
    os.remove(sfile)
[ ]: