运行出现报错。修改数据格式
输出sample_ids的值,可以看到数据类型是 torch.int32
解决
需要将sample_ids类型转为long,修改方式:
idx= idx.type(torch.long)
或
idx= self.tensor(idx, dtype=torch.long)
参考:
IndexError: tensors used as indices must be long, byte or bool tensors
知乎:https://zhuanlan.zhihu.com/p/565931659