婷婷综合国产,91蜜桃婷婷狠狠久久综合9色 ,九九九九九精品,国产综合av

主頁 > 知識庫 > pytorch自定義不可導激活函數的操作

pytorch自定義不可導激活函數的操作

熱門標簽:商家地圖標注海報 打電話機器人營銷 騰訊地圖標注沒法顯示 南陽打電話機器人 孝感營銷電話機器人效果怎么樣 地圖標注自己和別人標注區(qū)別 海外網吧地圖標注注冊 聊城語音外呼系統(tǒng) ai電銷機器人的優(yōu)勢

pytorch自定義不可導激活函數

今天自定義不可導函數的時候遇到了一個大坑。

首先我需要自定義一個函數:sign_f

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs  0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_-1.] = 0
        return grad_output

然后我需要把它封裝為一個module 類型,就像 nn.Conv2d 模塊 封裝 f.conv2d 一樣,于是

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
	# 我需要的module
    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        
    def forward(self, inputs):
    	# 使用自定義函數
        outs = sign_f(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs  0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_-1.] = 0
        return grad_output

結果報錯

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我試了半天,發(fā)現自定義函數后面要加 apply ,詳細見下面

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        self.r = sign_f.apply ### -----注意此處
        
    def forward(self, inputs):
        outs = self.r(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs  0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_-1.] = 0
        return grad_output

問題解決了!

PyTorch自定義帶學習參數的激活函數(如sigmoid)

有的時候我們需要給損失函數設一個超參數但是又不想設固定閾值想和網絡一起自動學習,例如給Sigmoid一個參數alpha進行調節(jié)

函數如下:

import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))

驗證和Sigmoid的一致性

class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
   
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
        [0.4379, 0.1828, 0.4629],
        [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

輸出結果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=MulBackward0>)

驗證權重是不是會更新

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

輸出結果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

會更新~

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • pytorch方法測試——激活函數(ReLU)詳解
  • PyTorch中常用的激活函數的方法示例
  • Pytorch 實現自定義參數層的例子

標簽:南寧 迪慶 牡丹江 撫州 聊城 楊凌 揚州 六盤水

巨人網絡通訊聲明:本文標題《pytorch自定義不可導激活函數的操作》,本文關鍵詞  pytorch,自定義,不,可導,激活,;如發(fā)現本文內容存在版權問題,煩請?zhí)峁┫嚓P信息告之我們,我們將及時溝通與處理。本站內容系統(tǒng)采集于網絡,涉及言論、版權與本站無關。
  • 相關文章
  • 下面列出與本文章《pytorch自定義不可導激活函數的操作》相關的同類信息!
  • 本頁收集關于pytorch自定義不可導激活函數的操作的相關信息資訊供網民參考!
  • 推薦文章
    主站蜘蛛池模板: 柳江县| 宝坻区| 浦县| 河东区| 资阳市| 调兵山市| 囊谦县| 富顺县| 商丘市| 镇宁| 平泉县| 铁力市| 黑龙江省| 潜江市| 通江县| 同德县| 岢岚县| 丰城市| 盐城市| 台北县| 河北省| 灌南县| 太康县| 红安县| 武功县| 尚志市| 北辰区| 岚皋县| 陆川县| 昌宁县| 阿拉善右旗| 新河县| 洞口县| 太谷县| 新民市| 镇原县| 银川市| 蒙城县| 金昌市| 高淳县| 扶沟县|