Skip to content

Commit b5ace32

Browse files
committed
use spawn
1 parent 7e42fa6 commit b5ace32

File tree

3 files changed

+52
-41
lines changed

3 files changed

+52
-41
lines changed

bench.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,28 @@
55
# from vllm import LLM, SamplingParams
66

77

8-
seed(0)
9-
num_seqs = 256
10-
max_input_len = 1024
11-
max_ouput_len = 1024
8+
def main():
9+
seed(0)
10+
num_seqs = 256
11+
max_input_len = 1024
12+
max_ouput_len = 1024
1213

13-
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
14-
llm = LLM(path, enforce_eager=False, max_model_len=4096)
14+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
15+
llm = LLM(path, enforce_eager=False, max_model_len=4096)
1516

16-
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
17-
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
18-
# uncomment the following line for vllm
19-
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
17+
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
18+
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
19+
# uncomment the following line for vllm
20+
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
2021

21-
llm.generate(["Benchmark: "], SamplingParams())
22-
t = time.time()
23-
llm.generate(prompt_token_ids, sampling_params)
24-
t = (time.time() - t)
25-
total_tokens = sum(sp.max_tokens for sp in sampling_params)
26-
throughput = total_tokens / t
27-
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
22+
llm.generate(["Benchmark: "], SamplingParams())
23+
t = time.time()
24+
llm.generate(prompt_token_ids, sampling_params)
25+
t = (time.time() - t)
26+
total_tokens = sum(sp.max_tokens for sp in sampling_params)
27+
throughput = total_tokens / t
28+
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
29+
30+
31+
if __name__ == "__main__":
32+
main()

example.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,32 @@
33
from transformers import AutoTokenizer
44

55

6-
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
7-
tokenizer = AutoTokenizer.from_pretrained(path)
8-
llm = LLM(path, enforce_eager=True)
6+
def main():
7+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
8+
tokenizer = AutoTokenizer.from_pretrained(path)
9+
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
910

10-
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
11-
prompts = [
12-
"introduce yourself",
13-
"list all prime numbers within 100",
14-
]
15-
prompts = [
16-
tokenizer.apply_chat_template(
17-
[{"role": "user", "content": prompt}],
18-
tokenize=False,
19-
add_generation_prompt=True,
20-
enable_thinking=True
21-
)
22-
for prompt in prompts
23-
]
24-
outputs = llm.generate(prompts, sampling_params)
11+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
12+
prompts = [
13+
"introduce yourself",
14+
"list all prime numbers within 100",
15+
]
16+
prompts = [
17+
tokenizer.apply_chat_template(
18+
[{"role": "user", "content": prompt}],
19+
tokenize=False,
20+
add_generation_prompt=True,
21+
enable_thinking=True
22+
)
23+
for prompt in prompts
24+
]
25+
outputs = llm.generate(prompts, sampling_params)
2526

26-
for prompt, output in zip(prompts, outputs):
27-
print("\n")
28-
print(f"Prompt: {prompt!r}")
29-
print(f"Completion: {output['text']!r}")
27+
for prompt, output in zip(prompts, outputs):
28+
print("\n")
29+
print(f"Prompt: {prompt!r}")
30+
print(f"Completion: {output['text']!r}")
31+
32+
33+
if __name__ == "__main__":
34+
main()

nanovllm/engine/llm_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ def __init__(self, model, **kwargs):
2020
config = Config(model, **config_kwargs)
2121
self.ps = []
2222
self.events = []
23+
ctx = mp.get_context("spawn")
2324
for i in range(1, config.tensor_parallel_size):
24-
event = mp.Event()
25-
process = mp.Process(target=ModelRunner, args=(config, i, event))
25+
event = ctx.Event()
26+
process = ctx.Process(target=ModelRunner, args=(config, i, event))
2627
process.start()
2728
self.ps.append(process)
2829
self.events.append(event)

0 commit comments

Comments
 (0)