Checkpointing and Reusing TensorFlow Models

In my last two posts I introduced TensorFlow and wrote a very simple predictive model. In doing so I introduced many of the key concepts of TensorFlow:

  • The Session, the core of the TensorFlow object model,
  • Computational graphs and some of their elements: placeholders, variables, and Tensors,
  • Training models by iteratively calling Session.run on Optimization objects.

In this post I want to show you can save and re-use the results of your TensorFlow models. As we discussed last time, training a model means finding variable values that suit a particular purpose, for example finding a slope and intercept that defines a line that best fits a series of points. Training a model can be computationally expensive because we have to search for the best variable values through optimization. Suppose we want to use the results of this trained model over and over again, but without re-training the model each time. You can do this in TensorFlow using the Saver object.

A Saver object can save and restore the values of TensorFlow Variables. A typical scenario has three steps:

  1. Creating a Saver and telling the Saver which variables you want to save,
  2. Save the variables to a file,
  3. Restore the variables from a file when they are needed.

A Saver deals only with Variables. It does not work with placeholders, sessions, expressions, or any other kind of TensorFlow object. Here is a simple example that saves and restores two variables:

def save(checkpoint_file=’hello.chk’):
    with tf.Session() as session:
        x = tf.Variable([42.0, 42.1, 42.3], name=’x’)
        y = tf.Variable([[1.0, 2.0], [3.0, 4.0]], name=’y’)
        not_saved = tf.Variable([-1, -2], name=’not_saved’)
        session.run(tf.initialize_all_variables())

        print(session.run(tf.all_variables()))
        saver = tf.train.Saver([x, y])
        saver.save(session, checkpoint_file)

def restore(checkpoint_file=’hello.chk’):
    x = tf.Variable(-1.0, validate_shape=False, name=’x’)
    y = tf.Variable(-1.0, validate_shape=False, name=’y’)
    with tf.Session() as session:
        saver = tf.train.Saver()
        saver.restore(session, checkpoint_file)
        print(session.run(tf.all_variables()))

def reset():
    tf.reset_default_graph()

Try calling save(), reset() and then restore(), and compare the outputs to verify everything worked out. When you create a Saver, you should specify a list (or dictionary) of Variable objects you wish to save. (If you don’t, TensorFlow will assume you are interested in all the variables in your current session.) The shapes and values of these values will be stored in binary format when you call the save() method, and retrieved on restore(). Notice in my last function, when I create x and y, I give dummy values and say validate_shape=False. This is because I want the saver to determine the values and shapes when the variables are restored. If you’re wondering why the reset() function is there, remember that computational graphs are associated with Sessions. I want to “clear out” the state of the Session so I don’t have multiple x and y objects floating around as we call save and restore().

When you use Saver in real models, you should keep a couple of facts in mind:

  1. If you want to do anything useful with the Variables you restore, you may need to recreate the rest of the computational graph.
  2. The computational graph that you use with restored Variables need not be the same as the one that you used when saving. That can be useful!
  3. Saver has additional methods that can be helpful if your computation spans machines, or if you want to avoid overwriting old checkpoints on successive calls to save().

At the end of this post I have include a modification of my line fitting example to optionally save and restore model results. I’ve highlighted the interesting parts. You can call it like this:

fit_line(5, checkpoint_file=’vars.chk’)
reset()
fit_line(5, checkpoint_file=’vars.chk’, restore=True)

With this version, I could easily “score” new data points x using my trained model.

def fit_line(n=1, log_progress=False, iter_scale=200,
             restore=False, checkpoint_file=None):
    with tf.Session() as session:
        x = tf.placeholder(tf.float32, [n], name=’x’)
        y = tf.placeholder(tf.float32, [n], name=’y’)
        m = tf.Variable([1.0], name=’m’)
        b = tf.Variable([1.0], name=’b’)
        y = tf.add(tf.mul(m, x), b) # fit y_i = m * x_i + b
        y_act = tf.placeholder(tf.float32, [n], name=’y_’)

        # minimize sum of squared error between trained and actual.
        error = tf.sqrt((y – y_act) * (y – y_act))
        train_step = tf.train.AdamOptimizer(0.05).minimize(error)

        x_in, y_star = make_data(n)

        saver = tf.train.Saver()
        feed_dict = {x: x_in, y_act: y_star}
        if restore:
            print(“Loading variables from ‘%s’.” % checkpoint_file)
            saver.restore(session, checkpoint_file)
            y_i, m_i, b_i = session.run([y, m, b], feed_dict)
        else:
            init = tf.initialize_all_variables()
            session.run(init)
            for i in range(iter_scale * n):
                y_i, m_i, b_i, _ = session.run([y, m, b, train_step],
                                               feed_dict)
                err = np.linalg.norm(y_i – y_star, 2)
                if log_progress:
                    print(“%3d | %.4f %.4f %.4e” % (i, m_i, b_i, err))

            print(“Done training! m = %f, b = %f, err = %e, iter = %d”
                  % (m_i, b_i, err, i))
            if checkpoint_file is not None:
                print(“Saving variables to ‘%s’.” % checkpoint_file)
                saver.save(session, checkpoint_file)

        print(”      x: %s” % x_in)
        print(“Trained: %s” % y_i)
        print(” Actual: %s” % y_star)

Author: natebrix

Follow me on twitter at @natebrix.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s