Source code for jet20.backend.core.primal_dual_interior_point_with_le

import torch
import numpy as np
from jet20.backend.constraints import Constraints
from jet20.backend.core.linear_solver import LUSolver,CholeskySolver,CGSolver
from jet20.backend.core.status import *
from jet20.backend.const import LINEAR,QUADRATIC
from jet20.backend.core.utils import NotImproving

import logging
logger = logging.getLogger(__name__)






[docs]def solve_kkt_fast(h2,d_f,lambda_,f_x,r_dual,r_cent): if isinstance(f_x,torch.DoubleTensor): f_x[f_x == 0] = 1e-16 else: f_x[f_x == 0] = 1e-8 _r_cent = (f_x ** -1) * r_cent _d_f = -(lambda_ / f_x).unsqueeze(-1) * d_f h_pd = h2 + d_f.T @ _d_f g = r_dual + d_f.T @ _r_cent solver = CholeskySolver() _dir_x = -solver(h_pd,g) _dir_lambda = _d_f @ _dir_x + _r_cent return _dir_x,_dir_lambda
[docs]def solve_kkt(h2,d_f,lambda_,f_x,r_dual,r_cent,n,m): p = n+m KKT = d_f.new_zeros((p,p)) if not (isinstance(h2,float) and h2 == 0.0): KKT[:n,:n] = h2 KKT[:n,n:n+m] = d_f.T KKT[n:n+m,:n] = -lambda_.unsqueeze(-1) * d_f KKT[n:n+m,n:n+m] = -torch.diag(f_x) solver = LUSolver() _dir = -solver(KKT,torch.cat([r_dual,r_cent])) _dir_x,_dir_lambda = _dir[:n],_dir[n:n+m] return _dir_x,_dir_lambda
[docs]def primal_dual_interior_point_with_le(x,obj,le_cons=None,should_stop=None,u=10.0, tolerance=1e-3, constraint_tolerance=1e-3, alpha=0.1, beta=0.5, fast=False,verbose=False,duals=None): from torch.autograd.functional import jacobian from torch.autograd.functional import hessian m = le_cons.size() n = x.size(0) if duals is None: lambda_ = x.new_ones(m) else: lambda_ = duals u = 10 def l(x,lambda_): return obj(x) + le_cons(x) @ lambda_ def residual(x,lambda_,t): f_x = le_cons(x) r_dual = jacobian(lambda x: l(x,lambda_),x) r_cent = -lambda_ * f_x - 1/t return r_dual,r_cent def r_norm(x,lambda_,t): r_dual,r_cent = residual(x,lambda_,t) norm = torch.cat([r_dual,r_cent]).norm(2) return norm def jacobian_(f,x): if f.type() == LINEAR: return f.A else: return jacobian(f,x) def hessian_(x,lambda_): if le_cons.type() == LINEAR and obj.type() == LINEAR: return 0.0 else: return hessian(lambda x: l(x,lambda_),x) should_stop = should_stop or [] not_improving = NotImproving() while True: f_x = le_cons(x) dual_gap = - f_x @ lambda_ t = u * m / dual_gap r_dual,r_cent = residual(x,lambda_,t) obj_value = obj(x) norm = torch.cat([r_dual,r_cent]).norm(2) if verbose: logger.info("obj:%s,r_dual:%s,r_cent:%s,norm:%s",obj_value,r_dual.norm(2),r_cent.norm(2),norm) if r_dual.norm(2) <= constraint_tolerance and dual_gap <= tolerance: return x, obj_value, OPTIMAL, lambda_ if not_improving(norm): return x, obj_value, SUB_OPTIMAL, lambda_ if torch.isnan(obj_value): return x, obj_value, FAIELD, lambda_ h2 = hessian_(x,lambda_) d_f = jacobian_(le_cons,x) if fast: _dir_x,_dir_lambda = solve_kkt_fast(h2,d_f,lambda_,f_x,r_dual,r_cent) else: _dir_x,_dir_lambda = solve_kkt(h2,d_f,lambda_,f_x,r_dual,r_cent,n,m) step = line_search(r_norm,le_cons,(x,lambda_),(_dir_x,_dir_lambda), t, norm, alpha=alpha, beta=beta) x = x + step * _dir_x lambda_ = lambda_ + step * _dir_lambda for ss in should_stop: if ss(x,obj_value,dual_gap): return x, obj_value, USER_STOPPED, lambda_