[Bug] Using A Short Prompt In A Quantized Text_encoder Will Result In An Error

by ADMIN 79 views

Introduction

We have encountered a bug in the Nunchaku library, specifically when using a short prompt in a quantized text encoder. The issue arises when attempting to process a short input prompt using a text encoder that has been quantized using the SVDQuant method. In this report, we will outline the steps to reproduce the bug, provide the necessary environment information, and discuss potential solutions.

Checklist

Before submitting this bug report, we have taken the necessary steps to ensure that the issue is not already documented or resolved:

  • We have searched for related issues and FAQs on the Nunchaku GitHub discussion board, but were unable to find a solution.
  • The issue persists in the latest version of the Nunchaku library.
  • We have provided a minimal reproducible example to facilitate the debugging process.
  • We have checked if the issue is related to ComfyUI and reported it accordingly.

Describe the Bug

Flux SVDQuant

When using a short prompt in a quantized text encoder, we encounter an error. The error occurs when attempting to process a short input prompt using a text encoder that has been quantized using the SVDQuant method.

Error Information

The error message is as follows:

python: /nunchaku/src/kernels/awq/gemv_awq.cu:307: gemv_awq(Tensor, Tensor, Tensor, Tensor, int, int, int, int)::<lambda()> [with half_t = __nv_bfloat16]: Assertion 'group_size == GROUP_SIZE' failed. Aborted (core dumped)

Reproduction Steps

To reproduce the bug, follow these steps:

  1. Install the necessary dependencies, including PyTorch 2.5.1, CUDA 12.4, and an RTX 3090 GPU.
  2. Import the necessary libraries, including torch, FluxPipeline, NunchakuFluxTransformer2dModel, and NunchakuT5EncoderModel.
  3. Load a quantized text encoder using the NunchakuT5EncoderModel.from_pretrained method.
  4. Set the attention implementation to "nunchaku-fp16" using the transformer.set_attention_impl method.
  5. Create a FluxPipeline instance using the FluxPipeline.from_pretrained method.
  6. Apply the cache on the pipeline using the apply_cache_on_pipe function.
  7. Process a short input prompt using the pipeline.

Environment

The bug was encountered on the following environment:

  • Operating System: Ubuntu 22.04
  • Python Version: 3.10
  • PyTorch Version: 2.5.1
  • CUDA Version: 12.4
  • GPU: RTX 3090

Reproduction Code

The reproduction code is as follows:

import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel

precision = "int4"
er = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
transformer.set_attention_impl("nunchaku-fp16")  # edit 1
# qencoder
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")  # edit2
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16,
    text_encoder_2=text_encoder_2,
).to("cuda")
apply_cache_on_pipe(pipeline, residual_diff_threshold=0.12)  # edit 3
image = pipeline("a dog", num_inference_steps=30, guidance_scale=7.5).images[0]
image.save(f"flux.1-dev-{precision}.png")

Potential Solutions

Based on the reproduction code and the error message, we suspect that the issue may be related to the quantization of the text encoder. Specifically, the SVDQuant method may not be compatible with short input prompts. To resolve this issue, we recommend the following:

  • Use a different quantization method: Consider using a different quantization method, such as the naive or cache method, to see if the issue persists.
  • Increase the input prompt length: Try increasing the length of the input prompt to see if the issue is related to the short length of the prompt.
  • Update the Nunchaku library: Ensure that the Nunchaku library is up-to-date, as the issue may have been resolved in a newer version.

Q: What is the issue with using a short prompt in a quantized text encoder?

A: When using a short prompt in a quantized text encoder, we encounter an error. The error occurs when attempting to process a short input prompt using a text encoder that has been quantized using the SVDQuant method.

Q: What is the error message?

A: The error message is as follows:

python: /nunchaku/src/kernels/awq/gemv_awq.cu:307: gemv_awq(Tensor, Tensor, Tensor, Tensor, int, int, int, int)::<lambda()> [with half_t = __nv_bfloat16]: Assertion 'group_size == GROUP_SIZE' failed. Aborted (core dumped)

Q: What are the reproduction steps?

A: To reproduce the bug, follow these steps:

  1. Install the necessary dependencies, including PyTorch 2.5.1, CUDA 12.4, and an RTX 3090 GPU.
  2. Import the necessary libraries, including torch, FluxPipeline, NunchakuFluxTransformer2dModel, and NunchakuT5EncoderModel.
  3. Load a quantized text encoder using the NunchakuT5EncoderModel.from_pretrained method.
  4. Set the attention implementation to "nunchaku-fp16" using the transformer.set_attention_impl method.
  5. Create a FluxPipeline instance using the FluxPipeline.from_pretrained method.
  6. Apply the cache on the pipeline using the apply_cache_on_pipe function.
  7. Process a short input prompt using the pipeline.

Q: What is the environment in which the bug was encountered?

A: The bug was encountered on the following environment:

  • Operating System: Ubuntu 22.04
  • Python Version: 3.10
  • PyTorch Version: 2.5.1
  • CUDA Version: 12.4
  • GPU: RTX 3090

Q: What are the potential solutions to the issue?

A: Based on the reproduction code and the error message, we suspect that the issue may be related to the quantization of the text encoder. Specifically, the SVDQuant method may not be compatible with short input prompts. To resolve this issue, we recommend the following:

  • Use a different quantization method: Consider using a different quantization method, such as the naive or cache method, to see if the issue persists.
  • Increase the input prompt length: Try increasing the length of the input prompt to see if the issue is related to the short length of the prompt.
  • Update the Nunchaku library: Ensure that the Nunchaku library is up-to-date, as the issue may have been resolved in a newer version.

Q: How can I report a similar issue or ask a question?

A: If you encounter a similar issue or have a question, please submit a new issue on the Nunchaku GitHub discussion board. Be sure to provide a minimal reproducible example and any relevant environment information to facilitate the debugging process.

Q: How can I contribute to the Nunchaku library?

A: If you would like to contribute to the Nunchaku library, please submit a pull request with your changes. Be sure to follow the contributing guidelines and ensure that your changes are thoroughly tested.