-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: aotinductoraot inductoraot inductoroncall: exportoncall: pt2topic: performancetopic categorytopic category
Description
🐛 Describe the bug
First iteration time (s) | Average time over 30 calls (s) | |
---|---|---|
Eager | 0.21249 | 0.021328 |
torch.compile (tlparse) | 25.765 | 6.1492, fastest: 0.009088 |
torch.compile with mark_dynamic (tlparse) | 63.959 | 4.6029 (one less recompilation), fastest: 0.008864 |
AOTI compiled artifact w/ dynamic shapes | 0.033352 | 0.0034717 |
More info: https://docs.google.com/document/d/1XPtQ0XoPv-VxUkx-7H9G6i68w7jLPMeu6cdUQedUrug/edit?tab=t.0
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.export import export, Dim
device = "cuda"
def test_model(model, tokenizer):
class Qwen2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.qwen = model
def forward(self, x):
result = self.qwen(x)
result.past_key_values = ()
return result
qwen2 = Qwen2().to(device)
prompt = "What are the benefits of using AI in healthcare?"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate response from the model
input_ids = input_ids.to(device)
torch._dynamo.mark_dynamic(input_ids, 1)
torch._dynamo.reset()
torch._inductor.config.force_disable_caches = True
model = torch.compile(model, fullgraph=True)
# ep = export(qwen2, (torch.cat([input_ids, input_ids, input_ids]).to(device),), dynamic_shapes=({0: Dim.DYNAMIC, 1: Dim.DYNAMIC},))
# path = torch._inductor.aoti_compile_and_package(ep, package_path="deepseek_qwen2_aoti_dynamic.pt2")
# model = torch._inductor.aoti_load_package(path)
start = time.time()
output = model(input_ids)
end = time.time()
print(f"Initial time taken: {end - start}")
logits = output.logits
next_token_id = torch.argmax(logits[:, -1])
decoded_response = tokenizer.decode(next_token_id, skip_special_tokens=True)
print("Prompt:", prompt)
print("Response:", decoded_response)
def generate_response(model, input_ids, max_length=30):
times = 0
response = []
for i in range(max_length):
input_ids = input_ids.to(device)
start = time.time()
output = model(input_ids)
end = time.time()
print(f"Time on iteration {i}: {end - start}")
times += end - start
logits = output.logits
next_token_id = torch.argmax(logits[:, -1])
response.append(next_token_id.item())
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
print(f"Avg time per call: {times / max_length}")
return response
response_ids = generate_response(model, input_ids)
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=True)
print("Prompt:", prompt)
print("Response:", decoded_response)
if __name__ == "__main__":
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).eval().to(device)
test_model(model, tokenizer)
Versions
main
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @suo @ydwu4 @desertfire @chenyang78 @yushangdi @benjaminglass1
Metadata
Metadata
Assignees
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: aotinductoraot inductoraot inductoroncall: exportoncall: pt2topic: performancetopic categorytopic category