Source code for jet20.backend.plugins.rouding
import torch
from jet20.backend.plugins import Plugin
import logging
logger = logging.getLogger(__name__)
[docs]def round_(x,n=3):
return (x * 10**n).round() / (10**n)
[docs]class Rounding(Plugin):
[docs] def postprocess(self,p,x,config):
old_value = p.obj(x)
for i in range(config.rouding_precision,16):
_x = round_(x,i)
if p.eq and not p.eq.validate(_x,config.opt_constraint_tolerance):
continue
if p.le and not p.le.validate(_x):
continue
new_value = p.obj(_x)
if new_value <= old_value:
return p,_x
elif config.force_rouding:
logger.warning("objective get worse,before rouding: %s, after rouding:%s, p:%s",old_value.item(),new_value.item(),i)
return p,_x
logger.warning("rouding faild.")
return p,x