dataloader各项参数详解
后台-插件-广告管理-内容页头部广告(手机) |
-
pin_memory(bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.
当设置为True时,将会在返回**batch之前将batch**数据复制到固定的内存区域,这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。
通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用**pin_memory**可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。
需要注意的是,使用**pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory**并不会带来明显的加速效果。
-
num_workers (int, optional) – how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为**0**,即在主进程中进行数据加载,而不使用额外的子进程。
下面说一下个人的理解,在初始化 dataloader 对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程数=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)
每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个 (即一帧),并将其放到该worker独有的内存队列中。
要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。
当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。
这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解 -
collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
整合多个样本到一个batch时需要调用的函数,当 __getitem__ 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求
比如在openpcdet框架的poinpillar中, __getitem__ 返回的是一个包含标注信息、点云信息、图像信息等的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包
在poinpillar中该函数为:
1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。
在线投稿:投稿 站长QQ:1888636
后台-插件-广告管理-内容页尾部广告(手机) |