In [1]:
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import os
import numpy as np
In [2]:
def equalizer(U):
    """ equalizes the input matrix such that norms of the columns are equal
    
    Args:
        U: input matrix to be equalized.
    Returns:
        sol: the equalizer, such that U @ sol is equalized
    """
    r = np.size(U,1)
    G = U.T @ U
    sol = np.identity(r)
    for i in range(r):
        l, u = np.linalg.eigh(G)
        vi = np.mean(u,1)*np.sqrt(r-i)
        vi = vi.reshape(len(vi),1)
        if i<r-1:
            tmp, _, _ = np.linalg.svd(vi)
        else:
            tmp = np.array([1])
        Ri = np.identity(r)
        Ri[i:,i:] = tmp
        sol = sol @ Ri
        G = tmp.T @ G @ tmp
        if i<r-1:
            G = G[1:,1:]
    return sol;
In [3]:
def dropout_solution(M, r, lam):
    """ solves the matrix dropout problem.
    
    Args:
        M:   input matrix to be factorized
        r:   the factorization size
        lam: the regularization parameter. lam = p/(1-p), where p is the dropout rate 
        
    Returns:
        sol:          sol = Ustar @ Vstar.T
        Ustar, Vstar: a pair of factors which is globally optimum    
    """
    U, s, V = np.linalg.svd(M, full_matrices=False)
    idx = s>0
    s, U, V = s[idx], U[:,idx], V[idx,:] # compact SVD
    rho_lim = min(len(s),r)
    s, U, V = s[range(rho_lim)], U[:,range(rho_lim)], V[range(rho_lim),:].T
    for i in reversed(range(1,rho_lim+1)):
        kappa=np.mean(s[range(i)])
        cutoff=(lam*i*kappa)/(r+lam*i)
        if s[i-1] > cutoff:
            s = s - cutoff
            s[s<=0] = 0
            s = np.append(s,np.zeros(r-rho_lim))
            rootS = np.diag(np.sqrt(s))
            U =np.hstack((U, np.zeros([np.size(U,0),r-rho_lim])))
            V =np.hstack((V, np.zeros([np.size(V,0),r-rho_lim])))
            eq=equalizer(np.matmul(U,rootS))
            Ustar=U @ rootS @ eq
            Vstar=V @ rootS @ eq
            sol=Ustar @ Vstar.T
            return (sol, Ustar, Vstar);
In [4]:
# N is sample size; D_in is input dimension;
# H is the number of hidden nodes; D_out is output dimension.
N, D_in, H, D_out = 1000, 20, 10, 15

torch.manual_seed(0)

# Create training points from standard Gaussian distribution
x = Variable(torch.randn(N, D_in))

# Create the transformation matrix M that generates labels y = M @ x.
um = np.random.randn(D_in, H)
vm = np.random.randn(H, D_out)
M = Variable(torch.Tensor(um @ vm), requires_grad=False)

# Regression: Create observed variables as y = M @ x
y = Variable(torch.mm(x,M).data, requires_grad=False)

# dropout rate and the corresponding regularizer
thetas = np.array([1/2,2/3,10/11])
lt = len(thetas)

# initialize performance measures
num_itr = 20000
obj = torch.Tensor(num_itr,lt).zero_()
loss_obj = torch.Tensor(num_itr,lt).zero_()
reg = torch.Tensor(num_itr,lt).zero_()
opt = torch.Tensor(lt).zero_()

# Lets initialize the weights to be highly non-equalized
scaling = torch.from_numpy(np.diag(2**np.arange(H))).float()
W1 = torch.mm(scaling,torch.rand(H, D_in))
W1 = torch.norm(M.data)**.5 * W1 / torch.norm(W1)
W2 = torch.mm(torch.rand(D_out, H), scaling)
W2 = torch.norm(M.data)**.5 * W2 / torch.norm(W2)

for theta_idx in range(lt):
    theta = thetas[theta_idx]
    lam = (1-theta)/theta
    
    sol, U, V = dropout_solution(M.data, H, lam)
    uu = np.diag(U.T@U)
    vv = np.diag(V.T@V)
    opt[theta_idx] = np.linalg.norm(M.data.numpy()-U@V.T)**2 + lam * np.sum(uu*vv)

    model = torch.nn.Sequential(
        torch.nn.Linear(D_in, H, bias=False),
        torch.nn.Dropout(p=1-theta),
        torch.nn.Linear(H, D_out, bias=False),
    )

    model[0].weight.data = W1.clone()
    model[2].weight.data = W2.clone()

# MSE loss    
    loss_fn = torch.nn.MSELoss(size_average=False)

    learning_rate = 1e-7
    for t in range(num_itr):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
    
        a=model[0].weight.data # H * D_in
        b=model[2].weight.data # D_out * H
        aa=torch.mm(a,torch.t(a)).diag()
        bb=torch.mm(torch.t(b),b).diag()
        obj[t,theta_idx] = torch.norm(M.data-torch.mm(torch.t(a),torch.t(b)))**2 + lam * torch.sum(aa*bb)
        
        _, sm, _ = np.linalg.svd((b @ a).numpy(),full_matrices=False)
        nuc_norm = sm.sum()
        
        reg[t,theta_idx] = torch.std(aa*bb/(nuc_norm**2))


        if t%2000 == 0:
            print('itr: {}, truth: {}, loss: {}, obj: {}, var: {}'.format(t, opt[theta_idx], loss.data[0]*1/N, obj[t,theta_idx], reg[t,theta_idx]))

        model.zero_grad()
        loss.backward()
        for param in model.parameters():
            param.data -= learning_rate  * param.grad.data
itr: 0, truth: 1040.0146484375, loss: 7370.3995, obj: 6869.98046875, var: 0.1705906242132187
itr: 2000, truth: 1040.0146484375, loss: 1171.102875, obj: 1168.78173828125, var: 0.00798219908028841
itr: 4000, truth: 1040.0146484375, loss: 1070.9235, obj: 1052.3408203125, var: 0.0018507082713767886
itr: 6000, truth: 1040.0146484375, loss: 1104.560875, obj: 1043.1497802734375, var: 0.0007635527290403843
itr: 8000, truth: 1040.0146484375, loss: 1089.558625, obj: 1041.877685546875, var: 0.0004560464876703918
itr: 10000, truth: 1040.0146484375, loss: 1046.921375, obj: 1041.5294189453125, var: 0.00036811360041610897
itr: 12000, truth: 1040.0146484375, loss: 1037.9776875, obj: 1041.4373779296875, var: 0.00034691602922976017
itr: 14000, truth: 1040.0146484375, loss: 1055.9675, obj: 1041.358154296875, var: 0.00030572788091376424
itr: 16000, truth: 1040.0146484375, loss: 1066.925375, obj: 1041.3018798828125, var: 0.00027281345683149993
itr: 18000, truth: 1040.0146484375, loss: 1057.7115, obj: 1041.32470703125, var: 0.0002893458877224475
itr: 0, truth: 674.9812622070312, loss: 6231.501, obj: 6049.263671875, var: 0.1705906242132187
itr: 2000, truth: 674.9812622070312, loss: 890.2166875, obj: 879.8507690429688, var: 0.012582366354763508
itr: 4000, truth: 674.9812622070312, loss: 736.8441875, obj: 715.0987548828125, var: 0.004989123437553644
itr: 6000, truth: 674.9812622070312, loss: 675.10875, obj: 684.8736572265625, var: 0.002388444496318698
itr: 8000, truth: 674.9812622070312, loss: 655.512625, obj: 678.614013671875, var: 0.0012385585578158498
itr: 10000, truth: 674.9812622070312, loss: 684.5153125, obj: 677.0885620117188, var: 0.0007080661016516387
itr: 12000, truth: 674.9812622070312, loss: 682.4474375, obj: 676.6397094726562, var: 0.0004785564378835261
itr: 14000, truth: 674.9812622070312, loss: 695.9023125, obj: 676.4458618164062, var: 0.00037044999771751463
itr: 16000, truth: 674.9812622070312, loss: 704.6805625, obj: 676.3016357421875, var: 0.0003187335387337953
itr: 18000, truth: 674.9812622070312, loss: 671.4925625, obj: 676.2023315429688, var: 0.00030061014695093036
itr: 0, truth: 183.17147827148438, loss: 5633.919, obj: 5392.68994140625, var: 0.1705906242132187
itr: 2000, truth: 183.17147827148438, loss: 394.31315625, obj: 406.016357421875, var: 0.018426906317472458
itr: 4000, truth: 183.17147827148438, loss: 256.085578125, obj: 262.52239990234375, var: 0.012422796338796616
itr: 6000, truth: 183.17147827148438, loss: 266.35528125, obj: 242.3431396484375, var: 0.01100333034992218
itr: 8000, truth: 183.17147827148438, loss: 250.932140625, obj: 224.51071166992188, var: 0.009590476751327515
itr: 10000, truth: 183.17147827148438, loss: 210.691109375, obj: 217.0332489013672, var: 0.008518378250300884
itr: 12000, truth: 183.17147827148438, loss: 215.401953125, obj: 211.4070281982422, var: 0.007586944382637739
itr: 14000, truth: 183.17147827148438, loss: 189.439015625, obj: 205.0497283935547, var: 0.0065371873788535595
itr: 16000, truth: 183.17147827148438, loss: 194.15615625, obj: 198.04759216308594, var: 0.005397842265665531
itr: 18000, truth: 183.17147827148438, loss: 213.33175, obj: 192.0817413330078, var: 0.00437609339132905
In [5]:
path = "../plots"
try:  
    if not os.path.exists(path):
        os.makedirs(path)
except OSError:  
    print ("Creation of the directory %s failed" % path)
In [6]:
plt.close()
plt.figure()
for theta_idx in range(lt):
    theta=thetas[theta_idx]
    lam=(1-theta)/theta
    plt.plot(range(num_itr),reg[:,theta_idx].numpy(),label=r"$\lambda$ = %.1f"%lam)
    
plt.legend()
plt.xlabel('Iterations')
plt.ylabel('Var of importance scores')
plt.title('Equalization')
plt.yscale('linear')
plt.xscale('log')
plt.grid()
plt.savefig("%s/nn_dout_r=%d.pdf"%(path,H))
plt.show()
In [7]:
plt.close()
for theta_idx in range(lt):
    theta=thetas[theta_idx]
    lam=(1-theta)/theta
    plt.figure(theta_idx)
    plt.plot(range(num_itr),obj[:,theta_idx].numpy(),label=r"Dropout, $\lambda$ = %.1f" %lam)
    plt.plot(range(num_itr),opt[theta_idx]*np.ones(num_itr),label='Truth')
    plt.legend()
    plt.xlabel('Iterations')
    plt.ylabel('Objective ')
    plt.title('Convergence in objective')
    plt.yscale('log')
    plt.xscale('log')
    plt.grid()
    plt.savefig('%s/nn_dout_r=%d_theta=%.1f.pdf'%(path,H,thetas[theta_idx]))
    plt.show()
In [ ]: