Source code for jet20.backend.constraints



import torch

from jet20.backend.const import LINEAR,QUADRATIC


[docs]class LeConstraitConflict(Exception): pass
[docs]class EqConstraitConflict(Exception): pass
[docs]class Constraints(object): def __call__(self,x): raise NotImplementedError("")
[docs] def validate(self,x,*args,**kwargs): raise NotImplementedError("")
[docs] def type(self): raise NotImplementedError("")
[docs] def size(self): raise NotImplementedError("")
[docs] def float(self): raise NotImplementedError("")
[docs] def double(self): raise NotImplementedError("")
[docs] def to(self,device): raise NotImplementedError("")
[docs]class LinearConstraints(Constraints): def __init__(self,A,b): super(LinearConstraints,self).__init__() self.A = A self.b = b def __call__(self,x): return self.A @ x - self.b
[docs] def validate(self,x,*args,**kwargs): raise NotImplementedError("")
[docs] def type(self): return LINEAR
[docs] def size(self): return self.A.size(0)
[docs] def float(self): A = self.A.float() b = self.b.float() return self.__class__(A,b)
[docs] def double(self): A = self.A.double() b = self.b.double() return self.__class__(A,b)
[docs] def to(self,device): A = self.A.to(device) b = self.b.to(device) return self.__class__(A,b)
[docs]class LinearEqConstraints(LinearConstraints): def __init__(self,A,b): super(LinearEqConstraints,self).__init__(A,b)
[docs] def validate(self,x,tolerance=1e-8): x = torch.abs(self(x)) neq = x > tolerance return neq.float().sum() == 0
[docs]class LinearLeConstraints(LinearConstraints): def __init__(self,A,b): super(LinearLeConstraints,self).__init__(A,b)
[docs] def validate(self,x): nle = self(x) > 0 return nle.float().sum() == 0