Source code for jet20.backend.obj


import torch
from jet20.backend.const import LINEAR,QUADRATIC

import logging
logger = logging.getLogger(__name__)


[docs]class Objective(object): def __call__(self,x): raise NotImplementedError("")
[docs] def type(self): raise NotImplementedError("")
[docs] def float(self): raise NotImplementedError("")
[docs] def double(self): raise NotImplementedError("")
[docs] def to(self,device): raise NotImplementedError("")
[docs]class LambdaObjective(Objective): def __init__(self,_type,f,*args): super(LambdaObjective,self).__init__() self._type = _type self.f = f self.args = args def __call__(self,x): return self.f(x,*self.args)
[docs] def type(self): return self._type
[docs] def float(self): args = [ arg.float() for arg in self.args ] return self.__class__(self._type,self.f,*args)
[docs] def double(self): args = [ arg.double() for arg in self.args ] return self.__class__(self._type,self.f,*args)
[docs] def to(self,device): args = [ arg.to(device) for arg in self.args ] return self.__class__(self._type,self.f,*args)
[docs]class LinearObjective(Objective): def __init__(self,b,c=None): super(LinearObjective,self).__init__() self.b = b if c is None: self.c = torch.tensor(0.0,dtype=b.dtype,device=b.device) elif isinstance(c,torch.Tensor): self.c = c else: self.c = torch.tensor(c,dtype=b.dtype,device=b.device) def __call__(self,x): return self.b @ x + self.c
[docs] def type(self): return LINEAR
[docs] def float(self): b = self.b.float() c = self.c.float() return self.__class__(b,c)
[docs] def double(self): b = self.b.double() c = self.c.double() return self.__class__(b,c)
[docs] def to(self,device): b = self.b.to(device) c = self.c.to(device) return self.__class__(b,c)
[docs]class QuadraticObjective(Objective): def __init__(self,Q,b=None,c=None): super(QuadraticObjective,self).__init__() self.Q = Q self.b = b if b is not None else Q.new_zeros(Q.size(0)) if c is None: self.c = torch.tensor(0.0,dtype=Q.dtype,device=Q.device) elif isinstance(c,torch.Tensor): self.c = c else: self.c = torch.tensor(c,dtype=Q.dtype,device=Q.device) def __call__(self,x): return x @ self.Q @ x + self.b @ x + self.c
[docs] def type(self): return QUADRATIC
[docs] def float(self): Q = self.Q.float() b = self.b.float() c = self.c.float() return self.__class__(Q,b,c)
[docs] def double(self): Q = self.Q.double() b = self.b.double() c = self.c.double() return self.__class__(Q,b,c)
[docs] def to(self,device): Q = self.Q.to(device) b = self.b.to(device) c = self.c.to(device) return self.__class__(Q,b,c)