torch.where() 用于將兩個broadcastable的tensor組合成新的tensor,類似于c++中的三元操作符“?:”
區別于python numpy中的where()直接可以找到特定條件元素的index

想要實現numpy中where()的功能,可以借助nonzero()

對應numpy中的where()操作效果:

補充:Pytorch torch.Tensor.detach()方法的用法及修改指定模塊權重的方法
detach
detach的中文意思是分離,官方解釋是返回一個新的Tensor,從當前的計算圖中分離出來

需要注意的是,返回的Tensor和原Tensor共享相同的存儲空間,但是返回的 Tensor 永遠不會需要梯度

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
那么這個函數有什么作用?
–假如A網絡輸出了一個Tensor類型的變量a, a要作為輸入傳入到B網絡中,如果我想通過損失函數反向傳播修改B網絡的參數,但是不想修改A網絡的參數,這個時候就可以使用detcah()方法
a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()
來看一個實際的例子:
import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad #True
y = t.ones(1, requires_grad=True)
y.requires_grad #True
x = x.detach() #分離之后
x.requires_grad #False
y = x+y #tensor([2.])
y.requires_grad #我還是True
y.retain_grad() #y不是葉子張量,要加上這一行
z = t.pow(y, 2)
z.backward() #反向傳播
y.grad #tensor([4.])
x.grad #None
以上代碼就說明了反向傳播到y就結束了,沒有到達x,所以x的grad屬性為None
既然談到了修改模型的權重問題,那么還有一種情況是:
–假如A網絡輸出了一個Tensor類型的變量a, a要作為輸入傳入到B網絡中,如果我想通過損失函數反向傳播修改A網絡的參數,但是不想修改B網絡的參數,這個時候又應該怎么辦了?
這時可以使用Tensor.requires_grad屬性,只需要將requires_grad修改為False即可.
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。
您可能感興趣的文章:- Python深度學習之使用Pytorch搭建ShuffleNetv2
- win10系統配置GPU版本Pytorch的詳細教程
- 淺談pytorch中的nn.Sequential(*net[3: 5])是啥意思
- pytorch visdom安裝開啟及使用方法
- PyTorch CUDA環境配置及安裝的步驟(圖文教程)
- pytorch中的nn.ZeroPad2d()零填充函數實例詳解
- 使用pytorch實現線性回歸
- pytorch實現線性回歸以及多元回歸
- pytorch顯存一直變大的解決方案
- 在Windows下安裝配置CPU版的PyTorch的方法
- PyTorch兩種安裝方法
- PyTorch的Debug指南