Saya perlu membagi data saya menjadi satu set pelatihan (75%) dan set pengujian (25%). Saat ini saya melakukannya dengan kode di bawah ini:
X, Xt, userInfo, userInfo_train = sklearn.cross_validation.train_test_split(X, userInfo)
Namun, saya ingin membuat stratifikasi set data pelatihan saya. Bagaimana aku melakukan itu? Saya telah mempelajari StratifiedKFold
metode ini, tetapi tidak membiarkan saya menentukan pembagian 75% / 25% dan hanya menyusun set data pelatihan.
python
scikit-learn
pir
sumber
sumber
TL; DR: Gunakan StratifiedShuffleSplit dengan
test_size=0.25
Scikit-learn menyediakan dua modul untuk Stratified Splitting:
n_folds
pelatihan / pengujian sedemikian rupa sehingga kelas seimbang di keduanya.Berikut beberapa kode (langsung dari dokumentasi di atas)
>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation >>> len(skf) 2 >>> for train_index, test_index in skf: ... print("TRAIN:", train_index, "TEST:", test_index) ... X_train, X_test = X[train_index], X[test_index] ... y_train, y_test = y[train_index], y[test_index] ... #fit and predict with X_train/test. Use accuracy metrics to check validation performance
n_iter=1
. Anda dapat menyebutkan ukuran tes di sini sama seperti ditrain_test_split
Kode:
>>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0) >>> len(sss) 1 >>> for train_index, test_index in sss: ... print("TRAIN:", train_index, "TEST:", test_index) ... X_train, X_test = X[train_index], X[test_index] ... y_train, y_test = y[train_index], y[test_index] >>> # fit and predict with your classifier using the above X/y train/test
sumber
0.18.x
,n_iter
seharusnyan_splits
untukStratifiedShuffleSplit
- dan ada API yang sedikit berbeda untuk itu: scikit-learn.org/stable/modules/generated/…y
adalah Seri Pandas, gunakany.iloc[train_index], y.iloc[test_index]
dataframe index: 2,3,5
the first split in sss:[(array([2, 1]), array([0]))]
:(X_train, X_test = X[train_index], X[test_index]
dipanggil itu menimpaX_train
danX_test
? Lalu mengapa tidak hanya satunext(sss)
?Anda cukup melakukannya dengan
train_test_split()
metode yang tersedia di Scikit learn:from sklearn.model_selection import train_test_split train, test = train_test_split(X, test_size=0.25, stratify=X['YOUR_COLUMN_LABEL'])
Saya juga telah menyiapkan GitHub Gist singkat yang menunjukkan cara
stratify
kerja opsi:https://gist.github.com/SHi-ON/63839f3a3647051a180cb03af0f7d0d9
sumber
Berikut adalah contoh untuk data berkelanjutan / regresi (hingga masalah ini di GitHub diselesaikan).
min = np.amin(y) max = np.amax(y) # 5 bins may be too few for larger datasets. bins = np.linspace(start=min, stop=max, num=5) y_binned = np.digitize(y, bins, right=True) X_train, X_test, y_train, y_test = train_test_split( X, y, stratify=y_binned )
start
min danstop
maksimal target kontinu Anda.right=True
maka itu akan lebih atau kurang membuat nilai maksimal Anda menjadi bin terpisah dan pemisahan Anda akan selalu gagal karena terlalu sedikit sampel akan berada di bin tambahan itu.sumber
Selain jawaban yang diterima oleh @Andreas Mueller, hanya ingin menambahkannya sebagai @tangy yang disebutkan di atas:
StratifiedShuffleSplit paling mirip dengan train_test_split ( stratify = y) dengan fitur tambahan:
sumber
#train_size is 1 - tst_size - vld_size tst_size=0.15 vld_size=0.15 X_train_test, X_valid, y_train_test, y_valid = train_test_split(df.drop(y, axis=1), df.y, test_size = vld_size, random_state=13903) X_train_test_V=pd.DataFrame(X_train_test) X_valid=pd.DataFrame(X_valid) X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=tst_size, random_state=13903)
sumber
Memperbarui jawaban @tangy dari atas ke versi scikit-learn: 0.23.2 ( dokumentasi StratifiedShuffleSplit ).
from sklearn.model_selection import StratifiedShuffleSplit n_splits = 1 # We only want a single split in this case sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=0) for train_index, test_index in sss.split(X, y): X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]
sumber