import grpc from concurrent import futures import utils.image_pb2 as pb2 import utils.image_pb2_grpc as pb2_grpc from PIL import Image import io import datetime from botocore.client import Config import traceback import logging import logging.handlers import boto3 import json import timm import os import uuid import random import torch from torchvision import models, transforms ############################################### # Config # ############################################### with open('./config.json','r') as f: cfg = json.load(f) SEED = cfg['model']['seed'] MODEL_NAME = cfg['model']['name'] NUM_CLASSES = cfg['model']['num_classes'] DEVICE_CFG = cfg['model']['device'] DEVICE = DEVICE_CFG if torch.cuda.is_available() else "cpu" MODEL_CKPT = cfg['model']['ckpt_path'] MODEL_FILE_NAME = MODEL_CKPT.split('/')[-1] CATEGORIES = {0: '모래', 1: '자갈', 2: '덮개', 3: '빈차', 4: '레미콘', 5: '차량없음'} # bwc에서 모델 업로드/다운로드 가능해지면 사용 안할 예정 MINIO_BUCKET = cfg['minio_bucket'] MINIO_URL = cfg['minio_url'] MINIO_ACC_KEY = cfg['minio_access_key'] MINIO_SCR_KEY = cfg['minio_secret_key'] MINIO_REGION = cfg['minio_region_name'] ############################################### # Logger Setting # ############################################### logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') log_fileHandler = logging.handlers.RotatingFileHandler( filename="./logs/log_inference.log", maxBytes=1024000, backupCount=3, mode='a') log_fileHandler.setFormatter(formatter) logger.addHandler(log_fileHandler) ############################################### # Model download # ############################################### #model_storage = boto3.client('s3', # endpoint_url=MINIO_URL, # aws_access_key_id=MINIO_ACC_KEY, # aws_secret_access_key=MINIO_SCR_KEY, # config=Config(signature_version='s3v4'), # region_name=MINIO_REGION) # ## minio에서 model ckpt 파일 다운로드 #if not os.path.isfile(MODEL_CKPT): # model_storage.download_file(MINIO_BUCKET,f'{MODEL_NAME}/{MODEL_FILE_NAME}', MODEL_CKPT) # print('Model is downloaded') ############################################### # Model Class # ############################################### class Model: def __init__(self, ckpt_path, num_classes, device): logger.info(f"DEVICE: {device}") self.model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=num_classes).to(device) self.model.load_state_dict(torch.load(ckpt_path, map_location=device)) self.device = device self.transform = transforms.Compose([transforms.Resize((384, 384)), transforms.ToTensor()]) def inference(self, image): t_image = self.transform(image).unsqueeze(0) with torch.no_grad(): self.model.eval() inputs = t_image.to(self.device) outputs = self.model(inputs) preds = torch.argmax(outputs, dim=-1) return preds.item() class Inference_Agent(pb2_grpc.ImageServiceServicer): def __init__(self, model): self.model = model def UploadImage(self, request, context): image = Image.open(io.BytesIO(request.image_data)).convert("RGB") now = datetime.datetime.now() formatted_now = now.strftime("%Y-%m-%d %H:%M:%S") with torch.no_grad(): pred = model.inference(image) #runAction(request.filename, pred) logger.info(f'{formatted_now}: filename = {request.filename}, predicted class = {CATEGORIES[pred]}') print(f'{formatted_now}: filename = {request.filename}, predicted class = {CATEGORIES[pred]}') result = f"Predicted class = {CATEGORIES[pred]}" return pb2.ImageResponse(message="Image Result", inference_result = result) def serve(model): server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) pb2_grpc.add_ImageServiceServicer_to_server(Inference_Agent(model), server) server.add_insecure_port('[::]:50051') server.start() print('Waitting for client...') server.wait_for_termination() if __name__ == "__main__": model = Model(MODEL_CKPT,NUM_CLASSES, DEVICE) print('Model is loaded') serve(model)