• 作者:老汪软件
  • 发表时间:2023-12-29 23:00
  • 浏览量:

我在做一个超大数据集的多分类,设备 22.04+i9 + 4090+64GB RAM,第一次的训练的训练集有700万张,训练成功。后面收集到更多数据集,数据增强后达到了1000万张。但第二次训练4个小时后,就被系统杀掉进程了,原因是Out of 。找了很久的原因,发现内存随着训练step的增加而线性增加,猜测是内存泄露,最后定位到了的参数(只要=0就没有问题)。

()中的list转换成时,会发生内存泄漏,要避免list的使用,可以通过使用np.array来代替list。

自定义中的类,然后类中的list全部用np.array来代替。这样的话,将np.array转换成的过程就不会发生内存泄露。

1.错误的加载数据集方法1

# 加载数据
train_data = datasets.ImageFolder(root=TRAIN_DIR_ARG, transform=transform)
valid_data = datasets.ImageFolder(root=VALIDATION_DIR, transform=transform)
test_data = datasets.ImageFolder(root=TEST_DIR, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

2.错误的加载数据集方法2(重写了方法)


class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        # 遍历数据目录并收集图像文件路径和对应的标签
        classes = os.listdir(data_dir)
        for i, class_name in enumerate(classes):
            class_dir = os.path.join(data_dir, class_name)
            if os.path.isdir(class_dir):
                for image_name in os.listdir(class_dir):
                    image_path = os.path.join(class_dir, image_name)
                    self.image_paths.append(image_path)
                    self.labels.append(i)
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        # # 在需要时加载图像
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image, label
train_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

3.重写的正确方法(重写了方法,list全部转成np.array)

class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []  # 使用Python列表
        self.labels = []  # 使用Python列表
        # 遍历数据目录并收集图像文件路径和对应的标签
        classes = os.listdir(data_dir)
        for i, class_name in enumerate(classes):
            class_dir = os.path.join(data_dir, class_name)
            if os.path.isdir(class_dir):
                for image_name in os.listdir(class_dir):
                    image_path = os.path.join(class_dir, image_name)
                    self.image_paths.append(image_path)  # 添加到Python列表
                    self.labels.append(i)  # 添加到Python列表
        # 转换为NumPy数组,这里就是解决内存泄露的关键代码
        self.image_paths = np.array(self.image_paths)
        self.labels = np.array(self.labels)
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        # 在需要时加载图像
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        # 将图像数据转换为NumPy数组
        image = np.array(image)
        return image, label
train_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)