forked from zxjzxj9/PyTorchIntroduction
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathex_2_33.py
More file actions
23 lines (21 loc) · 947 Bytes
/
ex_2_33.py
File metadata and controls
23 lines (21 loc) · 947 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
""" 该代码仅为演示类的构造方法所用,并不能实际运行
"""
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, \
"this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # 单进程数据载入
iter_start = self.start
iter_end = self.end
else: # 多进程,分割数据
per_worker = int(math.ceil((self.end - self.start) \
/ float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))