pytorch函数学习之squeeze函数 cat函数的运用
发布时间:2022-01-11 22:09:09 所属栏目:语言 来源:互联网
导读:这篇文章主要给大家分享pytorch函数的内容,本文给大家介绍两个函数,分别是squeeze函数、cat函数。那么这两个函数有什么用呢?用法是什么?下面我们一起来学习一下。 1 squeeze(): 去除size为1的维度,包括行和列。 至于维度大于等于2时,squeeze()不起作
这篇文章主要给大家分享pytorch函数的内容,本文给大家介绍两个函数,分别是squeeze函数、cat函数。那么这两个函数有什么用呢?用法是什么?下面我们一起来学习一下。 1 squeeze(): 去除size为1的维度,包括行和列。 至于维度大于等于2时,squeeze()不起作用。 行、例: >>> torch.rand(4, 1, 3) (0 ,.,.) = 0.5391 0.8523 0.9260 (1 ,.,.) = 0.2507 0.9512 0.6578 (2 ,.,.) = 0.7302 0.3531 0.9442 (3 ,.,.) = 0.2689 0.4367 0.6610 [torch.FloatTensor of size 4x1x3] >>> torch.rand(4, 1, 3).squeeze() 0.0801 0.4600 0.1799 0.0236 0.7137 0.6128 0.0242 0.3847 0.4546 0.9004 0.5018 0.4021 [torch.FloatTensor of size 4x3] 列、例: >>> torch.rand(4, 3, 1) (0 ,.,.) = 0.7013 0.9818 0.9723 (1 ,.,.) = 0.9902 0.8354 0.3864 (2 ,.,.) = 0.4620 0.0844 0.5707 (3 ,.,.) = 0.5722 0.2494 0.5815 [torch.FloatTensor of size 4x3x1] >>> torch.rand(4, 3, 1).squeeze() 0.8784 0.6203 0.8213 0.7238 0.5447 0.8253 0.1719 0.7830 0.1046 0.0233 0.9771 0.2278 [torch.FloatTensor of size 4x3] 不变、例: >>> torch.rand(4, 3, 2) (0 ,.,.) = 0.6618 0.1678 0.3476 0.0329 0.1865 0.4349 (1 ,.,.) = 0.7588 0.8972 0.3339 0.8376 0.6289 0.9456 (2 ,.,.) = 0.1392 0.0320 0.0033 0.0187 0.8229 0.0005 (3 ,.,.) = 0.2327 0.6264 0.4810 0.6642 0.8625 0.6334 [torch.FloatTensor of size 4x3x2] >>> torch.rand(4, 3, 2).squeeze() (0 ,.,.) = 0.0593 0.8910 0.9779 0.1530 0.9210 0.2248 (1 ,.,.) = 0.7938 0.9362 0.1064 0.6630 0.9321 0.0453 (2 ,.,.) = 0.0189 0.9187 0.4458 0.9925 0.9928 0.7895 (3 ,.,.) = 0.5116 0.7253 0.0132 0.6673 0.9410 0.8159 [torch.FloatTensor of size 4x3x2] 2 cat函数 >>> t1=torch.FloatTensor(torch.randn(2,3)) >>> t1 -1.9405 1.2009 0.0018 0.9463 0.4409 -1.9017 [torch.FloatTensor of size 2x3] >>> t2=torch.FloatTensor(torch.randn(2,2)) >>> t2 0.0942 0.1581 1.1621 1.2617 [torch.FloatTensor of size 2x2] >>> torch.cat((t1, t2), 1) -1.9405 1.2009 0.0018 0.0942 0.1581 0.9463 0.4409 -1.9017 1.1621 1.2617 [torch.FloatTensor of size 2x5] 补充:pytorch中 max()、view()、 squeeze()、 unsqueeze() 查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。 一、torch.max() import torch a=torch.randn(3) print("a:n",a) print('max(a):',torch.max(a)) (编辑:温州站长网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |