-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
102 lines (80 loc) · 2.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import uvicorn
from pathlib import Path
import logging
from fastapi import FastAPI, Depends, Security, status, HTTPException
from fastapi.security import (
HTTPBearer,
APIKeyHeader,
)
from celery import Celery
from logging.config import dictConfig
from dotenv import load_dotenv
from src.api import router as text_summarization_router
from config.celery_config import CeleryConfig
from config.constants import TMP_UPLOAD_DIR_NAME, LOGS_DIR_NAME
log_dir = Path(__file__).resolve().parent / LOGS_DIR_NAME
log_dir.mkdir(parents=True, exist_ok=True)
tmp_upload_dir = Path(__file__).resolve().parent / TMP_UPLOAD_DIR_NAME
tmp_upload_dir.mkdir(parents=True, exist_ok=True)
load_dotenv()
# logging configuration
class CustomFormatter(logging.Formatter):
def format(self, record):
cwd = os.getcwd()
abs_path = record.pathname
rel_path = os.path.relpath(abs_path, cwd)
record.pathname = rel_path
return super().format(record)
dictConfig(
{
"version": 1,
"formatters": {
"default": {
"()": CustomFormatter,
"format": "[%(asctime)s] %(levelname)s in %(pathname)s: %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S %Z",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "default",
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"filename": "logs/app.log",
"maxBytes": 1048576, # 1MB
"backupCount": 5,
"formatter": "default",
},
},
"root": {"level": "DEBUG", "handlers": ["console", "file"]},
}
)
app = FastAPI()
security = HTTPBearer()
api_key_header = APIKeyHeader(name="Authorization")
async def authenticate_user(api_key_header: str = Security(api_key_header)):
if api_key_header != os.getenv("API_KEY"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
return {}
# celery
celery = Celery(
"t4d-ai-llm",
)
celery.config_from_object(CeleryConfig, namespace="CELERY")
# routes
app.include_router(
text_summarization_router, prefix="/api", dependencies=[Depends(authenticate_user)]
)
# home route
@app.get("/api")
async def home(auth_user: dict = Depends(authenticate_user)):
print("here")
return {"message": "Welcome to the T4D's AI/LLM service"}
if __name__ == "__main__":
uvicorn.run("main:app", port=7001, reload=True, reload_dirs=["src", "config"])