Scaling machine learning pipelines using PyTorch can be a pain.
You typically start a PyTorch-based machine learning project by defining the model architecture. Then you run it on a CPU machine and progressively create a training pipeline. Once the pipeline is done, you run the same code on a GPU or TPU machine for faster gradient computations. You update the PyTorch code to load all the tensors to the GPU/TPU memory with a ‘.to(device)’ function call. Now comes the difficult part: what if you want to use distributed training for the same pipeline? You have to overhaul the code and test it to make sure nothing is broken.
Why sweat the small stuff? Let's use PyTorch Lightning instead.
(more…)