Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions benchmarks/lbfgs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@

import matplotlib.pyplot as plt

from flax import linen as nn
import tensorflow_datasets as tfds
import tensorflow as tf


FLAGS = flags.FLAGS

flags.DEFINE_integer("maxiter", default=30, help="Max # of iterations.")
flags.DEFINE_integer("n_samples", default=10000, help="Number of samples.")
flags.DEFINE_integer("n_features", default=200, help="Number of features.")
flags.DEFINE_string("task", "binary_logreg", "Task to benchmark.")
flags.DEFINE_integer("batch_size", default=1024, help="Batch size.")
flags.DEFINE_string("dataset", default="mnist", help="Dataset to use.")


def binary_logreg(linesearch):
Expand Down Expand Up @@ -108,6 +114,91 @@ def run_multiclass_logreg():
plt.show()


def cnn(linesearch):

def load_dataset(dataset, batch_size):
"""Loads the dataset as a generator of batches."""
train_ds, ds_info = tfds.load(f"{dataset}:3.*.*", split="train",
as_supervised=True, with_info=True)
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(10 * batch_size, seed=0)
train_ds = train_ds.batch(batch_size)
return tfds.as_numpy(train_ds), ds_info

class CNN(nn.Module):
"""A simple CNN model."""
num_classes: int
net_width: int

@nn.compact
def __call__(self, x):
x = nn.Conv(features=self.net_width, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=self.net_width*2, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=self.net_width*4)(x)
x = nn.relu(x)
x = nn.Dense(features=self.num_classes)(x)
return x

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting - why do we have TF here at all? Is it because we import tfds above?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is due to tensorflow datasets, which imports (parts of) tensorflow. Without this line, tf occupies the memory of the GPU and jax can't use it


train_ds, ds_info = load_dataset(FLAGS.dataset, FLAGS.batch_size)
train_ds = iter(train_ds)

# Initialize parameters.
input_shape = (1,) + ds_info.features["image"].shape
rng = jax.random.PRNGKey(0)
num_classes = ds_info.features["label"].num_classes
net_width = 4
net = CNN(num_classes, net_width)

logistic_loss = jax.vmap(jaxopt.loss.multiclass_logistic_loss)

def loss_fun(params, data):
"""Compute the loss of the network."""
inputs, labels = data
x = inputs.astype(jnp.float32) / 255.
logits = net.apply({"params": params}, x)
loss_value = jnp.mean(logistic_loss(labels, logits))
return loss_value

net = CNN(num_classes, 4)
params = net.init(rng, jnp.zeros(input_shape))["params"]

opt = jaxopt.LBFGS(fun=loss_fun, linesearch=linesearch)
state = opt.init_state(params)
jitted_update = jax.jit(opt.update)

errors = onp.zeros(FLAGS.maxiter)

for it in range(FLAGS.maxiter):
batch = next(train_ds)
params, state = jitted_update(params, state, batch)
errors[it] = state.error

return errors


def run_cnn():
errors_backtracking = cnn("backtracking")
errors_zoom = cnn("zoom")

plt.figure()
plt.plot(jnp.arange(FLAGS.maxiter), errors_backtracking, label="backtracking")
plt.plot(jnp.arange(FLAGS.maxiter), errors_zoom, label="zoom")
plt.xlabel("Iterations")
plt.ylabel("Gradient error")
plt.yscale("log")
plt.legend(loc="best")
plt.show()


def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
Expand All @@ -120,8 +211,13 @@ def main(argv):

if FLAGS.task == "binary_logreg":
run_binary_logreg()

elif FLAGS.task == "multiclass_logreg":
run_multiclass_logreg()

elif FLAGS.task == "cnn":
run_cnn()

else:
raise ValueError("Invalid task name.")

Expand Down