Skip to content

DeepSeek-R1-Distill-Qwen-1.5B: torch.compile slower than AOTInductor #149582

@angelayi

Description

@angelayi

🐛 Describe the bug

  First iteration time (s) Average time over 30 calls (s)
Eager 0.21249 0.021328
torch.compile (tlparse) 25.765 6.1492, fastest: 0.009088
torch.compile with mark_dynamic (tlparse) 63.959 4.6029 (one less recompilation), fastest: 0.008864
AOTI compiled artifact w/ dynamic shapes 0.033352 0.0034717

More info: https://docs.google.com/document/d/1XPtQ0XoPv-VxUkx-7H9G6i68w7jLPMeu6cdUQedUrug/edit?tab=t.0

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.export import export, Dim

device = "cuda"

def test_model(model, tokenizer):
    class Qwen2(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.qwen = model

        def forward(self, x):
            result = self.qwen(x)
            result.past_key_values = ()
            return result

    qwen2 = Qwen2().to(device)

    prompt = "What are the benefits of using AI in healthcare?"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Generate response from the model
    input_ids = input_ids.to(device)
    torch._dynamo.mark_dynamic(input_ids, 1)

    torch._dynamo.reset()
    torch._inductor.config.force_disable_caches = True
    model = torch.compile(model, fullgraph=True)

    # ep = export(qwen2, (torch.cat([input_ids, input_ids, input_ids]).to(device),), dynamic_shapes=({0: Dim.DYNAMIC, 1: Dim.DYNAMIC},))
    # path = torch._inductor.aoti_compile_and_package(ep, package_path="deepseek_qwen2_aoti_dynamic.pt2")
    # model = torch._inductor.aoti_load_package(path)

    start = time.time()
    output = model(input_ids)
    end = time.time()
    print(f"Initial time taken: {end - start}")

    logits = output.logits
    next_token_id = torch.argmax(logits[:, -1])
    decoded_response = tokenizer.decode(next_token_id, skip_special_tokens=True)
    print("Prompt:", prompt)
    print("Response:", decoded_response)

    def generate_response(model, input_ids, max_length=30):
        times = 0
        response = []
        for i in range(max_length):
            input_ids = input_ids.to(device)
            start = time.time()
            output = model(input_ids)
            end = time.time()
            print(f"Time on iteration {i}: {end - start}")
            times += end - start
            logits = output.logits
            next_token_id = torch.argmax(logits[:, -1])
            response.append(next_token_id.item())
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
        print(f"Avg time per call: {times / max_length}")
        return response

    response_ids = generate_response(model, input_ids)
    decoded_response = tokenizer.decode(response_ids, skip_special_tokens=True)
    print("Prompt:", prompt)
    print("Response:", decoded_response)


if __name__ == "__main__":
    model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).eval().to(device)

    test_model(model, tokenizer)

Versions

main

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @suo @ydwu4 @desertfire @chenyang78 @yushangdi @benjaminglass1

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions