Skip to content

Commit 28eeb1b

Browse files
committed
huggingface text_generation with local model server
1 parent 37f77b3 commit 28eeb1b

File tree

7 files changed

+1666
-2
lines changed

7 files changed

+1666
-2
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Local text generation server that mimics the Hugging Face Inference API.
4+
This allows you to use InferenceClient with a local model.
5+
"""
6+
7+
from fastapi import FastAPI
8+
from pydantic import BaseModel
9+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
10+
import torch
11+
import uvicorn
12+
from typing import Dict, Any, Optional
13+
14+
15+
class TextGenerationRequest(BaseModel):
16+
inputs: str
17+
parameters: Optional[Dict[str, Any]] = {}
18+
19+
20+
class TextGenerationResponse(BaseModel):
21+
generated_text: str
22+
23+
24+
class LocalTextGenerationServer:
25+
def __init__(self, model_name: str = "gpt2"):
26+
print(f"🔄 Loading model: {model_name}")
27+
self.model = GPT2LMHeadModel.from_pretrained(model_name)
28+
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
29+
self.tokenizer.pad_token = self.tokenizer.eos_token
30+
print("✅ Model loaded successfully!")
31+
32+
def generate_text(self, prompt: str, parameters: Dict[str, Any] = None) -> str:
33+
if parameters is None:
34+
parameters = {}
35+
36+
# Default parameters
37+
max_new_tokens = parameters.get("max_new_tokens", 50)
38+
temperature = parameters.get("temperature", 0.7)
39+
do_sample = parameters.get("do_sample", True)
40+
41+
# Tokenize
42+
inputs = self.tokenizer.encode(prompt, return_tensors="pt")
43+
44+
# Generate
45+
with torch.no_grad():
46+
outputs = self.model.generate(
47+
inputs,
48+
max_new_tokens=max_new_tokens,
49+
temperature=temperature,
50+
do_sample=do_sample,
51+
pad_token_id=self.tokenizer.eos_token_id,
52+
no_repeat_ngram_size=2
53+
)
54+
55+
# Decode
56+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
57+
58+
# Return full text (like HF API does)
59+
return generated_text
60+
61+
62+
# Initialize the model server
63+
print("🚀 Starting Local Text Generation Server...")
64+
text_gen_server = LocalTextGenerationServer()
65+
66+
# Create FastAPI app
67+
app = FastAPI(title="Local Text Generation API", version="1.0.0")
68+
69+
70+
@app.get("/")
71+
async def root():
72+
return {"message": "Local Text Generation Server", "status": "running"}
73+
74+
75+
@app.post("/")
76+
async def generate_text_endpoint(request: TextGenerationRequest):
77+
"""
78+
Main text generation endpoint that mimics HuggingFace Inference API format.
79+
"""
80+
try:
81+
generated_text = text_gen_server.generate_text(
82+
request.inputs,
83+
request.parameters
84+
)
85+
86+
# Return in HF API format (list with generated_text)
87+
return [{"generated_text": generated_text}]
88+
89+
except Exception as e:
90+
return {"error": str(e)}
91+
92+
93+
@app.post("/generate")
94+
async def generate_text_simple(request: TextGenerationRequest):
95+
"""
96+
Alternative endpoint with simpler response format.
97+
"""
98+
try:
99+
generated_text = text_gen_server.generate_text(
100+
request.inputs,
101+
request.parameters
102+
)
103+
104+
return {"generated_text": generated_text}
105+
106+
except Exception as e:
107+
return {"error": str(e)}
108+
109+
110+
if __name__ == "__main__":
111+
print("🌐 Server will be available at: http://localhost:8000")
112+
print("📝 Test endpoint: POST http://localhost:8000/")
113+
print("📚 API docs: http://localhost:8000/docs")
114+
print("🛑 Press Ctrl+C to stop the server")
115+
116+
uvicorn.run(
117+
"local_server:app",
118+
host="127.0.0.1",
119+
port=8000,
120+
reload=False,
121+
log_level="info"
122+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import time
3+
4+
import sentry_sdk
5+
from sentry_sdk.integrations.huggingface_hub import HuggingfaceHubIntegration
6+
7+
from huggingface_hub import InferenceClient
8+
9+
10+
def main():
11+
sentry_sdk.init(
12+
dsn=os.getenv("SENTRY_DSN", None),
13+
environment=os.getenv("ENV", "local"),
14+
traces_sample_rate=1.0,
15+
send_default_pii=True,
16+
debug=True,
17+
integrations=[
18+
HuggingfaceHubIntegration(include_prompts=True),
19+
],
20+
)
21+
22+
# Connect to local text generation server
23+
local_server_url = "http://localhost:8000"
24+
25+
print(f"🔄 Connecting to local server: {local_server_url}")
26+
print("📝 Make sure to start the local server first:")
27+
print(" ./run_server.sh")
28+
print()
29+
30+
with sentry_sdk.start_transaction(name="huggingface-hub-text-generation"):
31+
client = InferenceClient(model=local_server_url)
32+
33+
prompt = "The sky is"
34+
try:
35+
print(f"🔄 Test: Generating text for '{prompt}'")
36+
37+
# Use the real InferenceClient with local server
38+
response = client.text_generation(
39+
prompt,
40+
max_new_tokens=40,
41+
temperature=0.7,
42+
do_sample=True,
43+
)
44+
45+
print("✅ Success!")
46+
print(f" Prompt: {prompt}")
47+
print(f" Generated: {response}")
48+
print()
49+
50+
except Exception as e:
51+
print(f"❌ Failed: {type(e).__name__}: {e}")
52+
print("💡 Make sure the local server is running:")
53+
print(" ./run_server.sh")
54+
print()
55+
56+
57+
if __name__ == "__main__":
58+
main()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[project]
2+
name = "test"
3+
version = "0"
4+
requires-python = ">=3.12"
5+
6+
dependencies = [
7+
"fastapi>=0.116.1",
8+
"huggingface-hub[inference]==0.22.0",
9+
"ipdb>=0.13.13",
10+
"sentry-sdk",
11+
"text-generation>=0.7.0",
12+
"torch>=2.8.0",
13+
"transformers>=4.40.2",
14+
"uvicorn>=0.35.0",
15+
]
16+
17+
[tool.uv.sources]
18+
sentry-sdk = { path = "../../sentry-python", editable = true }
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env bash
2+
3+
# exit on first error
4+
set -euo pipefail
5+
6+
# Install uv if it's not installed
7+
if ! command -v uv &> /dev/null; then
8+
curl -LsSf https://astral.sh/uv/install.sh | sh
9+
fi
10+
11+
# Run the script
12+
uv run python main.py
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/usr/bin/env bash
2+
3+
# exit on first error
4+
set -euo pipefail
5+
6+
# Install uv if it's not installed
7+
if ! command -v uv &> /dev/null; then
8+
curl -LsSf https://astral.sh/uv/install.sh | sh
9+
fi
10+
11+
echo "Starting local text generation server accessible via huggingface_hub..."
12+
13+
# Run text generation server, accessible via huggingface_hub
14+
uv run python local_server.py

0 commit comments

Comments
 (0)