-
Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathserver.py
133 lines (105 loc) · 4.06 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import asyncio
import logging
import os
import random
import tempfile
import traceback
import uuid
import aiohttp
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline
logger = logging.getLogger(__name__)
class TextToImageInput(BaseModel):
model: str
prompt: str
size: str | None = None
n: int | None = None
class HttpClient:
session: aiohttp.ClientSession = None
def start(self):
self.session = aiohttp.ClientSession()
async def stop(self):
await self.session.close()
self.session = None
def __call__(self) -> aiohttp.ClientSession:
assert self.session is not None
return self.session
class TextToImagePipeline:
pipeline: StableDiffusion3Pipeline = None
device: str = None
def start(self):
if torch.cuda.is_available():
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
logger.info("Loading CUDA")
self.device = "cuda"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
elif torch.backends.mps.is_available():
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
logger.info("Loading MPS for Mac M Series")
self.device = "mps"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")
app = FastAPI()
service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
image_dir = os.path.join(tempfile.gettempdir(), "images")
if not os.path.exists(image_dir):
os.makedirs(image_dir)
app.mount("/images", StaticFiles(directory=image_dir), name="images")
http_client = HttpClient()
shared_pipeline = TextToImagePipeline()
# Configure CORS settings
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
allow_headers=["*"], # Allows all headers
)
@app.on_event("startup")
def startup():
http_client.start()
shared_pipeline.start()
def save_image(image):
filename = "draw" + str(uuid.uuid4()).split("-")[0] + ".png"
image_path = os.path.join(image_dir, filename)
# write image to disk at image_path
logger.info(f"Saving image to {image_path}")
image.save(image_path)
return os.path.join(service_url, "images", filename)
@app.get("/")
@app.post("/")
@app.options("/")
async def base():
return "Welcome to Diffusers! Where you can use diffusion models to generate images"
@app.post("/v1/images/generations")
async def generate_image(image_input: TextToImageInput):
try:
loop = asyncio.get_event_loop()
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
generator = torch.Generator(device=shared_pipeline.device)
generator.manual_seed(random.randint(0, 10000000))
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
logger.info(f"output: {output}")
image_url = save_image(output.images[0])
return {"data": [{"url": image_url}]}
except Exception as e:
if isinstance(e, HTTPException):
raise e
elif hasattr(e, "message"):
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)