import torch
from jet20.backend.const import LINEAR,QUADRATIC
import logging
logger = logging.getLogger(__name__)
[docs]class Objective(object):
def __call__(self,x):
raise NotImplementedError("")
[docs] def type(self):
raise NotImplementedError("")
[docs] def float(self):
raise NotImplementedError("")
[docs] def double(self):
raise NotImplementedError("")
[docs] def to(self,device):
raise NotImplementedError("")
[docs]class LambdaObjective(Objective):
def __init__(self,_type,f,*args):
super(LambdaObjective,self).__init__()
self._type = _type
self.f = f
self.args = args
def __call__(self,x):
return self.f(x,*self.args)
[docs] def type(self):
return self._type
[docs] def float(self):
args = [ arg.float() for arg in self.args ]
return self.__class__(self._type,self.f,*args)
[docs] def double(self):
args = [ arg.double() for arg in self.args ]
return self.__class__(self._type,self.f,*args)
[docs] def to(self,device):
args = [ arg.to(device) for arg in self.args ]
return self.__class__(self._type,self.f,*args)
[docs]class LinearObjective(Objective):
def __init__(self,b,c=None):
super(LinearObjective,self).__init__()
self.b = b
if c is None:
self.c = torch.tensor(0.0,dtype=b.dtype,device=b.device)
elif isinstance(c,torch.Tensor):
self.c = c
else:
self.c = torch.tensor(c,dtype=b.dtype,device=b.device)
def __call__(self,x):
return self.b @ x + self.c
[docs] def type(self):
return LINEAR
[docs] def float(self):
b = self.b.float()
c = self.c.float()
return self.__class__(b,c)
[docs] def double(self):
b = self.b.double()
c = self.c.double()
return self.__class__(b,c)
[docs] def to(self,device):
b = self.b.to(device)
c = self.c.to(device)
return self.__class__(b,c)
[docs]class QuadraticObjective(Objective):
def __init__(self,Q,b=None,c=None):
super(QuadraticObjective,self).__init__()
self.Q = Q
self.b = b if b is not None else Q.new_zeros(Q.size(0))
if c is None:
self.c = torch.tensor(0.0,dtype=Q.dtype,device=Q.device)
elif isinstance(c,torch.Tensor):
self.c = c
else:
self.c = torch.tensor(c,dtype=Q.dtype,device=Q.device)
def __call__(self,x):
return x @ self.Q @ x + self.b @ x + self.c
[docs] def type(self):
return QUADRATIC
[docs] def float(self):
Q = self.Q.float()
b = self.b.float()
c = self.c.float()
return self.__class__(Q,b,c)
[docs] def double(self):
Q = self.Q.double()
b = self.b.double()
c = self.c.double()
return self.__class__(Q,b,c)
[docs] def to(self,device):
Q = self.Q.to(device)
b = self.b.to(device)
c = self.c.to(device)
return self.__class__(Q,b,c)