Source code for jet20.backend.core.linear_solver
import torch
[docs]def robust_cholesky(A,beta=1e-3):
min_aii = torch.diagonal(A).min()
if min_aii > 0:
t = 0
else:
t = -min_aii + beta
I = torch.diag(A.new_ones(A.size(0)))
while True:
try:
A_ = A + t * I
l = A_.cholesky()
return l
except Exception as e:
t = max(2*t,beta)
[docs]def conjugate_gradients(A, b, nsteps, M_inv = None,residual_tol=1e-3):
x = A.new_zeros(b.size())
r = -b.clone()
if M_inv is not None:
y = M_inv @ r
p = -y.clone()
r_2 = r @ y
else:
p = -r.clone()
r_2 = r @ r
for i in range(nsteps):
_Avp = A @ p
pap = p @ _Avp
if pap <= 0:
if i == 0:
return p
else:
return x
alpha = r_2 / pap
x += alpha * p
r += alpha * _Avp
if M_inv is not None:
y = M_inv @ r
new_r_2 = r @ y
else:
new_r_2 = r @ r
betta = new_r_2 / r_2
if M_inv is not None:
p = -y + betta * p
else:
p = -r + betta * p
r_2 = new_r_2
if r_2 < residual_tol:
break
return x
[docs]class LinearSolver(object):
def __call__(self,A,b):
raise NotImplementedError()
[docs]class LUSolver(LinearSolver):
def __init__(self):
super(LUSolver,self).__init__()
self.cache = {}
def __call__(self,A,b):
lup = self.cache.get(id(A))
if lup is None:
lup = A.lu()
self.cache[id(A)] = lup
if b.ndim == 1:
return torch.lu_solve(b.unsqueeze(-1),*lup).squeeze(-1)
else:
return torch.lu_solve(b,*lup)
[docs]class CholeskySolver(LinearSolver):
def __init__(self):
super(CholeskySolver,self).__init__()
self.cache = {}
def __call__(self,A,b):
l = self.cache.get(id(A))
if l is None:
l = robust_cholesky(A)
self.cache[id(A)] = l
if b.ndim == 1:
return torch.cholesky_solve(b.unsqueeze(-1),l).squeeze(-1)
else:
return torch.cholesky_solve(b,l)
[docs]class CGSolver(LinearSolver):
def __init__(self,nsteps, M_inv = None,residual_tol=1e-3):
super(CGSolver,self).__init__()
self.nsteps = nsteps
self.M_inv = M_inv
self.residual_tol = residual_tol
def __call__(self,A,b):
return conjugate_gradients(A,b,self.nsteps,self.M_inv,self.residual_tol)