Did you know that you can navigate the posts by swiping left and right?

PyTorch 常用易混函数

10 May 2022 . category: tech .

@TOC

1.张量复制

detach()

返回一个与当前图形分离的新张量。

返回的张量与原始张量共享相同的存储空间。对其中任何一个进行修改,可能会在正确性检查中触发错误。重要提示:之前,对返回的张量进行更改,也会更新原始张量。现在,这些更改将不再更新原始张量,而是会触发一个错误。

detach_()

将一个张量从创建它的图中分离,并把它设置成叶结点。

clone()

Tensor.clone(*, memory_format=torch.preserve_format)  Tensor

返回与原张量有相同大小和数据类型的张量。

总结

Operation | New/Shared memory | Still in computation graph | |–|–|–| tensor.clone() | New | Yes | tensor.detach() | Shared | No | tensor.detach().clone() | New | No |

2.模型参数

parameters()

返回一个模型参数的迭代器。

# example
for name, param in model.named_parameters():
   if name in ['bias']:
       print(param.size())

源码:

def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
	for name, param in self.named_parameters(recurse=recurse):
		yield param

named_parameters()

返回一个模型参数的迭代器,包括参数名称和参数本身。

# example
for param in model.parameters():
    print(type(param), param.size())

源码:

def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
	gen = self._named_members(lambda module: module._parameters.items(), prefix=prefix, recurse=recurse)
	for elem in gen:
		yield elem

state_dict()

返回包含模型所有状态的字典,包括模型参数和持久缓冲区(persistent buffers)。字典的键是相应的参数和缓冲区名称。

3.张量初始化

rand()

torch.rand(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)  Tensor

返回一个由区间[0,1)上均匀分布的随机数填充的张量。

# example
>>> torch.rand(4)
tensor([ 0.5204,  0.2503,  0.3525,  0.5673])
>>> torch.rand(2, 3)
tensor([[ 0.8237,  0.5781,  0.6879],
        [ 0.3816,  0.7249,  0.0998]])

randn()

torch.randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)  Tensor

返回一个由标准正态分布中的随机数填充的张量。

# example
>>> torch.randn(4)
tensor([-2.1436,  0.9966,  2.3426, -0.6366])
>>> torch.randn(2, 3)
tensor([[ 1.5954,  2.8929, -1.0923],
        [ 1.1719, -0.4709, -0.1996]])

小技巧

如需生成大量的随机张量,且张量维度高,则非常耗费时间,但设定device为cuda后可以提速数倍。 代码示例如下:

torch.randn(torch.size([5000,500]),device=torch.device("cuda:0"))