Inteligência artificial
Otimizando memória para inferência e ajuste fino de modelos de linguagem grande
Grandes modelos de linguagem (LLMs) como GPT-4, Bloom e LLaMA alcançaram capacidades notáveis ao escalar até bilhões de parâmetros. No entanto, a implantação desses modelos massivos para inferência ou ajuste fino é um desafio devido aos seus imensos requisitos de memória. Neste blog técnico, exploraremos técnicas para estimar e otimizar o consumo de memória durante a inferência LLM e o ajuste fino em várias configurações de hardware.
Compreendendo os requisitos de memória
A memória necessária para carregar um LLM é determinada principalmente pelo número de parâmetros e pela precisão numérica usada para armazenar os parâmetros. Uma regra simples é:
- Carregar um modelo com X bilhões de parâmetros requer aproximadamente 4X GB de VRAM em 32-bit precisão flutuante
- Carregar um modelo com X bilhões de parâmetros requer aproximadamente 2X GB de VRAM em 16-bit precisão bfloat16/float16
Por exemplo, carregar o modelo GPT-175 de parâmetro 3B exigiria aproximadamente 350 GB de VRAM com precisão bfloat16. Atualmente, as maiores GPUs disponíveis comercialmente, como NVIDIA A100 e H100, oferecem apenas 80 GB de VRAM, necessitando de paralelismo de tensor e técnicas de paralelismo de modelo.
Durante a inferência, o consumo de memória é dominado pelos parâmetros do modelo e pelos tensores de ativação temporários produzidos. Uma estimativa de alto nível para o pico de uso de memória durante a inferência é a soma da memória necessária para carregar os parâmetros do modelo e a memória para ativações.
Quantificando a memória de inferência
Vamos quantificar os requisitos de memória para inferência usando o modelo OctoCode, que possui cerca de 15 bilhões de parâmetros no formato bfloat16 (~ 31 GB). Usaremos o Biblioteca de transformadores para carregar o modelo e gerar texto:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder",
torch_dtype=torch.bfloat16,
device_map="auto",
pad_token_id=0)
tokenizer = AutoTokenizer.from_pretrained("bigcode/octocoder")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
prompt = "Question: Please write a Python function to convert bytes to gigabytes.\n\nAnswer:"
result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
def bytes_to_gigabytes(bytes):
return bytes / 1024 / 1024 / 1024
bytes_to_gigabytes(torch.cuda.max_memory_allocated())
Saída:
29.0260648727417O pico de uso de memória da GPU é de cerca de 29 GB, o que se alinha com nossa estimativa de 31 GB para carregar os parâmetros do modelo no formato bfloat16.
Otimizando Memória de Inferência com Quantização
Embora bfloat16 seja a precisão comum usada para treinar LLMs, os pesquisadores descobriram que quantizar os pesos do modelo para tipos de dados de menor precisão, como inteiros de 8 bits (int8) ou inteiros de 4 bits, pode reduzir significativamente o uso de memória com perda mínima de precisão para tarefas de inferência como geração de texto.
Vamos ver a economia de memória da quantização de 8 e 4 bits do modelo OctoCode:
</div>
# 8-bit quantization
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", load_in_8bit=True,
pad_token_id=0)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
bytes_to_gigabytes(torch.cuda.max_memory_allocated())</pre>
Saída:
15.219234466552734# 4-bit quantization
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", load_in_4bit=True,
low_cpu_mem_usage=True, pad_token_id=0)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
bytes_to_gigabytes(torch.cuda.max_memory_allocated())
Saída:
9.543574333190918Com a quantização de 8 bits, o requisito de memória cai de 31 GB para 15 GB, enquanto a quantização de 4 bits reduz ainda mais para apenas 9.5 GB! Isso permite executar o modelo OctoCode de 15B de parâmetro em GPUs de consumo como o RTX 3090 (24GB VRAM).
No entanto, observe que quantizações mais agressivas, como a de 4 bits, podem, às vezes, levar à degradação da precisão em comparação com a precisão de 8 bits ou bfloat16. Há um compromisso entre economia de memória e precisão que os usuários devem avaliar para seu caso de uso.
A quantização é uma técnica poderosa que pode permitir a implantação de LLM em ambientes com recursos limitados, como instâncias de nuvem, dispositivos de borda ou até mesmo telefones celulares, reduzindo drasticamente o consumo de memória.
Estimando memória para ajuste fino
Embora a quantização seja usada principalmente para inferência eficiente, técnicas como paralelismo de tensores e paralelismo de modelos são cruciais para gerenciar requisitos de memória durante o treinamento ou afinação de grandes modelos de linguagem.
O pico de consumo de memória durante o ajuste fino é normalmente 3 a 4 vezes maior que o inferido devido aos requisitos adicionais de memória para:
- Gradientes
- Estados do otimizador
- Ativações do passe direto armazenadas para retropropagação
Uma estimativa conservadora é que o ajuste fino de um LLM com X bilhões de parâmetros requer cerca de 4 * (2X) = 8XGB de VRAM com precisão bfloat16.
Por exemplo, o ajuste fino do modelo LLaMA de parâmetro 7B exigiria aproximadamente 7 * 8 = 56 GB de VRAM por GPU com precisão bfloat16. Isso excede a capacidade de memória das GPUs atuais, necessitando de técnicas distribuídas de ajuste fino.
Técnicas de ajuste fino distribuídas
Vários métodos de ajuste fino distribuído foram propostos para superar as restrições de memória da GPU para modelos grandes:
- Paralelismo de dados: a abordagem clássica de paralelismo de dados replica todo o modelo em várias GPUs enquanto divide e distribui os lotes de dados de treinamento. Isso reduz o tempo de treinamento linearmente com o número de GPUs, mas não reduz o pico de exigência de memória em cada GPU.
- ZeRO Estágio 3: uma forma avançada de paralelismo de dados que particiona os parâmetros do modelo, gradientes e estados do otimizador entre GPUs. Reduz a memória em comparação com o paralelismo de dados clássico, mantendo apenas os dados particionados necessários em cada GPU durante as diferentes fases de treinamento.
- Paralelismo tensorial: em vez de replicar o modelo, o paralelismo tensorial divide os parâmetros do modelo em linhas ou colunas e os distribui pelas GPUs. Cada GPU opera em um conjunto particionado de parâmetros, gradientes e estados de otimização, levando a economias substanciais de memória.
- Paralelismo de pipeline: esta técnica particiona as camadas do modelo em diferentes GPUs/workers, com cada dispositivo executando um subconjunto de camadas. As ativações são passadas entre trabalhadores, reduzindo o pico de memória, mas aumentando a sobrecarga de comunicação.
Estimar o uso de memória para esses métodos distribuídos não é trivial, pois a distribuição de parâmetros, gradientes, ativações e estados do otimizador varia entre as técnicas. Além disso, diferentes componentes, como o corpo do transformador e o cabeçote de modelagem da linguagem, podem apresentar diferentes comportamentos de alocação de memória.
A solução LLMem
Pesquisadores propuseram recentemente LLMem, uma solução que estima com precisão o consumo de memória da GPU ao aplicar métodos de ajuste fino distribuídos a LLMs em várias GPUs.
O LLMem considera fatores como recombinação de parâmetros antes da computação (ZeRO Estágio 3), coleta de saída na passagem para trás (paralelismo de tensor) e as diferentes estratégias de alocação de memória para o corpo do transformador e cabeçote de modelagem de linguagem.
Os resultados experimentais mostram que o LLMem pode estimar o pico de uso da memória da GPU para ajuste fino de LLMs em uma única GPU com taxas de erro de até 1.6%, superando a taxa média de erro do DNNMem de última geração de 42.6%. Ao aplicar métodos de ajuste fino distribuídos a LLMs com mais de um bilhão de parâmetros em múltiplas GPUs, o LLMem atinge uma impressionante taxa de erro média de 3.0%.









