An easy-to-use LLM quantization package with user-friendly APIs, based on GPTQ algorithm (weight-only quantization).
use_marlin=True
when loading models.auto-gptq
, so now running and training GPTQ models can be more available to everyone! See this blog and itโs resources for more details!For more histories please turn to here
The result is generated using this script, batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better).
The quantized model is loaded using the setup that can gain the fastest inference speed.
model | GPU | num_beams | fp16 | gptq-int4 |
---|---|---|---|---|
llama-7b | 1xA100-40G | 1 | 18.87 | 25.53 |
llama-7b | 1xA100-40G | 4 | 68.79 | 91.30 |
moss-moon 16b | 1xA100-40G | 1 | 12.48 | 15.25 |
moss-moon 16b | 1xA100-40G | 4 | OOM | 42.67 |
moss-moon 16b | 2xA100-40G | 1 | 06.83 | 06.78 |
moss-moon 16b | 2xA100-40G | 4 | 13.10 | 10.80 |
gpt-j 6b | 1xRTX3060-12G | 1 | OOM | 29.55 |
gpt-j 6b | 1xRTX3060-12G | 4 | OOM | 47.36 |
For perplexity comparison, you can turn to here and here
AutoGPTQ is available on Linux and Windows only. You can install the latest stable release of AutoGPTQ from pip with pre-built wheels:
Platform version | Installation | Built against PyTorch |
---|---|---|
CUDA 11.8 | pip install auto-gptq --no-build-isolation --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ |
2.2.1+cu118 |
CUDA 12.1 | pip install auto-gptq --no-build-isolation |
2.2.1+cu121 |
ROCm 5.7 | pip install auto-gptq --no-build-isolation --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm573/ |
2.2.1+rocm5.7 |
AutoGPTQ can be installed with the Triton dependency with pip install auto-gptq[triton] --no-build-isolation
in order to be able to use the Triton backend (currently only supports linux, no 3-bits quantization).
For older AutoGPTQ, please refer to the previous releases installation table.
On NVIDIA systems, AutoGPTQ does not support Maxwell or lower GPUs.
Clone the source code:
git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
A few packages are required in order to build from source: pip install numpy gekko pandas
.
Then, install locally from source:
pip install -vvv --no-build-isolation -e .
You can set BUILD_CUDA_EXT=0
to disable pytorch extension building, but this is strongly discouraged as AutoGPTQ then falls back on a slow python implementation.
As a last resort, if the above command fails, you can try python setup.py install
.
To install from source for AMD GPUs supporting ROCm, please specify the ROCM_VERSION
environment variable. Example:
ROCM_VERSION=5.6 pip install -vvv --no-build-isolation -e .
The compilation can be speeded up by specifying the PYTORCH_ROCM_ARCH
variable (reference) in order to build for a single target device, for example gfx90a
for MI200 series devices.
For ROCm systems, the packages rocsparse-dev
, hipsparse-dev
, rocthrust-dev
, rocblas-dev
and hipblas-dev
are required to build.
Notice: make sure youโre in commit 65c2e15 or later
To install from source for Intel Gaudi 2 HPUs, set the BUILD_CUDA_EXT=0
environment variable to disable building the CUDA PyTorch extension. Example:
BUILD_CUDA_EXT=0 pip install -vvv --no-build-isolation -e .
Notice that Intel Gaudi 2 uses an optimized kernel upon inference, and requires
BUILD_CUDA_EXT=0
on non-CUDA machines.
warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good.
Below is an example for the simplest use of auto_gptq
to quantize a model and inference after quantization:
from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-4bit"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
examples = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
)
]
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
)
# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)
# save quantized model
model.save_quantized(quantized_model_dir)
# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)
# push quantized model to Hugging Face Hub.
# to use use_auth_token=True, Login first via huggingface-cli login.
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
# alternatively you can save and push at the same time
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
# download quantized model from Hugging Face Hub and load to the first GPU
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# inference with model.generate
print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
# or you can also use pipeline
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
print(pipeline("auto-gptq is")[0]["generated_text"])
For more advanced features of model quantization, please reference to this script
You can use tasks defined in auto_gptq.eval_tasks
to evaluate modelโs performance on specific down-stream task before and after quantization.
The predefined tasks support all causal-language-models implemented in ๐ค transformers and in this project.
tutorials provide step-by-step guidance to integrate auto_gptq
with your own project and some best practice principles.
examples provide plenty of example scripts to use auto_gptq
in different ways.
you can use
model.config.model_type
to compare with the table below to check whether the model you use is supported byauto_gptq
.for example, model_type of
WizardLM
,vicuna
andgpt4all
are allllama
, hence they are all supported byauto_gptq
.
model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
---|---|---|---|---|---|
bloom | โ | โ | โ | โ | ย |
gpt2 | โ | โ | โ | โ | ย |
gpt_neox | โ | โ | โ | โ | โ requires this peft branch |
gptj | โ | โ | โ | โ | โ requires this peft branch |
llama | โ | โ | โ | โ | โ |
moss | โ | โ | โ | โ | โ requires this peft branch |
opt | โ | โ | โ | โ | ย |
gpt_bigcode | โ | โ | โ | โ | ย |
codegen | โ | โ | โ | โ | ย |
falcon(RefinedWebModel/RefinedWeb) | โ | โ | โ | โ | ย |
Currently, auto_gptq
supports: LanguageModelingTask
, SequenceClassificationTask
and TextSummarizationTask
; more Tasks will come soon!
Tests can be run with:
pytest tests/ -s
AutoGPTQ defaults to using exllamav2 int4*fp16 kernel for matrix multiplication.
Marlin is an optimized int4 * fp16 kernel was recently proposed at https://github.com/IST-DASLab/marlin. This is integrated in AutoGPTQ when loading a model with use_marlin=True
. This kernel is available only on devices with compute capability 8.0 or 8.6 (Ampere GPUs).