From 3321759484809b7f1350e5bcb60b9a16cf7a1fb7 Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Thu, 7 Jul 2022 18:03:22 +0200 Subject: [PATCH] LBFGS benchmark using a CNN. --- benchmarks/lbfgs_benchmark.py | 96 +++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/benchmarks/lbfgs_benchmark.py b/benchmarks/lbfgs_benchmark.py index 0c1d05a8..a6e0faa5 100644 --- a/benchmarks/lbfgs_benchmark.py +++ b/benchmarks/lbfgs_benchmark.py @@ -29,6 +29,10 @@ import matplotlib.pyplot as plt +from flax import linen as nn +import tensorflow_datasets as tfds +import tensorflow as tf + FLAGS = flags.FLAGS @@ -36,6 +40,8 @@ flags.DEFINE_integer("n_samples", default=10000, help="Number of samples.") flags.DEFINE_integer("n_features", default=200, help="Number of features.") flags.DEFINE_string("task", "binary_logreg", "Task to benchmark.") +flags.DEFINE_integer("batch_size", default=1024, help="Batch size.") +flags.DEFINE_string("dataset", default="mnist", help="Dataset to use.") def binary_logreg(linesearch): @@ -108,6 +114,91 @@ def run_multiclass_logreg(): plt.show() +def cnn(linesearch): + + def load_dataset(dataset, batch_size): + """Loads the dataset as a generator of batches.""" + train_ds, ds_info = tfds.load(f"{dataset}:3.*.*", split="train", + as_supervised=True, with_info=True) + train_ds = train_ds.repeat() + train_ds = train_ds.shuffle(10 * batch_size, seed=0) + train_ds = train_ds.batch(batch_size) + return tfds.as_numpy(train_ds), ds_info + + class CNN(nn.Module): + """A simple CNN model.""" + num_classes: int + net_width: int + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=self.net_width, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=self.net_width*2, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=self.net_width*4)(x) + x = nn.relu(x) + x = nn.Dense(features=self.num_classes)(x) + return x + + # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make + # it unavailable to JAX. + tf.config.experimental.set_visible_devices([], 'GPU') + + train_ds, ds_info = load_dataset(FLAGS.dataset, FLAGS.batch_size) + train_ds = iter(train_ds) + + # Initialize parameters. + input_shape = (1,) + ds_info.features["image"].shape + rng = jax.random.PRNGKey(0) + num_classes = ds_info.features["label"].num_classes + net_width = 4 + net = CNN(num_classes, net_width) + + logistic_loss = jax.vmap(jaxopt.loss.multiclass_logistic_loss) + + def loss_fun(params, data): + """Compute the loss of the network.""" + inputs, labels = data + x = inputs.astype(jnp.float32) / 255. + logits = net.apply({"params": params}, x) + loss_value = jnp.mean(logistic_loss(labels, logits)) + return loss_value + + net = CNN(num_classes, 4) + params = net.init(rng, jnp.zeros(input_shape))["params"] + + opt = jaxopt.LBFGS(fun=loss_fun, linesearch=linesearch) + state = opt.init_state(params) + jitted_update = jax.jit(opt.update) + + errors = onp.zeros(FLAGS.maxiter) + + for it in range(FLAGS.maxiter): + batch = next(train_ds) + params, state = jitted_update(params, state, batch) + errors[it] = state.error + + return errors + + +def run_cnn(): + errors_backtracking = cnn("backtracking") + errors_zoom = cnn("zoom") + + plt.figure() + plt.plot(jnp.arange(FLAGS.maxiter), errors_backtracking, label="backtracking") + plt.plot(jnp.arange(FLAGS.maxiter), errors_zoom, label="zoom") + plt.xlabel("Iterations") + plt.ylabel("Gradient error") + plt.yscale("log") + plt.legend(loc="best") + plt.show() + + def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") @@ -120,8 +211,13 @@ def main(argv): if FLAGS.task == "binary_logreg": run_binary_logreg() + elif FLAGS.task == "multiclass_logreg": run_multiclass_logreg() + + elif FLAGS.task == "cnn": + run_cnn() + else: raise ValueError("Invalid task name.")