A team of researchers at MIT and the MIT-IBM Watson AI Lab developed a new technique that enables on-device training using less than a quarter of a megabyte of memory. The new development is an impressive achievement as other training solutions usually need more than 500 megabytes of memory, which exceeds the 256-kilobyte capacity of most microcontrollers.
By training a machine-learning model on an intelligent edge device, it can adapt to new data and make better predictions. With that said, the training process usually requires a lot of memory, so it is often carried out with computers at a data center before the model is deployed on a device. This process is far more costly and raises privacy concerns compared to the new technique developed by the team.
The researchers developed the algorithms and framework in a way that reduces the amount of computation needed to train a model, making the process faster and more memory efficient. The technique can help train a machine-learning model on a microcontroller in just a few minutes.
The new technique also helps with privacy as it keeps the data on the device, which is important when sensitive data is involved. At the same time, the framework improves the accuracy of the model when compared to other approaches.
Song Han is an associate professor in the Department of Electrical Engineering and Computer Science (EECS), a member of the MIT-IBM Watson AI Lab, and senior author of the research paper.
“Our study enables IoT devices to not only perform inference but also continuously update the AI models to newly collected data, paving the way for lifelong on-device learning,” Han said. “The low resource utilization makes deep learning more accessible and can have a broader reach, especially for low-power edge devices.”
The paper included co-lead authors and EECS PhD students Ji Lin and Ligeng Zhu, and MIT postdocs Wei-Ming Chen and Wei-Chen Wang. It also included Chuang Gan, a principal research staff member at the MIT-IBM Watson AI Lab.
Making the Training Process More Efficient
To make the training process more efficient and less memory-intensive, the team relied on two algorithmic solutions. The first is known as sparse update, which uses an algorithm that identifies the most important weights to update during each round of training. The algorithm freezes the weights one at a time until the accuracy falls to a certain threshold, at which point it stops. The remaining weights are then updated and the activations corresponding to the frozen weights don’t need to be stored in memory.
“Updating the whole model is very expensive because there is a lot of activation, so people tend to update only the last layer, but as you can imagine, this hurts the accuracy,” Han said. “For our method, we selectively update those important weights and make sure the accuracy is fully preserved.”
The second solution developed by the team involves quantized training and simplifying the weights. An algorithm first rounds the weights to only eight bits through a quantization process which also cuts the amount of memory for training and inference, with inference being the process of applying a model to a dataset and generating a prediction. The algorithm then relies on a technique called quantization-aware scaling (QAS), which acts like a multiplier to adjust the ratio between weight and gradient. This helps avoid any drop in accuracy that could result from quantized training.
The researchers developed a system called a tiny training engine, which runs the algorithm innovations on a simple microcontroller lacking an operating system. To complete more work in the compilation stage, prior to the deployment of the model on the edge device, the system changes the order of steps in the training process.
“We push a lot of the computation, such as auto-differentiation and graph optimization, to compile time. We also aggressively prune the redundant operators to support sparse updates. Once at runtime, we have much less workload to do on the device,” Han says.
Highly Efficient Technique
While traditional techniques designed for lightweight training usually would need around 300 to 600 megabytes of memory, the team’s optimization only needed 157 kilobytes to train a machine-learning model on a microcontroller.
The framework was tested by training a computer vision model to detect people in images, and it learned to complete this task in just 10 minutes. The method was also able to train a model more than 20 times faster than other methods.
The researchers will now look to apply the techniques to language models and different types of data. They also want to use this acquired knowledge to shrink larger models without a loss in accuracy, which could also help reduce the carbon footprint of training large-scale machine-learning models.