Tips for Debugging PyTorch Errors
Debugging PyTorch errors can be very difficult, mostly due to poor error messages. Here I share some of the strategies I use to get a more informative error message or (hopefully) fix the issue.
- Use the CPU instead of a GPU. CUDA errors (other than “out of memory”) are usually useless. A trick to get a useful error message is to use the CPU instead of GPU.
import torch device = "cpu" torch.device(device) model = MyModel() model = model.to(device)
- Disable asynchronous kernel launches (GPU). To do this you have to set an environment variable,
CUDA_LAUNCH_BLOCKING. It’s easiest to set this variable in your
bashscript. Check this stackoverflow thread for more information.
export CUDA_VISIBLE_DEVICES=0 export CUDA_LAUNCH_BLOCKING=1 python my_program.py
Out of Memory Errors
You can check the amount of GPU memory used with Linux command
nvidia-smi. This command is useful when maximizing your batch size.
Check batch dimensions. This is important when training language models that can only have inputs of a certain size (i.e. BERT and an input ). Note: Thanks, Pamela Shapiro for this tip!
Only put what you need on the GPU. If you have a lot of data, ideally only the model and the current batch should be on the GPU. In general PyTorch is pretty good at removing unneeded variables, so clearing GPU cache (
torch.cuda.empty_cache()) often does not help.
# Bad memory usage data = data.to(device) model = model.to(device) for epoch in num_epochs: for batch in data: out = model(batch) # Good memory usage model = model.to(device) for epoch in num_epochs: for batch in data: batch = batch.to(device) out = model(batch)
Reduce batch size. I put this common trick last because this saves the pain of reducing the batch size to 1 and still having memory errors.
If your loss is
nan there are a few possible culprits. You should check the following:
Nansin your features / labels. This is definitely the first place to look if your loss is
if torch.isnan(X).any(): print("Oh no! Missing data!") X = torch.nan_to_num(X, nan=0.0) print("All fixed :)") else: print("Data is all good!")
- Custom loss function. If you are using a custom loss, swap out the custom loss for a built-in one (e.g. MSE).
Thank you Rachel Wicks and Nathaniel Weir for your help with my
nan issue :)