22.02.24 [kfood프로젝트] 음식데이터 YOLOv3 training을 위한 train, test, validation 8:1:1 split하기
2022. 2. 24. 18:23ㆍ프로젝트/KFood
train test validation split하기 참고 사이트
https://lynnshin.tistory.com/46
참고 코드
import os
from glob import glob
import shutil
from sklearn.model_selection import train_test_split
# getting list of images
image_files = glob('데이터경로/*.jpg')
images = [name.replace('.jpg','') for name in image_files]
# splitting the dataset
# train:test:val = 8:1:1
train_names, test_names = train_test_split(images, test_size = 0.2, random_state=42, shuffle=True)
val_names, test_names = train_test_split(images, test_size = 0.2, random_state=42, shuffle=True)
def batch_move_files(file_list, source_path, destination_path):
for file in file_list:
# 경로에서 마지막 파일명만 가져와서 확장자 붙임
image = file.split('/')[-1] + '.jpg'
txt = file.split('/')[-1] + '.json'
shutil.copy(os.path.join(source_path, image), destination_path)
shutil.copy(os.path.join(source_path, txt), destination_path)
return
# 원본 data path
source_dir = '데이터경로'
# 최종본 data path
train_dir = '/output경로/train'
test_dir = '/output경로/test'
val_dir = '/output경로/val'
batch_move_files(train_names, source_dir, train_dir)
batch_move_files(test_names, source_dir, test_dir)
batch_move_files(val_names, source_dir, val_dir)
우리 파일구성
우리 코드
label_list = os.listdir('./labels')
print(sorted(label_list))
label_list = sorted(label_list)
print(len(label_list))
label_list
split 코드
import os
from glob import glob
import shutil
import sklearn
from sklearn.model_selection import train_test_split
# 확인용 list들
train_nums=[]
test_nums=[]
val_nums=[]
# image들 가지고오기
for i in label_list:
image_files = glob(f'./images/{i}/image/*.jpg')
print(image_files[:5]) # './images/가지구이/image/B080302XX_10619.jpg'
images = [name.replace('.jpg','') for name in image_files]
print(images[:5]) # './images/가지구이/image/B080302XX_10619'
# split하기 8:1:1 (train:test = 8:2로 하고 그 test를 또 test/vali로 0.5씩)
train_names , test_names = train_test_split(images, test_size=0.2, random_state=42,
shuffle=True)
# random_state: 세트를 섞을 때 해당 int 값을 보고 섞으며, 하이퍼 파라미터를
# 튜닝시 이 값을 고정해두고 튜닝해야 매번 데이터셋이 변경되는 것을 방지할 수 있습니다.
val_names , test_names = train_test_split(test_names, test_size=0.5, random_state=42,
shuffle=True)
print('train_names: ',train_names[:5],len(train_names),'개\n')
train_nums.append(len(train_names))
print('test_names: ',test_names[:5],len(test_names),'개\n')
test_nums.append(len(test_names))
print('val_names: ',val_names[:5],len(val_names),'개\n')
val_nums.append(len(val_names))
def batch_move_files(file_list, img_source_path, txt_source_path, img_destination_path, txt_destination_path):
for file in file_list:
# 경로에서 맨 마지막 split값(이미지명) B080302XX_10619 만 가져와서 확장자 붙이기
image = file.split('/')[-1] + '.jpg' # B080302XX_10619.jpg
txt = file.split('/')[-1] + '.txt' # B080302XX_10619.txt
shutil.copy(os.path.join(img_source_path, image), img_destination_path)
shutil.copy(os.path.join(txt_source_path, txt), txt_destination_path)
# shutil.copy : dest 에 폴더 경로를 지정할 경우 src 파일명과 같은 파일 생성
return
# 원본 data path
img_source_path = f'./images/{i}/image'
txt_source_path = f'./labels/{i}'
# os.path.join 실험용 print
print(os.path.join(img_source_path,'B080302XX_10619.jpg'))
print(os.path.join(txt_source_path,'B080302XX_10619.txt'))
# 진짜 최종 data path
img_destination_path = './nochilsu/images/'
txt_destination_path = './nochilsu/labels/'
# train_dir = './nochilsu/train'
# test_dir = './nochilsu/test'
# val_dir = './nochilsu/val'
batch_move_files(train_names, img_source_path, txt_source_path, img_destination_path+'train',txt_destination_path+'train')
batch_move_files(test_names, img_source_path, txt_source_path, img_destination_path+'test',txt_destination_path+'test')
batch_move_files(val_names, img_source_path, txt_source_path, img_destination_path+'val',txt_destination_path+'val')
8:1:1로 잘 split됬는지 확인
import pandas as pd
data_split_df = pd.DataFrame({
'train 갯수':train_nums,
'test 갯수':test_nums,
'val 갯수':val_nums
})
data_split_df.index = label_list
data_split_df
print('train 이미지 갯수',len(os.listdir('./nochilsu/images/train')))
print('test 이미지 갯수',len(os.listdir('./nochilsu/images/test')))
print('validation 이미지 갯수',len(os.listdir('./nochilsu/images/val')))
print('train 라벨 갯수',len(os.listdir('./nochilsu/labels/train')))
print('test 라벨 갯수',len(os.listdir('./nochilsu/labels/test')))
print('validation 라벨 갯수',len(os.listdir('./nochilsu/labels/val')))
8:1:1 정도로 잘 split된 모습이다.