How to design API of Machine Learning library
We released the ML library Higgsfield two days ago. Higgsfield is a fault-tolerant, highly scalable cluster management, and a machine learning framework designed for training models with billions to trillions of parameters.
With Higgsfield, you can build systems like ChatGPT on large clusters of many nodes. Utilize your GPU cluster at 100% efficiency. For more details, you can read the GitHub readme. In this post, I want to share how I designed a small part of the API that enables the writing of training logic for large language models like Llama 70B.
In the past nine years of my deep learning journey, I have come across a vast number of frameworks. Lua Torch was a fantastic framework that initially died due to a lack of Python's ecosystem, but then rose again as PyTorch. Theano was also a great framework, but its major drawback was difficult debugging. I remember spending two weeks writing a Neural Turing Machine for solving bAbI tasks on theano. (Nowadays, it would take a couple hours on Pytorch). Tensorflow - I still don't understand what that was, a terrible framework. There was also Caffe, which was popular in computer vision. Julia is another language that attempted to introduce automatic differentiation as a built-in feature. And JAX, which I was originally biased against since it's a Google product. But some close friends persuaded me to try it, and I actually liked it. However, I thought that it would be difficult for JAX to gain widespread adoption in the community, as PyTorch already had a strong network effect and was gaining traction quickly. I didn't see how anyone could catch up with PyTorch. Another issue with JAX is that it requires additional cognitive load for developers.
A lot has changed since then. PyTorch has rightfully become the leading deep learning framework. And at one point, it seemed like we had reached a perfect equilibrium. But then new frameworks like LightningAI and accelerate appeared, as well as new methods for model sharding like deepspeed and FSDP. And most importantly, large foundation models emerged that can only be trained by big corporations like Meta.
In the past, the focus was on creating new architectures. But now the focus is on creating complex distributed training pipelines like RLHF and RLAIF. There's no need to change the architecture, since it would cost a lot of money to retrain a new model.
Next, I'll tell you how I designed the API considering all of this. I went through four iterations - from a simple script launch to a convenient solution.
Version zero:
Call script, no API design, all that the user can do is fill out a large configuration file and call the script. The configuration file is very large because it has to cover all the user cases for different hardware, cluster size, dataset size, training requirements, etc. But even with this, we were able to release Higgsfield because the foundation of the library is distributed training on a cluster. However, from a developer experience perspective, the interface is awful for the end user.
Version one:
Call the 'train' function, now the user can write their own 'business logic,' which includes preparing the dataset and arguments for the experiment. The user has to run a lot of different experiments and change the arguments, like the learning rate. The fact that the 'train' function has to cover all the possible use cases for all users means that it needs to pass a huge list of arguments, which I think makes it very cumbersome and difficult to use.
I'd call this argument hell, which is exactly why I really dislike the trainer Huggingface interface. Having a class of arguments with 2600 lines is just too much!
Second version:
Allow the user to compose their own model and training algorithm. Let's look at the APIs of once successful libraries like Keras or sklearn, there was a temptation to make the API like Keras and give the user a false sense of control over what they're doing. When Keras came out, I was into complex computational graphs like Neural Turing Machine or Differentiable K-dimensional Tree, and the idea of a library like Keras where all you could do was stack convolutional layers was kinda gross to me. But its success with total machine learning noobs convinced me to give it a try.
Frameworks like Keras died out because they were impossible to debug properly, but now Hugging Face Trainer is just as difficult to debug because there are 3800 lines of code. It's a debugging nightmare.
If I was pressed for time and had to release ASAP, I might've released with this API, but I decided to give myself two more days and extend the deadline.
The callback hell similar to Keras was concerning - the user might want to write custom checkpointing or monitoring logic, and that would lead to a callback hell. With the additional complexity of running on a cluster, it would be very difficult to debug. I couldn't allow that complexity for a beginner user.
Version three:
We ditch fit/predict and allow the user to write their own custom PyTorch loop. This gives them more flexibility and control over their code.
The problem is, we can't let the user write custom PyTorch code because it won't work in a distributed setting. That's why solutions like huggingface’ accelerate were developed, which wrap certain objects to make them work in a distributed setting.
At first glance, this seems like an amazing solution. It allows the user to write any logic they want, and everything else just magically works. But that's exactly the problem - the fact that it works like magic can quickly get overwhelming, especially when there are a lot of edge cases.
For example, when using FSDP, you can't just wrap all the objects at once. Instead, you need to start by wrapping the model. Then, you create the optimizer and pass in the model's parameters. Finally, wrap everything else, but make sure to do it in the right order.
I really disliked this solution because it works like magic. What's more, it's not clear to the user how this magic works. And they have to remember all the different magical incantations - for example, sometimes in one order, and sometimes in another.
In addition to this, I think it does too much. It would be okay if it only wrapped the model and optimizer for sharding. But I don't see the point in wrapping the data loader, since all it does is split the data loader into separate processes. I also don't see the point in making magic out of ordinary things like gradient accumulation.
Version four:
Inheriting FSDP and rejecting magic. The final solution is to only wrap the model in FSDP for sharding, and everything else should remain as plain PyTorch. The ultimate goal of the PyTorch developers is to make sharding so simple that users won't notice any difference between distributed and regular code. In my opinion, they're doing a great job at achieving this gradually. So, I don't think it makes sense to make the API 'magical' - instead, it's better to use as much of the core work the PyTorch developers have already done.
As I mentioned at the beginning, I don't see the point in supporting the implementation of different architectures. There's value in using pre-trained models like Llama or Falcon. If, in the past, deep learning engineering often meant creating new model architectures, now you need a ton of money to change the architecture of an LLM and train it from scratch.
Therefore, I decided to create an API where you can use foundational models with arguments for sharding and training. Like zero stage, mixed precision, CPU offloading for memory-limited situations, etc.
After that, you can use standard Pytorch code without any magic.
Creating optimizer and learning scheduler.
Next, the standard Pytorch loop. And users can write their own custom logic for monitoring and saving checkpoints.
Gradient accumulation example:
Gradient clipping:
For a more in-depth understanding, you can check out the tutorial.
And that's how my little API design adventure ended. Soon I'll write part two, about how I designed the API for RLHF and RLAIF. If you're interested, follow higgsfield on GitHub and star it so you don't miss the next API update.