深度学习能不能学会乘法

import torchimport numpy as npimport torch.nn as nnbase=def dec2bin(num,bin): mid = while True: if num == 0: break num, rem = divmod(num, 2) mid.append(base) return (bin-len(mid))*+]def gen_xy(size): x=np.zeros(,np.float32) y=np.zeros(,np.float32) for idx in range(size): x1=np.random.randint(0,255) x2=np.random.randint(0,255) yl=dec2bin(x1*x2,16) xl=dec2bin(x1,8)+dec2bin(x2,16) for i in range(16): x=xl y = yl return x,yclass acc_func(nn.Module): def __init__(self): super(acc_func, self).__init__() def forward(self, x,y): t=torch.sum( torch.abs(y-x)\u0026lt;=0.5)/16./x.shape return tclass multi_model(nn.Module): def __init__(self): super(multi_model, self).__init__() lys= self.stem=nn.Linear(16,32) for x in range(20): lys.append(nn.Linear(32,32)) lys.append(nn.ReLU()) self.ly=nn.Sequential(*lys) self.out=nn.Sequential(nn.Linear(32,16),nn.Sigmoid()) def forward(self, x): x=self.stem(x) return self.out(self.ly(x))model=multi_model()loss_func=nn.MSELoss()optm=torch.optim.Adam(model.parameters(),0.1)acc=acc_func()for iter in range(1000000): x,y=gen_xy(256) x=torch.from_numpy(x)-0.5 y=torch.from_numpy(y) pred=model(x) loss=loss_func(pred,y) print(loss,acc(pred,y)) optm.zero_grad() loss.backward() optm.step()测试了下,真没学出来……
即使用单层很宽的网络也没学出来。靠Dense层应该是没啥希望。
import torchimport numpy as npimport torch.nn as nnimport torch.nn.functional as Fdef gen_xy(size): x=np.random.randint(0,255,) y=x*x x=x.astype(np.float32) y=y.astype(np.float32) return x,yclass mul_cell(nn.Module): def __init__(self): super(mul_cell, self).__init__() def forward(self, x): return x*xclass multi_model(nn.Module): def __init__(self): super(multi_model, self).__init__() self.mul=mul_cell() self.dense=nn.Linear(2,1) self.weight=nn.Parameter(torch.Tensor(2)) self.reset_param() def reset_param(self): torch.nn.init.constant_(self.weight, 0.01) def forward(self, x): weight=torch.softmax(self.weight,0) return weight*self.mul(x)+weight*self.dense(x)model=multi_model().cuda()loss_func=nn.MSELoss()optm=torch.optim.SGD(model.parameters(),.01)for iter in range(1000000): x,y=gen_xy(64) x=torch.from_numpy(x).cuda() y=torch.from_numpy(y).cuda() pred=model(x) loss=loss_func(pred,y) print(loss,model.weight) optm.zero_grad() loss.backward() optm.step()当然如果是显式写个乘法,让网络学权重,那就毫无问题。

■网友
可以参考 NIPS’17 的一篇挺有意思的 paper。
Deep Sets,https://papers.nips.cc/paper/6931-deep-sets.pdf
文中尝试学习 permutation invariant 的 set representation,并应用到了加法操作上,效果很不错。我认为推广到乘法也值得一试.

■网友
这个问题其实算已经基本解决了


推荐阅读