142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
import grpc
|
|
from concurrent import futures
|
|
import utils.image_pb2 as pb2
|
|
import utils.image_pb2_grpc as pb2_grpc
|
|
from PIL import Image
|
|
import io
|
|
import datetime
|
|
from botocore.client import Config
|
|
import traceback
|
|
import logging
|
|
import logging.handlers
|
|
import boto3
|
|
import json
|
|
import timm
|
|
import os
|
|
import uuid
|
|
import random
|
|
|
|
import torch
|
|
from torchvision import models, transforms
|
|
|
|
|
|
###############################################
|
|
# Config #
|
|
###############################################
|
|
with open('./config.json','r') as f:
|
|
cfg = json.load(f)
|
|
|
|
SEED = cfg['model']['seed']
|
|
MODEL_NAME = cfg['model']['name']
|
|
NUM_CLASSES = cfg['model']['num_classes']
|
|
DEVICE_CFG = cfg['model']['device']
|
|
DEVICE = DEVICE_CFG if torch.cuda.is_available() else "cpu"
|
|
MODEL_CKPT = cfg['model']['ckpt_path']
|
|
MODEL_FILE_NAME = MODEL_CKPT.split('/')[-1]
|
|
CATEGORIES = {0: '모래',
|
|
1: '자갈',
|
|
2: '덮개',
|
|
3: '빈차',
|
|
4: '레미콘',
|
|
5: '차량없음'}
|
|
|
|
# bwc에서 모델 업로드/다운로드 가능해지면 사용 안할 예정
|
|
MINIO_BUCKET = cfg['minio_bucket']
|
|
MINIO_URL = cfg['minio_url']
|
|
MINIO_ACC_KEY = cfg['minio_access_key']
|
|
MINIO_SCR_KEY = cfg['minio_secret_key']
|
|
MINIO_REGION = cfg['minio_region_name']
|
|
|
|
###############################################
|
|
# Logger Setting #
|
|
###############################################
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
log_fileHandler = logging.handlers.RotatingFileHandler(
|
|
filename="./logs/log_inference.log",
|
|
maxBytes=1024000,
|
|
backupCount=3,
|
|
mode='a')
|
|
|
|
log_fileHandler.setFormatter(formatter)
|
|
logger.addHandler(log_fileHandler)
|
|
|
|
###############################################
|
|
# Model download #
|
|
###############################################
|
|
|
|
#model_storage = boto3.client('s3',
|
|
# endpoint_url=MINIO_URL,
|
|
# aws_access_key_id=MINIO_ACC_KEY,
|
|
# aws_secret_access_key=MINIO_SCR_KEY,
|
|
# config=Config(signature_version='s3v4'),
|
|
# region_name=MINIO_REGION)
|
|
#
|
|
## minio에서 model ckpt 파일 다운로드
|
|
#if not os.path.isfile(MODEL_CKPT):
|
|
# model_storage.download_file(MINIO_BUCKET,f'{MODEL_NAME}/{MODEL_FILE_NAME}', MODEL_CKPT)
|
|
# print('Model is downloaded')
|
|
|
|
###############################################
|
|
# Model Class #
|
|
###############################################
|
|
|
|
class Model:
|
|
def __init__(self, ckpt_path, num_classes, device):
|
|
logger.info(f"DEVICE: {device}")
|
|
self.model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=num_classes).to(device)
|
|
self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
|
self.device = device
|
|
self.transform = transforms.Compose([transforms.Resize((384, 384)),
|
|
transforms.ToTensor()])
|
|
|
|
def inference(self, image):
|
|
t_image = self.transform(image).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
self.model.eval()
|
|
|
|
inputs = t_image.to(self.device)
|
|
outputs = self.model(inputs)
|
|
|
|
preds = torch.argmax(outputs, dim=-1)
|
|
|
|
return preds.item()
|
|
|
|
|
|
|
|
|
|
class Inference_Agent(pb2_grpc.ImageServiceServicer):
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def UploadImage(self, request, context):
|
|
image = Image.open(io.BytesIO(request.image_data)).convert("RGB")
|
|
now = datetime.datetime.now()
|
|
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
with torch.no_grad():
|
|
pred = model.inference(image)
|
|
#runAction(request.filename, pred)
|
|
logger.info(f'{formatted_now}: filename = {request.filename}, predicted class = {CATEGORIES[pred]}')
|
|
print(f'{formatted_now}: filename = {request.filename}, predicted class = {CATEGORIES[pred]}')
|
|
|
|
result = f"Predicted class = {CATEGORIES[pred]}"
|
|
return pb2.ImageResponse(message="Image Result", inference_result = result)
|
|
|
|
def serve(model):
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
|
pb2_grpc.add_ImageServiceServicer_to_server(Inference_Agent(model), server)
|
|
server.add_insecure_port('[::]:50051')
|
|
server.start()
|
|
print('Waitting for client...')
|
|
server.wait_for_termination()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = Model(MODEL_CKPT,NUM_CLASSES, DEVICE)
|
|
print('Model is loaded')
|
|
serve(model)
|