Skip to content

Commit 778b61c

Browse files
sywangyiregisss
andauthored
[gaudi] Remove unnecessary reinitialize to HeterogeneousNextTokenChooser to make sampling output correct (#3284)
Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: regisss <[email protected]>
1 parent 3d2e7c8 commit 778b61c

File tree

4 files changed

+95
-336
lines changed

4 files changed

+95
-336
lines changed

backends/gaudi/examples/docker_commands/docker_commands.md

Lines changed: 12 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ docker run -p 8080:80 \
1919
--ipc=host \
2020
-v $volume:/data \
2121
-e HF_TOKEN=$hf_token \
22-
-e MAX_TOTAL_TOKENS=2048 \
23-
-e PREFILL_BATCH_BUCKET_SIZE=2 \
24-
-e BATCH_BUCKET_SIZE=32 \
25-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
26-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
22+
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
2723
--model-id $model \
2824
--max-input-tokens 1024 --max-total-tokens 2048 \
2925
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
@@ -43,60 +39,7 @@ docker run -p 8080:80 \
4339
--ipc=host \
4440
-v $volume:/data \
4541
-e HF_TOKEN=$hf_token \
46-
-e MAX_TOTAL_TOKENS=2048 \
47-
-e BATCH_BUCKET_SIZE=256 \
48-
-e PREFILL_BATCH_BUCKET_SIZE=4 \
49-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
50-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
51-
--model-id $model \
52-
--sharded true --num-shard 8 \
53-
--max-input-tokens 1024 --max-total-tokens 2048 \
54-
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
55-
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
56-
```
57-
58-
### Llama2-7B on 1 Card (BF16)
59-
60-
```bash
61-
model=meta-llama/Llama-2-7b-chat-hf
62-
hf_token=YOUR_ACCESS_TOKEN
63-
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
64-
65-
docker run -p 8080:80 \
66-
--runtime=habana \
67-
--cap-add=sys_nice \
68-
--ipc=host \
69-
-v $volume:/data \
70-
-e HF_TOKEN=$hf_token \
71-
-e MAX_TOTAL_TOKENS=2048 \
72-
-e PREFILL_BATCH_BUCKET_SIZE=2 \
73-
-e BATCH_BUCKET_SIZE=32 \
74-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
75-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
76-
--model-id $model \
77-
--max-input-tokens 1024 --max-total-tokens 2048 \
78-
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
79-
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
80-
```
81-
82-
### Llama2-70B on 8 cards (BF16)
83-
84-
```bash
85-
model=meta-llama/Llama-2-70b-chat-hf
86-
hf_token=YOUR_ACCESS_TOKEN
87-
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
88-
89-
docker run -p 8080:80 \
90-
--runtime=habana \
91-
--cap-add=sys_nice \
92-
--ipc=host \
93-
-v $volume:/data \
94-
-e HF_TOKEN=$hf_token \
95-
-e MAX_TOTAL_TOKENS=2048 \
96-
-e BATCH_BUCKET_SIZE=256 \
97-
-e PREFILL_BATCH_BUCKET_SIZE=4 \
98-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
99-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
42+
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
10043
--model-id $model \
10144
--sharded true --num-shard 8 \
10245
--max-input-tokens 1024 --max-total-tokens 2048 \
@@ -115,49 +58,20 @@ docker run -p 8080:80 \
11558
--cap-add=sys_nice \
11659
--ipc=host \
11760
-v $volume:/data \
118-
-e PREFILL_BATCH_BUCKET_SIZE=1 \
119-
-e BATCH_BUCKET_SIZE=1 \
120-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
61+
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
12162
--model-id $model \
12263
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
12364
--max-total-tokens 8192 --max-batch-size 4
12465
```
12566

12667
## FP8 Precision
12768

128-
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
129-
130-
## Llama3.1-8B on 1 Card (FP8)
131-
132-
```bash
133-
model=meta-llama/Meta-Llama-3.1-8B-Instruct
134-
hf_token=YOUR_ACCESS_TOKEN
135-
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
136-
137-
docker run -p 8080:80 \
138-
--runtime=habana \
139-
--cap-add=sys_nice \
140-
--ipc=host \
141-
-v $volume:/data \
142-
-v $PWD/quantization_config:/usr/src/quantization_config \
143-
-v $PWD/hqt_output:/usr/src/hqt_output \
144-
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
145-
-e HF_TOKEN=$hf_token \
146-
-e MAX_TOTAL_TOKENS=2048 \
147-
-e PREFILL_BATCH_BUCKET_SIZE=2 \
148-
-e BATCH_BUCKET_SIZE=32 \
149-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
150-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
151-
--model-id $model \
152-
--max-input-tokens 1024 --max-total-tokens 2048 \
153-
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
154-
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
155-
```
69+
You could also set kv cache dtype to FP8 when launching the server, fp8_e4m3fn is supported in Gaudi
15670

157-
## Llama3.1-70B on 8 cards (FP8)
71+
## Llama3-8B on 1 Card (FP8)
15872

15973
```bash
160-
model=meta-llama/Meta-Llama-3.1-70B-Instruct
74+
model=RedHatAI/Meta-Llama-3-8B-Instruct-FP8-KV
16175
hf_token=YOUR_ACCESS_TOKEN
16276
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
16377

@@ -166,53 +80,19 @@ docker run -p 8080:80 \
16680
--cap-add=sys_nice \
16781
--ipc=host \
16882
-v $volume:/data \
169-
-v $PWD/quantization_config:/usr/src/quantization_config \
170-
-v $PWD/hqt_output:/usr/src/hqt_output \
171-
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
172-
-e HF_TOKEN=$hf_token \
173-
-e MAX_TOTAL_TOKENS=2048 \
174-
-e BATCH_BUCKET_SIZE=256 \
175-
-e PREFILL_BATCH_BUCKET_SIZE=4 \
176-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
177-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
178-
--model-id $model \
179-
--sharded true --num-shard 8 \
180-
--max-input-tokens 1024 --max-total-tokens 2048 \
181-
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
182-
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
183-
```
184-
185-
## Llama2-7B on 1 Card (FP8)
186-
187-
```bash
188-
model=meta-llama/Llama-2-7b-chat-hf
189-
hf_token=YOUR_ACCESS_TOKEN
190-
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
191-
192-
docker run -p 8080:80 \
193-
--runtime=habana \
194-
--cap-add=sys_nice \
195-
--ipc=host \
196-
-v $volume:/data \
197-
-v $PWD/quantization_config:/usr/src/quantization_config \
198-
-v $PWD/hqt_output:/usr/src/hqt_output \
199-
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
20083
-e HF_TOKEN=$hf_token \
201-
-e MAX_TOTAL_TOKENS=2048 \
202-
-e PREFILL_BATCH_BUCKET_SIZE=2 \
203-
-e BATCH_BUCKET_SIZE=32 \
204-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
205-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
84+
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
20685
--model-id $model \
86+
--kv-cache-dtype fp8_e4m3fn \
20787
--max-input-tokens 1024 --max-total-tokens 2048 \
20888
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
20989
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
21090
```
21191

212-
## Llama2-70B on 8 Cards (FP8)
92+
## Llama3-70B on 8 cards (FP8)
21393

21494
```bash
215-
model=meta-llama/Llama-2-70b-chat-hf
95+
model=RedHatAI/Meta-Llama-3-70B-Instruct-FP8
21696
hf_token=YOUR_ACCESS_TOKEN
21797
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
21898

@@ -221,63 +101,12 @@ docker run -p 8080:80 \
221101
--cap-add=sys_nice \
222102
--ipc=host \
223103
-v $volume:/data \
224-
-v $PWD/quantization_config:/usr/src/quantization_config \
225-
-v $PWD/hqt_output:/usr/src/hqt_output \
226-
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
227104
-e HF_TOKEN=$hf_token \
228-
-e MAX_TOTAL_TOKENS=2048 \
229-
-e BATCH_BUCKET_SIZE=256 \
230-
-e PREFILL_BATCH_BUCKET_SIZE=4 \
231-
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
232-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
105+
ghcr.io/huggingface/text-generation-inference:3.3.4-gaudi \
233106
--model-id $model \
107+
--kv-cache-dtype fp8_e4m3fn \
234108
--sharded true --num-shard 8 \
235109
--max-input-tokens 1024 --max-total-tokens 2048 \
236110
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
237111
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
238112
```
239-
240-
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
241-
242-
```bash
243-
model=llava-hf/llava-v1.6-mistral-7b-hf
244-
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
245-
246-
docker run -p 8080:80 \
247-
--runtime=habana \
248-
--cap-add=sys_nice \
249-
--ipc=host \
250-
-v $volume:/data \
251-
-v $PWD/quantization_config:/usr/src/quantization_config \
252-
-v $PWD/hqt_output:/usr/src/hqt_output \
253-
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
254-
-e PREFILL_BATCH_BUCKET_SIZE=1 \
255-
-e BATCH_BUCKET_SIZE=1 \
256-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
257-
--model-id $model \
258-
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
259-
--max-total-tokens 8192 --max-batch-size 4
260-
```
261-
262-
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
263-
264-
```bash
265-
model=llava-hf/llava-v1.6-mistral-7b-hf
266-
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
267-
268-
docker run -p 8080:80 \
269-
--runtime=habana \
270-
--cap-add=sys_nice \
271-
--ipc=host \
272-
-v $volume:/data \
273-
-v $PWD/quantization_config:/usr/src/quantization_config \
274-
-v $PWD/hqt_output:/usr/src/hqt_output \
275-
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
276-
-e PREFILL_BATCH_BUCKET_SIZE=1 \
277-
-e BATCH_BUCKET_SIZE=1 \
278-
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
279-
--model-id $model \
280-
--sharded true --num-shard 8 \
281-
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
282-
--max-total-tokens 8192 --max-batch-size 4
283-
```

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,6 @@ def __init__(
140140
self.hidden_size = config.hidden_size
141141
self.head_size = self.hidden_size // self.num_heads
142142

143-
# Setting defaults for baichuan custom config which doesn't apply them.
144-
config.rope_theta = getattr(config, "rope_theta", 10000)
145-
config.num_key_value_heads = getattr(
146-
config, "num_key_value_heads", config.num_attention_heads
147-
)
148-
149143
self.rotary_emb = rotary_emb
150144

151145
# `config.attention_multiplier` is used in Granite
@@ -476,7 +470,11 @@ def __init__(self, prefix, config, weights):
476470
# Skip fp8 quant for first and last layers
477471
self.layers = nn.ModuleList()
478472
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
479-
473+
# Setting defaults for baichuan custom config which doesn't apply them.
474+
config.rope_theta = getattr(config, "rope_theta", 10000)
475+
config.num_key_value_heads = getattr(
476+
config, "num_key_value_heads", config.num_attention_heads
477+
)
480478
rotary_emb = PositionRotaryEmbedding.static(
481479
config=config,
482480
dim=config.hidden_size // config.num_attention_heads,

backends/gaudi/server/text_generation_server/models/flash_causal_lm.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,22 +1076,23 @@ def prepare_for_decode(
10761076
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
10771077
value=0,
10781078
)
1079-
next_token_chooser_parameters = []
1080-
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
1081-
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
1082-
# update past grammar states
1083-
fsm_grammar_states = [0] * padded_bs
1084-
1085-
for i, req in enumerate(self.requests):
1086-
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
1087-
1088-
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
1089-
next_token_chooser_parameters,
1090-
self.next_token_chooser.dtype,
1091-
self.next_token_chooser.device,
1092-
self.next_token_chooser.tokenizer,
1093-
fsm_grammar_states,
1094-
)
1079+
if len(self.next_token_chooser.do_sample) != padded_bs:
1080+
next_token_chooser_parameters = []
1081+
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
1082+
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
1083+
# update past grammar states
1084+
fsm_grammar_states = [0] * padded_bs
1085+
1086+
for i, req in enumerate(self.requests):
1087+
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
1088+
1089+
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
1090+
next_token_chooser_parameters,
1091+
self.next_token_chooser.dtype,
1092+
self.next_token_chooser.device,
1093+
self.next_token_chooser.tokenizer,
1094+
fsm_grammar_states,
1095+
)
10951096

10961097
def prepare_for_prefill(
10971098
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
@@ -1379,23 +1380,25 @@ def prepare_for_prefill(
13791380
self.all_input_ids_tensor[i]
13801381
)
13811382
self.all_input_ids_tensor = all_input_ids_tensor
1382-
1383-
next_token_chooser_parameters = []
1384-
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
1385-
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
1386-
# update past grammar states
1387-
fsm_grammar_states = [0] * max_padded_bs
1388-
1389-
for i, req in enumerate(self.requests):
1390-
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
1391-
1392-
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
1393-
next_token_chooser_parameters,
1394-
self.next_token_chooser.dtype,
1395-
self.next_token_chooser.device,
1396-
self.next_token_chooser.tokenizer,
1397-
fsm_grammar_states,
1398-
)
1383+
if len(self.next_token_chooser.do_sample) != max_padded_bs:
1384+
next_token_chooser_parameters = []
1385+
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
1386+
pad_next_token_chooser_parameters(
1387+
next_token_chooser_parameters, max_padded_bs
1388+
)
1389+
# update past grammar states
1390+
fsm_grammar_states = [0] * max_padded_bs
1391+
1392+
for i, req in enumerate(self.requests):
1393+
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
1394+
1395+
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
1396+
next_token_chooser_parameters,
1397+
self.next_token_chooser.dtype,
1398+
self.next_token_chooser.device,
1399+
self.next_token_chooser.tokenizer,
1400+
fsm_grammar_states,
1401+
)
13991402

14001403
if ADAPTER_TO_INDEX:
14011404
if adapter_set:

0 commit comments

Comments
 (0)