Source code for stadv.optimization

import tensorflow as tf
import numpy as np
from scipy.optimize import fmin_l_bfgs_b

[docs]def lbfgs( loss, flows, flows_x0, feed_dict=None, grad_op=None, fmin_l_bfgs_b_extra_kwargs=None, sess=None ): """Optimize a given loss with (SciPy's external) L-BFGS-B optimizer. It can be used to solve the optimization problem of Eq. (2) in Xiao et al. (arXiv:1801.02612). See `the documentation on scipy.optimize.fmin_l_bfgs_b <>`_ for reference on the optimizer. Args: loss (tf.Tensor): loss (can be of any shape). flows (tf.Tensor): flows of shape `(B, 2, H, W)`, where the second dimension indicates the dimension on which the pixel shift is applied. flows_x0 (np.ndarray): Initial guess for the flows. If the input is not of type `np.ndarray`, it will be converted as such if possible. feed_dict (dict): feed dictionary to the ```` operation (for everything which might be needed to execute the graph beyond the input flows). grad_op (tf.Tensor): gradient of the loss with respect to the flows. If not provided it will be computed from the input and added to the graph. fmin_l_bfgs_b_extra_kwargs (dict): extra arguments to ``scipy.optimize.fmin_l_bfgs_b`` (e.g. for modifying the stopping condition). sess (tf.Session): session within which the graph should be executed. If not provided a new session will be started. Returns: `Dictionary` with keys ``'flows'`` (`np.ndarray`, estimated flows of the minimum), ``'loss'`` (`float`, value of loss at the minimum), and ``'info'`` (`dict`, information summary as returned by ``scipy.optimize.fmin_l_bfgs_b``). """ def tf_run(x): """Function to minimize as provided to ``scipy.optimize.fmin_l_bfgs_b``. Args: x (np.ndarray): current flows proposal at a given stage of the optimization (flattened `np.ndarray` of type `np.float64` as required by the backend FORTRAN implementation of L-BFGS-B). Returns: `Tuple` `(loss, loss_gradient)` of type `np.float64` as required by the backend FORTRAN implementation of L-BFGS-B. """ flows_val = np.reshape(x, flows_shape) feed_dict.update({flows: flows_val}) loss_val, gradient_val = [loss, loss_gradient], feed_dict=feed_dict ) loss_val = np.sum(loss_val).astype(np.float64) gradient_val = gradient_val.flatten().astype(np.float64) return loss_val, gradient_val flows_x0 = np.asarray(flows_x0, dtype=np.float64) flows_shape = flows_x0.shape if feed_dict is None: feed_dict = {} if fmin_l_bfgs_b_extra_kwargs is None: fmin_l_bfgs_b_extra_kwargs = {} fmin_l_bfgs_b_kwargs = { 'func': tf_run, 'approx_grad': False, # we want to use the gradients from TensorFlow 'fprime': None, 'args': () } for key in fmin_l_bfgs_b_extra_kwargs.keys(): if key in fmin_l_bfgs_b_kwargs: raise ValueError( "The argument " + str(key) + " should not be overwritten by " "fmin_l_bfgs_b_extra_kwargs" ) # define the default extra arguments to fmin_l_bfgs_b default_extra_kwargs = { 'x0': flows_x0.flatten(), 'factr': 10.0, 'm': 20, 'iprint': -1 } fmin_l_bfgs_b_kwargs.update(default_extra_kwargs) fmin_l_bfgs_b_kwargs.update(fmin_l_bfgs_b_extra_kwargs) if grad_op is not None: loss_gradient = grad_op else: loss_gradient = tf.gradients(loss, flows, name='loss_gradient')[0] if loss_gradient is None: raise ValueError( "Cannot compute the gradient d(loss)/d(flows). Is the graph " "really differentiable?" ) sess_ = tf.Session() if sess is None else sess raw_results = fmin_l_bfgs_b(**fmin_l_bfgs_b_kwargs) if sess is None: sess_.close() return { 'flows': np.reshape(raw_results[0], flows_shape), 'loss': raw_results[1], 'info': raw_results[2] }