Tips for Debugging PyTorch Errors

Debugging PyTorch errors can be very difficult, mostly due to poor error messages. Here I share some of the strategies I use to get a more informative error message or (hopefully) fix the issue.

GPU Errors

  1. Use the CPU instead of a GPU. CUDA errors (other than “out of memory”) are usually useless. A trick to get a useful error message is to use the CPU instead of GPU.
    import torch
    device = "cpu"
    model = MyModel()
    model =
  2. Disable asynchronous kernel launches (GPU). To do this you have to set an environment variable, CUDA_LAUNCH_BLOCKING. It’s easiest to set this variable in your bash script. Check this stackoverflow thread for more information.

Out of Memory Errors

You can check the amount of GPU memory used with Linux command nvidia-smi. This command is useful when maximizing your batch size.

  1. Check batch dimensions. This is important when training language models that can only have inputs of a certain size (i.e. BERT and an input ). Note: Thanks, Pamela Shapiro for this tip!

  2. Only put what you need on the GPU. If you have a lot of data, ideally only the model and the current batch should be on the GPU. In general PyTorch is pretty good at removing unneeded variables, so clearing GPU cache (torch.cuda.empty_cache()) often does not help.

    # Bad memory usage
    data =
    model =
    for epoch in num_epochs:
     for batch in data:
         out = model(batch)
    # Good memory usage
    model =
    for epoch in num_epochs:
     for batch in data:
         batch =
         out = model(batch)
  3. Reduce batch size. I put this common trick last because this saves the pain of reducing the batch size to 1 and still having memory errors.

NaN Loss

If your loss is nan there are a few possible culprits. You should check the following:

  1. Nans in your features / labels. This is definitely the first place to look if your loss is nan immediately.
     if torch.isnan(X).any():
         print("Oh no! Missing data!")
         X = torch.nan_to_num(X, nan=0.0)
         print("All fixed :)")
         print("Data is all good!")
  2. Custom loss function. If you are using a custom loss, swap out the custom loss for a built-in one (e.g. MSE).

Thank you Rachel Wicks and Nathaniel Weir for your help with my nan issue :)

Comments? Questions? Let me know! @Alexir563