跳转至

ConcatDataset 和 StackDataset

在 PyTorch 中,ConcatDatasetStackDataset 是两种不同的数据集组合方式。本文介绍了它们的作用及其适用场景。

Python
# 使用 ConcatDataset 连接数据集
concat_dataset = ConcatDataset([dataset1, dataset2])
# 遍历 ConcatDataset
for sample in concat_dataset:
    print(sample)
Text Only
1
2
3
4
5
6
Python
# 使用 StackDataset 组合数据集
stack_dataset = StackDataset(dataset1, dataset2)

# 遍历 StackDataset
for sample in stack_dataset:
    print(sample)
Text Only
(1, 4)
(2, 5)
(3, 6)

作用

ConcatDataset

  • 将多个数据集按顺序拼接,形成一个更大的数据集。
  • 遍历时顺序访问每个子数据集的所有样本。
  • 例如,ConcatDataset([dataset1, dataset2]) 会先遍历 dataset1,再遍历 dataset2
  • 适用于需要合并多个数据集为单一数据集的情况。

StackDataset

  • 将多个数据集的样本按索引一一对应组合。
  • 遍历时同时从每个子数据集中取出相同索引的样本,并组合成元组或列表。
  • 例如,StackDataset(dataset1, dataset2) 会返回 (sample1, sample2),其中 sample1 来自 dataset1sample2 来自 dataset2
  • 适用于需要同时处理多个数据集样本的情况,如混合频率的数据集。

代码示例

Python
import torch
from torch.utils.data import ConcatDataset, Dataset, StackDataset


# 创建两个简单的数据集
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 数据集 1
dataset1 = SimpleDataset([1, 2, 3])
# 数据集 2
dataset2 = SimpleDataset([4, 5, 6])
Python
# 使用 ConcatDataset 连接数据集
concat_dataset = ConcatDataset([dataset1, dataset2])
# 遍历 ConcatDataset
for sample in concat_dataset:
    print(sample)
Text Only
1
2
3
4
5
6
Python
# 使用 StackDataset 组合数据集
stack_dataset = StackDataset(dataset1, dataset2)

# 遍历 StackDataset
for sample in stack_dataset:
    print(sample)
Text Only
(1, 4)
(2, 5)
(3, 6)

评论