-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
60 lines (47 loc) · 1.76 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import logging
import random
import sys
from typing import Optional
import warnings
import numpy as np
import torch
sys.path.append('../')
from src.modeling import ModelProvider
def create_new_model_tuple(model_provider: ModelProvider):
"""Create a new model and tokenizer."""
model, model_utils = model_provider.create_new_model_tuple()
tokenizer = model_utils['tokenizer']
return model, tokenizer
class OpenAIRequestLogFilter(logging.Filter):
"""Filter out the OpenAI request log."""
def filter(self, record):
return 'HTTP Request: POST https://api.openai.com' not in record.getMessage()
def configure_logging(logger):
"""Configure logging for the server."""
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
# add log in the file
handler = logging.FileHandler('../eval_log.txt')
handler.setLevel(logging.INFO)
handler.setFormatter(formatter)
logger.addHandler(handler)
warnings.filterwarnings('ignore', message=".*Could not load referrer policy.*")
trafilatura_logger = logging.getLogger('trafilatura')
trafilatura_logger.setLevel(logging.INFO)
lite_llm = logging.getLogger('LiteLLM')
lite_llm.setLevel(logging.INFO)
httpx_logger = logging.getLogger('httpx')
httpx_logger.setLevel(logging.INFO)
httpx_logger.addFilter(OpenAIRequestLogFilter())
def set_seed(seed: Optional[int] = None):
"""Set the random seed for reproducibility."""
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)