我有一个熊猫数据框架,我想把它分为3个单独的集。我知道使用sklearn中的train_test_split。交叉验证,可以将数据分为两组(训练和测试)。然而,我无法找到将数据分成三组的任何解决方案。最好是有原始数据的下标。
我知道一个解决办法是使用train_test_split两次,并以某种方式调整索引。但是是否有一种更标准/内置的方法将数据分成3组而不是2组?
我有一个熊猫数据框架,我想把它分为3个单独的集。我知道使用sklearn中的train_test_split。交叉验证,可以将数据分为两组(训练和测试)。然而,我无法找到将数据分成三组的任何解决方案。最好是有原始数据的下标。
我知道一个解决办法是使用train_test_split两次,并以某种方式调整索引。但是是否有一种更标准/内置的方法将数据分成3组而不是2组?
当前回答
def train_val_test_split(X, y, train_size, val_size, test_size):
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size = test_size)
relative_train_size = train_size / (val_size + train_size)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val,
train_size = relative_train_size, test_size = 1-relative_train_size)
return X_train, X_val, X_test, y_train, y_val, y_test
在这里,我们使用sklearn的train_test_split将数据分割2次
其他回答
在监督学习的情况下,你可能想拆分X和y(其中X是你的输入,y是基本真理输出)。 你只需要注意在分割之前以同样的方式洗牌X和y。
在这里,X和y在同一个数据帧中,所以我们对它们进行洗牌,将它们分开,并对每个数据帧应用拆分(就像在选择的答案中一样),或者X和y在两个不同的数据帧中,所以我们洗牌X,将y按洗牌X的方式重新排序,并对每个数据帧应用拆分。
# 1st case: df contains X and y (where y is the "target" column of df)
df_shuffled = df.sample(frac=1)
X_shuffled = df_shuffled.drop("target", axis = 1)
y_shuffled = df_shuffled["target"]
# 2nd case: X and y are two separated dataframes
X_shuffled = X.sample(frac=1)
y_shuffled = y[X_shuffled.index]
# We do the split as in the chosen answer
X_train, X_validation, X_test = np.split(X_shuffled, [int(0.6*len(X)),int(0.8*len(X))])
y_train, y_validation, y_test = np.split(y_shuffled, [int(0.6*len(X)),int(0.8*len(X))])
def train_val_test_split(X, y, train_size, val_size, test_size):
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size = test_size)
relative_train_size = train_size / (val_size + train_size)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val,
train_size = relative_train_size, test_size = 1-relative_train_size)
return X_train, X_val, X_test, y_train, y_val, y_test
在这里,我们使用sklearn的train_test_split将数据分割2次
将数据集分割为训练集和测试集,如在其他答案中一样,使用
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
然后,如果您适合您的模型,您可以添加validation_split作为参数。这样就不需要提前创建验证集。例如:
from tensorflow.keras import Model
model = Model(input_layer, out)
[...]
history = model.fit(x=X_train, y=y_train, [...], validation_split = 0.3)
验证集旨在作为训练集训练期间的代表运行测试集,完全来自训练集,无论是通过k-fold交叉验证(推荐)还是通过validation_split;然后,您不需要单独创建一个验证集,仍然可以将数据集分为您所要求的三个集。
注意:
函数被编写来处理随机集创建的播种。你不应该依赖集分割,它不会随机化集合。
import numpy as np
import pandas as pd
def train_validate_test_split(df, train_percent=.6, validate_percent=.2, seed=None):
np.random.seed(seed)
perm = np.random.permutation(df.index)
m = len(df.index)
train_end = int(train_percent * m)
validate_end = int(validate_percent * m) + train_end
train = df.iloc[perm[:train_end]]
validate = df.iloc[perm[train_end:validate_end]]
test = df.iloc[perm[validate_end:]]
return train, validate, test
示范
np.random.seed([3,1415])
df = pd.DataFrame(np.random.rand(10, 5), columns=list('ABCDE'))
df
train, validate, test = train_validate_test_split(df)
train
validate
test
回答任意数量的子集:
def _separate_dataset(patches, label_patches, percentage, shuffle: bool = True):
"""
:param patches: data patches
:param label_patches: label patches
:param percentage: list of percentages for each value, example [0.9, 0.02, 0.08] to get 90% train, 2% val and 8% test.
:param shuffle: Shuffle dataset before split.
:return: tuple of two lists of size = len(percentage), one with data x and other with labels y.
"""
x_test = patches
y_test = label_patches
percentage = list(percentage) # need it to be mutable
assert sum(percentage) == 1., f"percentage must add to 1, but it adds to sum{percentage} = {sum(percentage)}"
x = []
y = []
for i, per in enumerate(percentage[:-1]):
x_train, x_test, y_train, y_test = train_test_split(x_test, y_test, test_size=1-per, shuffle=shuffle)
percentage[i+1:] = [value / (1-percentage[i]) for value in percentage[i+1:]]
x.append(x_train)
y.append(y_train)
x.append(x_test)
y.append(y_test)
return x, y
这适用于任何比例。在本例中,您应该执行percentage = [train_percentage, val_percentage, test_percentage]。