demo-inference/main.py

142 lines
4.7 KiB
Python
Raw Permalink Normal View History

2024-12-04 22:40:39 +00:00
import grpc
from concurrent import futures
import utils.image_pb2 as pb2
import utils.image_pb2_grpc as pb2_grpc
from PIL import Image
import io
import datetime
from botocore.client import Config
import traceback
import logging
import logging.handlers
import boto3
import json
import timm
import os
import uuid
import random
import torch
from torchvision import models, transforms
###############################################
# Config #
###############################################
with open('./config.json','r') as f:
cfg = json.load(f)
SEED = cfg['model']['seed']
MODEL_NAME = cfg['model']['name']
NUM_CLASSES = cfg['model']['num_classes']
DEVICE_CFG = cfg['model']['device']
DEVICE = DEVICE_CFG if torch.cuda.is_available() else "cpu"
MODEL_CKPT = cfg['model']['ckpt_path']
MODEL_FILE_NAME = MODEL_CKPT.split('/')[-1]
CATEGORIES = {0: '모래',
1: '자갈',
2: '덮개',
3: '빈차',
4: '레미콘',
5: '차량없음'}
# bwc에서 모델 업로드/다운로드 가능해지면 사용 안할 예정
MINIO_BUCKET = cfg['minio_bucket']
MINIO_URL = cfg['minio_url']
MINIO_ACC_KEY = cfg['minio_access_key']
MINIO_SCR_KEY = cfg['minio_secret_key']
MINIO_REGION = cfg['minio_region_name']
###############################################
# Logger Setting #
###############################################
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log_fileHandler = logging.handlers.RotatingFileHandler(
filename="./logs/log_inference.log",
maxBytes=1024000,
backupCount=3,
mode='a')
log_fileHandler.setFormatter(formatter)
logger.addHandler(log_fileHandler)
###############################################
# Model download #
###############################################
#model_storage = boto3.client('s3',
# endpoint_url=MINIO_URL,
# aws_access_key_id=MINIO_ACC_KEY,
# aws_secret_access_key=MINIO_SCR_KEY,
# config=Config(signature_version='s3v4'),
# region_name=MINIO_REGION)
#
## minio에서 model ckpt 파일 다운로드
#if not os.path.isfile(MODEL_CKPT):
# model_storage.download_file(MINIO_BUCKET,f'{MODEL_NAME}/{MODEL_FILE_NAME}', MODEL_CKPT)
# print('Model is downloaded')
###############################################
# Model Class #
###############################################
class Model:
def __init__(self, ckpt_path, num_classes, device):
logger.info(f"DEVICE: {device}")
self.model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=num_classes).to(device)
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
self.device = device
self.transform = transforms.Compose([transforms.Resize((384, 384)),
transforms.ToTensor()])
def inference(self, image):
t_image = self.transform(image).unsqueeze(0)
with torch.no_grad():
self.model.eval()
inputs = t_image.to(self.device)
outputs = self.model(inputs)
preds = torch.argmax(outputs, dim=-1)
return preds.item()
class Inference_Agent(pb2_grpc.ImageServiceServicer):
def __init__(self, model):
self.model = model
def UploadImage(self, request, context):
image = Image.open(io.BytesIO(request.image_data)).convert("RGB")
now = datetime.datetime.now()
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
with torch.no_grad():
pred = model.inference(image)
#runAction(request.filename, pred)
logger.info(f'{formatted_now}: filename = {request.filename}, predicted class = {CATEGORIES[pred]}')
print(f'{formatted_now}: filename = {request.filename}, predicted class = {CATEGORIES[pred]}')
result = f"Predicted class = {CATEGORIES[pred]}"
return pb2.ImageResponse(message="Image Result", inference_result = result)
def serve(model):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
pb2_grpc.add_ImageServiceServicer_to_server(Inference_Agent(model), server)
server.add_insecure_port('[::]:50051')
server.start()
print('Waitting for client...')
server.wait_for_termination()
if __name__ == "__main__":
model = Model(MODEL_CKPT,NUM_CLASSES, DEVICE)
print('Model is loaded')
serve(model)