Source code for jet20.backend.plugins.feasible


import torch

from jet20.backend.plugins import Plugin
from jet20.backend.constraints import *
from jet20.backend.obj import *
from jet20.backend.core import solve
from jet20.backend.solver import Problem,USER_STOPPED,SUB_OPTIMAL,OPTIMAL

import logging
logger = logging.getLogger(__name__)


[docs]class LinearDependent(Exception):pass
[docs]class EnsureEqFeasible(Plugin):
[docs] def find_feasible(self,eq,config): A = eq.A b = eq.b u,_lambda,v = A.svd() if (_lambda < 1e-8).float().sum() > 0: raise LinearDependent("linear dependent in eq constraints..") x = v @ torch.diag(_lambda**-1) @ u.T @ b # if not eq.validate(x,config.opt_constraint_tolerance): # logger.debug("delta:%s",A @ v - b) # raise EqConstraitConflict("confilct in eq constraints") return x
[docs] def preprocess(self,p,x,config): if not p.eq: return p,x if x is None: return p,self.find_feasible(p.eq,config) if not p.eq.validate(x): logger.warning("x is not a feasible solution, eq constraints not satisfied") return p,self.find_feasible(p.eq,config) return p,x
[docs] def postprocess(self,p,x,config): if p.eq and not p.eq.validate(x,config.opt_constraint_tolerance): raise EqConstraitConflict("confilct in eq constraints") return p,x
[docs]class EnsureLeFeasible(Plugin):
[docs] def find_feasible(self,p,x,config): if p.eq: if p.eq.type() != LINEAR: raise NotImplementedError("non linear constrait not supported") _A = torch.cat([p.eq.A.new_zeros(p.eq.A.size(0)).unsqueeze(-1),p.eq.A],dim=1) eq = LinearEqConstraints(_A,p.eq.b) else: eq = None _A = torch.cat([-1 * p.le.A.new_ones(p.le.A.size(0)).unsqueeze(-1),p.le.A],dim=1) le = LinearLeConstraints(_A,p.le.b) if x is None: x = _A.new_ones(p.n) s = (torch.mv(p.le.A,x) - p.le.b).max()+1e-3 x = torch.cat([s.unsqueeze(0),x]) def f(x): return x[0] def should_stop(x,obj_value,dual_gap): return obj_value <= 0 obj = LambdaObjective(LINEAR,f) _p = Problem([],obj,le,eq) _p_f32 = _p.float() x = x.float() x,obj_value,status,duals = solve(_p_f32,x,config,fast=True,should_stops=[should_stop]) x = x.double() if isinstance(duals,(tuple,list)): duals = [d.double() for d in duals] else: duals = duals.double() if status == SUB_OPTIMAL and obj_value > 0: # _p = _p.double() x,obj_value,status,duals = solve(_p,x,config,fast=True,should_stops=[should_stop],duals=duals) if status == SUB_OPTIMAL and obj_value > 0: # _p = _p.double() x,obj_value,status,duals = solve(_p,x,config,fast=False,should_stops=[should_stop],duals=duals) if status == USER_STOPPED or obj_value <= 0: return x[1:] else: raise LeConstraitConflict("conflict in le constraints")
[docs] def preprocess(self,p,x,config): if not p.le: return p,x if x is None: return p,self.find_feasible(p,x,config) if not p.le.validate(x): logger.warning("x is not a feasible solution, le constraints not satisfied") return p,self.find_feasible(p,x,config) return p,x
[docs] def postprocess(self,p,x,config): if p.le and not p.le.validate(x): raise LeConstraitConflict("conflict in le constraints") return p,x