using Flux
using Random, LinearAlgebra
# Set seed for reproducibility
Random.seed!(42)
# Dimensions
m, n, r = 10, 8, 8 # A is m x n, X is m x k, Y is k x n
# Generate a random matrix A
A = randn(m, n)
# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)
function custom_loss(A, X, Y)
return norm(A-X*Y)^2
#return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2
end
model_loss = (X,Y) -> custom_loss(A,X,Y)
# Training loop
niterations = 5000
for epoch in 1:niterations
gX, gY = gradient(model_loss, X, Y)
@. X = X - 0.01*gX
@. Y = Y - 0.01*gY
# Print loss every 50 epochs
if epoch % 50 == 0
println("Epoch $epoch: Loss = ", model_loss(X, Y))
end
end
opterr = model_loss(X,Y)
# Display learned X and Y
Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2
@show opterr
@show svderr
Epoch 50: Loss = 2.091547823016724 Epoch 100: Loss = 0.5661487081197589 Epoch 150: Loss = 0.37969498421413045 Epoch 200: Loss = 0.29424961525867044 Epoch 250: Loss = 0.23736222031632534 Epoch 300: Loss = 0.1964354866301492 Epoch 350: Loss = 0.16601387453115432 Epoch 400: Loss = 0.14287411770926364 Epoch 450: Loss = 0.1249276797918005 Epoch 500: Loss = 0.11076102657341395 Epoch 550: Loss = 0.09938956930058786 Epoch 600: Loss = 0.09011258577555409 Epoch 650: Loss = 0.08242286358561295 Epoch 700: Loss = 0.07594817484624594 Epoch 750: Loss = 0.07041219966566804 Epoch 800: Loss = 0.06560780146151443 Epoch 850: Loss = 0.06137839646849488 Epoch 900: Loss = 0.05760476277001347 Epoch 950: Loss = 0.05419557903563848 Epoch 1000: Loss = 0.05108056173010772 Epoch 1050: Loss = 0.048205435866809476 Epoch 1100: Loss = 0.04552821297055821 Epoch 1150: Loss = 0.04301640904246799 Epoch 1200: Loss = 0.040644943537536414 Epoch 1250: Loss = 0.038394535145568016 Epoch 1300: Loss = 0.036250462521420154 Epoch 1350: Loss = 0.034201595160950955 Epoch 1400: Loss = 0.032239626065311194 Epoch 1450: Loss = 0.030358456836600548 Epoch 1500: Loss = 0.02855369955562033 Epoch 1550: Loss = 0.026822269694929714 Epoch 1600: Loss = 0.025162051458471984 Epoch 1650: Loss = 0.023571622051321215 Epoch 1700: Loss = 0.022050025002258982 Epoch 1750: Loss = 0.020596585179147527 Epoch 1800: Loss = 0.01921075984538795 Epoch 1850: Loss = 0.017892021228549455 Epoch 1900: Loss = 0.016639766782187856 Epoch 1950: Loss = 0.015453253753083748 Epoch 2000: Loss = 0.014331554922186804 Epoch 2050: Loss = 0.013273532546789995 Epoch 2100: Loss = 0.012277827650879621 Epoch 2150: Loss = 0.011342861929083665 Epoch 2200: Loss = 0.010466849670845874 Epoch 2250: Loss = 0.009647817286751566 Epoch 2300: Loss = 0.008883628230206997 Epoch 2350: Loss = 0.008172011350082477 Epoch 2400: Loss = 0.00751059097438486 Epoch 2450: Loss = 0.006896917300391609 Epoch 2500: Loss = 0.006328495941524063 Epoch 2550: Loss = 0.0058028157451007 Epoch 2600: Loss = 0.005317374239396081 Epoch 2650: Loss = 0.0048697002867813515 Epoch 2700: Loss = 0.0044573737081773914 Epoch 2750: Loss = 0.004078041800864097 Epoch 2800: Loss = 0.003729432797022102 Epoch 2850: Loss = 0.003409366405908848 Epoch 2900: Loss = 0.003115761651040025 Epoch 2950: Loss = 0.0028466422585831153 Epoch 3000: Loss = 0.0026001398781133342 Epoch 3050: Loss = 0.0023744954257074163 Epoch 3100: Loss = 0.0021680588356420283 Epoch 3150: Loss = 0.001979287493983368 Epoch 3200: Loss = 0.0018067436079380144 Epoch 3250: Loss = 0.0016490907413761518 Epoch 3300: Loss = 0.0015050897213698617 Epoch 3350: Loss = 0.001373594094424699 Epoch 3400: Loss = 0.0012535452854519839 Epoch 3450: Loss = 0.0011439675882342335 Epoch 3500: Loss = 0.0010439630937052696 Epoch 3550: Loss = 0.0009527066421060979 Epoch 3600: Loss = 0.0008694408671250296 Epoch 3650: Loss = 0.0007934713844947298 Epoch 3700: Loss = 0.0007241621641145677 Epoch 3750: Loss = 0.0006609311134525298 Epoch 3800: Loss = 0.000603245890573871 Epoch 3850: Loss = 0.0005506199574433729 Epoch 3900: Loss = 0.0005026088779495058 Epoch 3950: Loss = 0.0004588068602014448 Epoch 4000: Loss = 0.0004188435388616766 Epoch 4050: Loss = 0.00038238099042634487 Epoch 4100: Loss = 0.0003491109722921667 Epoch 4150: Loss = 0.00031875237501530706 Epoch 4200: Loss = 0.0002910488762518231 Epoch 4250: Loss = 0.00026576678436702426 Epoch 4300: Loss = 0.00024269305952397344 Epoch 4350: Loss = 0.00022163350013455508 Epoch 4400: Loss = 0.0002024110828191659 Epoch 4450: Loss = 0.00018486444442131878 Epoch 4500: Loss = 0.00016884649512177784 Epoch 4550: Loss = 0.0001542231522587359 Epoch 4600: Loss = 0.0001408721850611429 Epoch 4650: Loss = 0.0001286821611204695 Epoch 4700: Loss = 0.00011755148604680225 Epoch 4750: Loss = 0.00010738752836612796 Epoch 4800: Loss = 9.810582230878433e-5 Epoch 4850: Loss = 8.962934170755446e-5 Epoch 4900: Loss = 8.188783876490361e-5 Epoch 4950: Loss = 7.481724195830831e-5 Epoch 5000: Loss = 6.835910783088071e-5 opterr = 6.835910783088071e-5 svderr = 1.2555705456793983e-28
1.2555705456793983e-28
# This will evaluate the gradient of norm(A-XY^2) at a point X, Y
# Let's check against my "analytical" gradient...
X = ones(m,r)
Y = ones(r,n)
gX, gY = gradient(model_loss, X, Y)
gX
gX + (2*(A - X*Y)*Y')
10×8 Matrix{Float64}: 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
norm(X), norm(Y)
(3.8030981067374814, 3.803098106737481)
norm(A-X*Y)
2.6586124701944938
# Dimensions
m, n, r = 10, 8, 8 # A is m x n, X is m x k, Y is k x n
# Generate a random matrix A
A = randn(m, n)
# Initialize learnable factors
X = randn(m, r)
Y = randn(r, n)
function full_function(A, X, Y)
#return norm(A-X*Y)^2
return norm(A-X*Y)^2 + norm(X)^2 + norm(Y)^2
end
optimization_function = (X,Y) -> full_function(A,X,Y)
opt = Optimisers.Adam(0.01)
state = Optimisers.setup(opt, (X,Y))
# Training loop (Flux calls this epochs, but I don't like confusing them...)
niteration = 5000
for iteration in 1:niteration
#@show X, Y
grads = gradient(optimization_function, X, Y)
state, (X,Y) = Optimisers.update(state, (X,Y), grads)
# Print loss every 50 epochs
if iteration % 50 == 0
println("Iteration $iteration: Loss = ", optimization_function(X, Y))
end
end
@show opterr = optimization_function(X,Y)
@show norm(A-X*Y)^2
Iteration 50: Loss = 179.60227608291302 Iteration 100: Loss = 95.79193142197144 Iteration 150: Loss = 71.33033035163707 Iteration 200: Loss = 57.27653930501515 Iteration 250: Loss = 47.99503153083923 Iteration 300: Loss = 41.955722482084234 Iteration 350: Loss = 38.00203752449347 Iteration 400: Loss = 35.352331740381516 Iteration 450: Loss = 33.566439779119 Iteration 500: Loss = 32.37618085148305 Iteration 550: Loss = 31.59633605079715 Iteration 600: Loss = 31.093976245108976 Iteration 650: Loss = 30.775308464352666 Iteration 700: Loss = 30.576066285058488 Iteration 750: Loss = 30.453310272875175 Iteration 800: Loss = 30.378869803582386 Iteration 850: Loss = 30.334512046759293 Iteration 900: Loss = 30.308581208014253 Iteration 950: Loss = 30.293729764249335 Iteration 1000: Loss = 30.285404187704813 Iteration 1050: Loss = 30.280838545929953 Iteration 1100: Loss = 30.278390015150972 Iteration 1150: Loss = 30.277105947348623 Iteration 1200: Loss = 30.276447430899296 Iteration 1250: Loss = 30.276117146388895 Iteration 1300: Loss = 30.275955111934994 Iteration 1350: Loss = 30.275877349604485 Iteration 1400: Loss = 30.275840841097125 Iteration 1450: Loss = 30.275824073791345 Iteration 1500: Loss = 30.275816541832693 Iteration 1550: Loss = 30.27581323357476 Iteration 1600: Loss = 30.27581181339558 Iteration 1650: Loss = 30.27581121790305 Iteration 1700: Loss = 30.275810974195856 Iteration 1750: Loss = 30.275810876937964 Iteration 1800: Loss = 30.275810839129065 Iteration 1850: Loss = 30.2758108248278 Iteration 1900: Loss = 30.275810819570978 Iteration 1950: Loss = 30.275810817695717 Iteration 2000: Loss = 30.275810817047393 Iteration 2050: Loss = 30.27581081683048 Iteration 2100: Loss = 30.275810816760345 Iteration 2150: Loss = 30.275810816738463 Iteration 2200: Loss = 30.275810816731894 Iteration 2250: Loss = 30.275810816729994 Iteration 2300: Loss = 30.27581081672946 Iteration 2350: Loss = 30.27581081672932 Iteration 2400: Loss = 30.275810816729283 Iteration 2450: Loss = 30.275810816729273 Iteration 2500: Loss = 30.275810816729276 Iteration 2550: Loss = 30.275810816729273 Iteration 2600: Loss = 30.275810816729283 Iteration 2650: Loss = 30.27581081672927 Iteration 2700: Loss = 30.275810816729287 Iteration 2750: Loss = 30.275810816729276 Iteration 2800: Loss = 30.27581081672927 Iteration 2850: Loss = 30.27581081672927 Iteration 2900: Loss = 30.275810816729283 Iteration 2950: Loss = 30.275810816729273 Iteration 3000: Loss = 30.275810816729276 Iteration 3050: Loss = 30.275810816729276 Iteration 3100: Loss = 30.275810816729276 Iteration 3150: Loss = 30.275810816729276 Iteration 3200: Loss = 30.275810816729273 Iteration 3250: Loss = 30.275810816729273 Iteration 3300: Loss = 30.275810816729276 Iteration 3350: Loss = 30.275810816729276 Iteration 3400: Loss = 30.275810816729276 Iteration 3450: Loss = 30.275810816729276 Iteration 3500: Loss = 30.27581081672927 Iteration 3550: Loss = 30.27581081672927 Iteration 3600: Loss = 30.27581081672927 Iteration 3650: Loss = 30.27581081672927 Iteration 3700: Loss = 30.275810816729276 Iteration 3750: Loss = 30.275810816729276 Iteration 3800: Loss = 30.275810816729276 Iteration 3850: Loss = 30.275810816729276 Iteration 3900: Loss = 30.275810816729276 Iteration 3950: Loss = 30.27581081672927 Iteration 4000: Loss = 30.275810816729276 Iteration 4050: Loss = 30.27581081672927 Iteration 4100: Loss = 30.275810816729276 Iteration 4150: Loss = 30.275810816729276 Iteration 4200: Loss = 30.275810816729276 Iteration 4250: Loss = 30.275810816729276 Iteration 4300: Loss = 30.275810816729276 Iteration 4350: Loss = 30.275810816729276 Iteration 4400: Loss = 30.27581081672928 Iteration 4450: Loss = 30.275810816729276 Iteration 4500: Loss = 30.275810816729276 Iteration 4550: Loss = 30.275810816729276 Iteration 4600: Loss = 30.275810816729283 Iteration 4650: Loss = 30.275810816729276 Iteration 4700: Loss = 30.275810816729283 Iteration 4750: Loss = 30.275810816729283 Iteration 4800: Loss = 30.275810816729283 Iteration 4850: Loss = 30.275810816729276 Iteration 4900: Loss = 30.275810816729283 Iteration 4950: Loss = 30.275810816729283 Iteration 5000: Loss = 30.275810816729283 opterr = optimization_function(X, Y) = 30.275810816729283 norm(A - X * Y) ^ 2 = 7.141659826689626
7.141659826689626
using Optim
WARNING: using Optim.Adam in module Main conflicts with an existing identifier.
m = 10
n = 8
A = randn(m,n)
#A = Matrix(1.0I,m,n)
r = 2
myf = x -> matrix_approx_function(x, A, r)
myg! = (x, storage) -> matrix_approx_gradient!(x, storage, A, r)
soln = optimize(myf, myg!, ones(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
#soln = optimize(myf, myg!, randn(m*r+n*r), BFGS(), Optim.Options(f_tol = 1e-8))
x = Optim.minimizer(soln)
@show soln
Uopt = reshape(x[(1:m*r)],m,r)
Vopt = reshape(x[(m*r+1):end],n,r)
objval = 2*myf(x)
opterr = norm(A-Uopt*Vopt')^2
Usvd,Ssvd,Vsvd = svd(A)
svderr = norm(A-Usvd[:,1:r]*Diagonal(Ssvd[1:r])*Vsvd[:,1:r]')^2
@show objval
@show opterr
@show svderr
; # hide final output in JuliaBox