torch indices x[indices] 内存不足崩溃,python进程锁报错。
报错
Process Process-167:
Traceback (most recent call last):File "/usr/lib/python3.10/multiprocessing/process.py", line 317, in _bootstraputil._exit_function()File "/usr/lib/python3.10/multiprocessing/util.py", line 360, in _exit_function_run_finalizers()File "/usr/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizersfinalizer()File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__res = self._callback(*self._args, **self._kwargs)File "/usr/lib/python3.10/multiprocessing/queues.py", line 199, in _finalize_jointhread.join()File "/usr/lib/python3.10/threading.py", line 1096, in joinself._wait_for_tstate_lock()File "/usr/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lockif lock.acquire(block, timeout):
看到进程锁,我人都麻了,找了好多天,然后每一行的去看,发现是这里出了问题X[mask.expand_as(X)]
这里的X如果很大的话,就会崩溃,所以我们需要分chunk去做。
这是会崩溃的代码,会在最后一行崩溃。
import torch
X=torch.rand([2957312, 1024])
mask=torch.randint(2,[2957312, 1],dtype=bool)
X=X[mask.expand_as(X)]
这是修正后,分chunk去做的代码
import torch
def select_elements(X, mask, chunk_size=10000):selected_indices = []for i in range(0, len(X), chunk_size):chunk = X[i:i + chunk_size]chunk_mask = mask[i:i + chunk_size]expanded_mask = chunk_mask.expand_as(chunk)selected_chunk = chunk[expanded_mask]selected_indices.append(selected_chunk)selected_indices = torch.cat(selected_indices)return selected_indices
X=torch.rand([2957312, 1024])
mask=torch.randint(2,[2957312, 1],dtype=bool)
X=select_elements(X, mask)
这样就运行成功了。