Integrating Ray Tune, Hugging Face Transformers and W&B

Ruan Chaves Rodrigues
2 min readFeb 2, 2022

Update 03/21/2021: I published my modified version of run_glue.py as a public gist on GitHub.

Update 03/22/2021: My pull request has been accepted. The CustomTrainer subclass is not needed anymore for present and future versions of the transformer library: you can now use the Trainer directly.

There are a few articles, notebooks and code samples that teach how to integrate Ray Tune and Hugging Face Transformers, but they either leave out Weights & Biases or do not work anymore due to changes made to the library.

After some hours of experimentation, I figured out the right way to integrate them. First of all, there is a bug in the Trainer that won’t
allow you to use W&B and Ray Tune at the same time. I have already submitted a PR on this, but regardless of whether they accept it or not, you can currently fix this bug by creating a subclass that inherits from the Trainer:

Before you actually instantiate a CustomTrainer object, you’ll have to create two functions: model_init and hp_space_fn.
model_init has to simply return your model, and hp_space_fn has to return the config that will be used by Ray Trace and W&B.

A few points regarding the code below:

* You can get your wandb api key at wandb.ai/authorize. I like to set it as an environment variable and run my scripts as API_KEY=… WANDB_PROJECT=my_project_name python run_glue.py …’ .

  • The model_init function is meant to run on a modified version of run_glue.py.

Now you’re ready to run the hyperparameter_search Trainer method. Any additional parameters ( such as time_budget_s ) will be passed directly to tune.run, as stated in the docs.

The best hyperparameters will be returned as a dictionary that can be accessed at best_run.hyperparameters.

There may be better ways to do this, but this approach simply works. All code was tested on the transformers version 4.4.0.dev0.

P.S.: If for some reason you want to completely disable wandb, it is enough to omit the loggers argument on your call to trainer.hyperparameter_search, comment the config.update(wandb_config) line on hp_space_fn and remove the WandbCallback from the trainer:

--

--

Ruan Chaves Rodrigues

Machine Learning Engineer. MSc student at the EMLCT programme. Personal website: https://ruanchaves.github.io/