Source code for jet20.backend.constraints
import torch
from jet20.backend.const import LINEAR,QUADRATIC
[docs]class LeConstraitConflict(Exception):
pass
[docs]class EqConstraitConflict(Exception):
pass
[docs]class Constraints(object):
def __call__(self,x):
raise NotImplementedError("")
[docs] def validate(self,x,*args,**kwargs):
raise NotImplementedError("")
[docs] def type(self):
raise NotImplementedError("")
[docs] def size(self):
raise NotImplementedError("")
[docs] def float(self):
raise NotImplementedError("")
[docs] def double(self):
raise NotImplementedError("")
[docs] def to(self,device):
raise NotImplementedError("")
[docs]class LinearConstraints(Constraints):
def __init__(self,A,b):
super(LinearConstraints,self).__init__()
self.A = A
self.b = b
def __call__(self,x):
return self.A @ x - self.b
[docs] def validate(self,x,*args,**kwargs):
raise NotImplementedError("")
[docs] def type(self):
return LINEAR
[docs] def size(self):
return self.A.size(0)
[docs] def float(self):
A = self.A.float()
b = self.b.float()
return self.__class__(A,b)
[docs] def double(self):
A = self.A.double()
b = self.b.double()
return self.__class__(A,b)
[docs] def to(self,device):
A = self.A.to(device)
b = self.b.to(device)
return self.__class__(A,b)
[docs]class LinearEqConstraints(LinearConstraints):
def __init__(self,A,b):
super(LinearEqConstraints,self).__init__(A,b)
[docs] def validate(self,x,tolerance=1e-8):
x = torch.abs(self(x))
neq = x > tolerance
return neq.float().sum() == 0
[docs]class LinearLeConstraints(LinearConstraints):
def __init__(self,A,b):
super(LinearLeConstraints,self).__init__(A,b)
[docs] def validate(self,x):
nle = self(x) > 0
return nle.float().sum() == 0