前言
接上一篇文章,pytorch 基础操作知识(一),继续记录整理一些pytorch
的基本使用与实验代码
Broadcast机制
Ps:broadcast
并不是函数,而是在不同size
的tensor
之间进行加减操作会自动进行的一种机制
broadcast
的两个特点:
- 能够进行维度的扩展,相当于
expand
,但是是自动扩展 - 扩展的时候不需要拷贝数据,能够节约内存(直接计算的时候加到最终结果,中间转化是不需要存储的)
直接看下面的代码例子:
# broadcast 机制 注意:不是函数
import torch
a = torch.rand(3,3)
print(a)
b = torch.rand(1,3)
print(b)
c = a+b
print(c)
c = torch.rand(3,2,2)
# d = c.repeat(1,2,2)
d = c.expand(3,4,4)
print(d.shape)
注意,上面的例子就是想说明,expand
对应非1的不能拓展,所以基于expand
的broadcast
机制也无法进行
合并操作
tensor
的合并 cat
和stack
函数,介绍直接看下面的代码实例
# tensor的合并 cat stack 大前提,两个tensor的维度要相同
# example 1
a = torch.rand(2,3,4)
b = torch.rand(2,3,4)
print(torch.cat((a,b), dim=0).shape)
print(torch.stack((a,b), dim=0).shape) # 第一个区别,cat在指定维度进行合并操作,然而stack是在指定维度新增一个维度进行合并操作
# example2
a = torch.rand(2,1,4)
b = torch.rand(2,3,4)
print(torch.cat((a,b), dim=0).shape)
print(torch.stack((a,b), dim=0).shape) # 指定维度之外的维度不一样的情况下,两个函数都会报错
# example3
a = torch.rand(1,3,4)
b = torch.rand(2,3,4)
print(torch.cat((a,b), dim=0).shape)
print(torch.stack((a,b), dim=0).shape) # 指定维度不同,cat可以进行合并操作,stack就不可以合并
分割操作
# tensor的拆分 split与chunk
# example 1
a = torch.rand(3,3,3)
aa, ab, ac = a.split(1, dim=0) # split里面是数字的话,表示在dim对应维度,均分成指定数值的tensor
print(aa.shape)
print(ab.shape)
print(ac.shape)
print("====================")
a = torch.rand(4,3,3)
aa, ab = a.split(2, dim=0)
print(aa.shape)
print(ab.shape)
print("====================")
# example 2
a = torch.rand(3,3,3)
aa, ab = a.split([1,2], dim=0) # split里面是列表的话,代表着dim对应维度,拆分成指定的形式
print(aa.shape)
print(ab.shape)
# example 3
print("====================")
a = torch.rand(8,3,3)
aa, ab = a.chunk(2, dim=0) # 与split的数字情况区分下,这个是直接按数字分成几份
print(aa.shape)
print(ab.shape)
数学运算
基本的加减乘除
直接看下面的代码吧,注意区分点对点运算,而不是矩阵运算
# 加减乘除 建议直接使用运算符
import torch
a = torch.rand(3,4)
b = torch.rand(4)
# 先输出a tensor与b tensor
print(a)
print(b)
# 加法
c = a+b
print(c)
# 减法
d = a-b
print(d)
# 乘法 注意这个是值的点乘,矩阵乘法不是这个符号
e = a*b
print(e)
# 除法 注意这个是值的点对点除法,矩阵运算注意区分
f = a/b
print(f)
矩阵的乘法
矩阵乘法有三种形式,其中mm只适用与二维矩阵,不推荐使用,一般推荐使用matmul
# 矩阵乘法,推荐写法matmul
# 注意mm只适用于二维的矩阵乘法
a = torch.rand(2,2)
b = torch.rand(2,2)
print(torch.mm(a,b))
print(torch.matmul(a,b))
print(a@b)
多维tensor
的乘法,其实还是最后两个维度进行乘法,前两个维度保持不变即可
a = torch.rand(3,3,2,2)
b = torch.rand(3,3,2,3)
c = torch.matmul(a,b)
print(c.shape)
N次方/N次方根/倒数
一般N次方与N次方跟通用pow
即可
# N次方
a = torch.full([2,2],4, dtype=torch.float)
print(a)
print(a.pow(2)) # 2指的就是n次方的n
print(a**2)
# N次方根
print(a**(0.5))
print(a.pow(0.5))
# 倒数
print(a.rsqrt())
自然指数与对数
这里其实就是对矩阵中的每个元素做处理
# e^x
a = torch.full([2,2],1, dtype=torch.float)
print(a)
b = torch.exp(a)
print(b)
# ln x
print(torch.log(b))
值近似处理的几个函数
# floor 向下取值
a = torch.rand([2,2])
print(a)
print(a.floor())
# ceil 向上取值
print(a.ceil())
# 取整数部分
print(a.trunc())
# 取小数部分
print(a.frac())
# 四舍五入
print(a.round())