pytorch 基础操作知识(二)

前言

接上一篇文章,pytorch 基础操作知识(一),继续记录整理一些pytorch的基本使用与实验代码

Broadcast机制

Ps:broadcast并不是函数,而是在不同sizetensor之间进行加减操作会自动进行的一种机制

broadcast的两个特点:

  1. 能够进行维度的扩展,相当于expand,但是是自动扩展
  2. 扩展的时候不需要拷贝数据,能够节约内存(直接计算的时候加到最终结果,中间转化是不需要存储的)

直接看下面的代码例子:

# broadcast 机制   注意:不是函数

import torch

a = torch.rand(3,3)
print(a)
b = torch.rand(1,3)
print(b)

c = a+b
print(c)

在这里插入图片描述

c = torch.rand(3,2,2)

# d = c.repeat(1,2,2)
d = c.expand(3,4,4)

print(d.shape)

在这里插入图片描述
注意,上面的例子就是想说明,expand对应非1的不能拓展,所以基于expandbroadcast机制也无法进行

合并操作

tensor的合并 catstack函数,介绍直接看下面的代码实例

# tensor的合并 cat stack   大前提,两个tensor的维度要相同

# example 1
a = torch.rand(2,3,4)
b = torch.rand(2,3,4)

print(torch.cat((a,b), dim=0).shape)
print(torch.stack((a,b), dim=0).shape)  # 第一个区别,cat在指定维度进行合并操作,然而stack是在指定维度新增一个维度进行合并操作

# example2
a = torch.rand(2,1,4)
b = torch.rand(2,3,4)

print(torch.cat((a,b), dim=0).shape)
print(torch.stack((a,b), dim=0).shape) # 指定维度之外的维度不一样的情况下,两个函数都会报错

# example3
a = torch.rand(1,3,4)
b = torch.rand(2,3,4)

print(torch.cat((a,b), dim=0).shape)
print(torch.stack((a,b), dim=0).shape) # 指定维度不同,cat可以进行合并操作,stack就不可以合并

在这里插入图片描述

分割操作

# tensor的拆分 split与chunk

# example 1

a = torch.rand(3,3,3)

aa, ab, ac = a.split(1, dim=0) # split里面是数字的话,表示在dim对应维度,均分成指定数值的tensor
print(aa.shape)
print(ab.shape)
print(ac.shape)

print("====================")
a = torch.rand(4,3,3)

aa, ab = a.split(2, dim=0)
print(aa.shape)
print(ab.shape)

print("====================")
# example 2
a = torch.rand(3,3,3)

aa, ab = a.split([1,2], dim=0)  # split里面是列表的话,代表着dim对应维度,拆分成指定的形式
print(aa.shape)
print(ab.shape)

# example 3

print("====================")
a = torch.rand(8,3,3)

aa, ab = a.chunk(2, dim=0)    # 与split的数字情况区分下,这个是直接按数字分成几份
print(aa.shape)
print(ab.shape)

在这里插入图片描述

数学运算

基本的加减乘除

直接看下面的代码吧,注意区分点对点运算,而不是矩阵运算

# 加减乘除 建议直接使用运算符

import torch

a = torch.rand(3,4)
b = torch.rand(4)

# 先输出a tensor与b tensor
print(a)
print(b)

# 加法
c = a+b
print(c)

# 减法
d = a-b
print(d)

# 乘法 注意这个是值的点乘,矩阵乘法不是这个符号
e = a*b
print(e)

# 除法 注意这个是值的点对点除法,矩阵运算注意区分
f = a/b
print(f)

在这里插入图片描述

矩阵的乘法

矩阵乘法有三种形式,其中mm只适用与二维矩阵,不推荐使用,一般推荐使用matmul

# 矩阵乘法,推荐写法matmul
# 注意mm只适用于二维的矩阵乘法

a = torch.rand(2,2)
b = torch.rand(2,2)

print(torch.mm(a,b))

print(torch.matmul(a,b))

print(a@b)

在这里插入图片描述
多维tensor的乘法,其实还是最后两个维度进行乘法,前两个维度保持不变即可

a = torch.rand(3,3,2,2)
b = torch.rand(3,3,2,3)

c = torch.matmul(a,b)
print(c.shape)

在这里插入图片描述

N次方/N次方根/倒数

一般N次方与N次方跟通用pow即可

# N次方
a = torch.full([2,2],4, dtype=torch.float)
print(a)

print(a.pow(2)) # 2指的就是n次方的n

print(a**2)

# N次方根

print(a**(0.5))

print(a.pow(0.5))

# 倒数

print(a.rsqrt())

在这里插入图片描述

自然指数与对数

这里其实就是对矩阵中的每个元素做处理

# e^x
a = torch.full([2,2],1, dtype=torch.float)
print(a)

b = torch.exp(a)
print(b)


# ln x
print(torch.log(b))

在这里插入图片描述

值近似处理的几个函数

# floor 向下取值
a = torch.rand([2,2])
print(a)

print(a.floor())

# ceil 向上取值
print(a.ceil())

# 取整数部分
print(a.trunc())

# 取小数部分
print(a.frac())


# 四舍五入
print(a.round())

在这里插入图片描述

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 代码科技 设计师:Amelia_0503 返回首页