[메디스태프 강의] ML 모델을 (웹)앱으로 만들기

2021. 10. 31. 20:47컴퓨터

이 강의는 국내 최고 의사 커뮤니티인 <메디스태프 - 젊은 의사커뮤니티>에서 강의를 했던 내용으로, 강의를 들으신 분들이 자세한 내용을 다시 찾아보시는데에 도움이 되시라고 전체 과정을 정리한 내용입니다.

들어가기에 앞서 : 왜?!!!!

많은 AI관련 의학 논문들이 출판되면서 단순히 'ML을 사용했더니 예측력이 좋았다'라는것의 학문적 의미의 한계를 많은 사람들이 느끼고 있습니다. 이에 따라 <실제 임상에서 사용되는 상황>에 대한 이야기가 포함된 논문들을 요구하는 경우가 많습니다. 또는 적어도 reporting guideline에서 이런 ML모델을 사용할 수 있는 방법을 (또는 자세한 weight가 포함된 model의 description을) 제시하도록 강제하는것이 앞으로 미래일 가능성이 높습니다. 이에 따라 단순히 ML모델을 만들고 결과를 도출한것이 끝이 아니고, 이 모델을 어떤 방식으로든 독자가 직접 테스트해볼 수 있도록 하는것이 중요해졌습니다. 그리고, 실제로 사용해보면 독자들이 이 기술이 얼마나 정확하고 훌륭한지 더욱 체감하는 효과도 있겠습니다.

예측 모델의 대표적인 Reporting guideline인 TRIPOD는 AI전용 모델이 개발중이며, 현재 일반 예측모델에서도 위와같이 사용할 수 있는 자세한 방법에 대해 reporting하도록 하고있습니다.

 

온디바이스 or 서버에서 돌리기

ML 모델을 사용자가 사용할 수 있도록 하는 방법은 실제로 여러가지가 있겠습니다. 모델 자체를 디바이스 안에다가 집어넣어서 (스마트폰 등) 디바이스 안에서 돌아가도록 하는 방법도 있고, 서버에서 돌리는 방법도 있겠습니다. 두 방법이 장단점이 있겠으며 (아래) 이 중 서버에서 모델을 돌리는 방법을 이 강의에서는 다룹니다. 이전 시간에 김병훈 선생님이 제작한 모델(Brain tumor segmentation model)을 구현하는것을 목표로 하며, 이 모델이 pytorch기반 모델이므로 python기반 FastAPI상에서 해당 모델을 구현하겠습니다.

 

 

클라이언트 고르기 (스마트폰 or 웹?)

또한 이 모델에 값을 입력하고 결과값을 볼 수 있는 클라이언트 (예. 스마트폰앱 또는 웹사이트)를 정하는 것도 필요합니다. 이 또한 장단점이 명확하며, 이 강좌에서는 웹사이트를 만드는것을 목표로 합니다. React.js framework를 이용할 예정입니다.

 

최종 목표

아래와같이 웹상에서 nii파일을 업로드하면, 이 파일을 서버로 보내서 결과물이 .gif형태로 아래 표시되는 웹사이트가 최종 구축 목표입니다.

최종 결과물 영상

 

백엔드 구축하기

서버 만들기 (아마존 AWS)

이 강의에서는 amazon AWS의 EC2를 이용하여 서버를 구축할예정입니다. t2.large (2vCPU, 8GB ram)인스턴스에 15Gb SSD를 사용하는 인스턴스입니다. (Oregon region에서 한달에 약 $70 정도 과금된다고 합니다)

 

아무 linux based OS나 상관없을것 같아서 Amazon Linux 2를 선택했습니다. (평소에는 ubuntu를 많이 사용합니다)

말씀드린대로 t2.large이며 바로 review & launch하고 싶지만 SSD 8Gb셋팅은 너무 부족하므로 15기가로 올리기위해 next: configure instance details로 넘어갑니다..

다음페이지는 그냥 넘어가고.. 드디어 add storage.. 15Gb SSD로 셋팅합니다.

이후에는 그냥 넘어가다가 ssh key는 원래 있던 key를 사용하시거나 본인의 상황에 맞게 새로운 key를 생성해서 설정하시면됩니다.

 

이제 EC2 instances menu에서 새로 생성된 인스턴스의  IP주소를 확인합니다. 참고로 elastic IP를 설정하지 않으시면 이 인스턴스를 껐다 켤때마다 IP주소가바뀌니 이를 유념하셔서 만약 여유가되신다면 elastic IP주소를 할당하시는것을 추천드립니다. (https://docs.aws.amazon.com/ko_kr/AWSEC2/latest/UserGuide/elastic-ip-addresses-eip.html)

 

또한 해당 인스턴스를 클릭해서 들어간 뒤, security 탭에서 security group설정을 들어갑니다.

이후 edit inbound rule를 선택해주시고

아래와같이 22 port, 80 port를 열어주십시요. (제대로 https로 하시려면 443 등 추가 포트를 열어주시면 좋습니다) 참고로 저는 letsencrypt를 애용하고 있습니다. (https://letsencrypt.org/ko/)

이제 ssh로 잘 접속이 되는지 확인해봅니다. 맥이라면 터미널프로그램이 자동으로 깔려있고, 윈도우즈이시거나 다른상황에서는 ssh client를 다운받아 설치하시면됩니다. (PuTTY나 iTerm 등) 이후에 아래와같이 ssh를 접속해주시면됩니다.

ssh -i [키_위치.pem] ec2-user@[IP주소]

이후 root권한 설정을 해주셔야하는데, 이는 추후에 1024이하 port에서 (이 강의에서는 80port) 서비스를 제공하기 위함입니다.

이와 관련된 설정은 아래 링크를 따라주시면됩니다.

https://goddaehee.tistory.com/193

 

[AWS] 8.AWS EC2 root 계정 활성화 시키기

[AWS] 8.AWS EC2 root 계정 활성화 시키기 안녕하세요. 갓대희 입니다. 이번 포스팅은 [ AWS EC2 (리눅스) root 계정 사용하기 ] 입니다. : ) 이전 포스팅을 통해 EC2 리눅스를 설치 해보았다. 다만 root.

goddaehee.tistory.com

 

이후 같은 서버에 root로 접속해주시면됩니다.

ssh -i [키_위치.pem] root@[IP주소]

 

추후에 파일을 쉽게 업로드하고 수정하기위해 sftp client에도 설정을 해줍니다. 저는 Cyberduck을 좋아하기 때문에 이를 통해 설정하였습니다.

 

 

필요한 패키지들 설치하기

pip를 이용하여 아래의 패키지들을 설치해주시면 됩니다.

라고 하려고했으나 pip조차없으니 이부터 설치해야합니다.

Amazon의 공식문서대로 하면되며, eb cli까지는 필요없으니 pip설치하기 섹션의 4번까지 하시면됩니다.(또한 위 셋팅대로 하시면 python3도 기본으로 설치되어있습니다.  python3 --version으로 확인가능합니다)

https://docs.aws.amazon.com/ko_kr/elasticbeanstalk/latest/dg/eb-cli3-install-linux.html

 

Linux에 Python, pip 및 EB CLI 설치 - AWS Elastic Beanstalk

Linux에 Python, pip 및 EB CLI 설치 EB CLI에는 Python 2.7, 3.4 또는 그 이상이 필요합니다. 배포가 Python과 함께 제공되지 않았거나 이전 버전과 함께 제공된 경우 pip 및 EB CLI를 설치하기 전에 Python을 설치

docs.aws.amazon.com

 

이후 필요한 패키지를 pip로 설정하시면됩니다.

pip install torch
pip install torchio
pip install monai
pip install fastapi
pip install uvicorn
pip install python-multipart
pip install aiofiles

 

FastAPI 로 서버 구성하기

이제부터 본론입니다. Python 기반 백엔드로는 대표적인것이 매우 가볍고 간단한 Flask, Django등이 있습니다. 이 중에서 저는 FastAPI를 사용하였는데, django의 'batteries-included' 방식과 Flask의 가볍고 심플함의 중간이라는 누군가의 썰을듣고 사용하고 있기 때문입니다. 최근 굉장히 popular해지고 있다고 알고 있습니다. 

아주 쉽게 getting started문서가 잘 작성되어있으므로 조금이라도 궁금하시면 잠깐만 들어가보시면 매우 쉽게 배울 수 있습니다.

https://fastapi.tiangolo.com/ko/

 

FastAPI

FastAPI FastAPI 프레임워크, 고성능, 간편한 학습, 빠른 코드 작성, 준비된 프로덕션 문서: https://fastapi.tiangolo.com 소스 코드: https://github.com/tiangolo/fastapi FastAPI는 현대적이고, 빠르며(고성능), 파이썬

fastapi.tiangolo.com

 

결론적으로 fastapi서버를 시작하려면 아래의 코드를 main.py에 넣고 (위 튜토리얼의 첫 코드입니다)

from fastapi import FastAPI

app = FastAPI()


@app.get("/")
async def root():
    return {"message": "Hello World"}

터미널에서 아래와같이 입력해주면됩니다. 

uvicorn main:app --host 0.0.0.0 --port 80 --reload

 이러고나면 이전에 확인된 우리의 아이피주소로 브라우져에서 입력하면 벌써 서버가 셋팅된것을 확인할 수 있습니다. 또한,  /docs 디렉토리에 들어가면 자동으로 openAPI 기반 swagger페이지까지(!!) 생성되어있는것을 알 수 있습니다.

 

 

4개의 파일 입력받기

이제 /predict라는 URL로 파일 4개를 입력받고 (t1, t2, t1ce, flair) 이 파일을 기반으로 작동하도록 구성해보도록 하겠습니다.

https://fastapi.tiangolo.com/tutorial/request-files/  공식 튜토리얼의 코드를 참고하여 아래와같이 코드를 작성합니다.

 

from fastapi import FastAPI, File, UploadFile

app = FastAPI()


@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):
    return {
        "t1" : t1_file.filename,
        "t2" : t2_file.filename,
        "t1ce" : t1ce_file.filename,
        "flair" : flair_file.filename,
    }

위와같이 t1_file, t2_file.. 이런식으로 4개의 파일을 받겠다고 getPrediction 함수에서 설정했으며 (참고로 이 함수 이름은 아무거나 지어도 전혀 상관없습니다), 우선 잘 작동하는지 보기위해 해당 파일이름 자체를 다시 반환하도록 설정했습니다.

 

http://[IP_ADDR]/docs 에서 잘 작동하는지 테스트해봅니다.

 

 

위 Try it out버튼을 클릭한 뒤 각 파일을 하나씩 설정해주신 뒤 execute 버튼을 누르면 아래와같이 서버가 해당 파일 이름들을 잘 출력해주는것을 확인할 수 있습니다.

 

 

 

이제 여기까지 잠시 정리하고 넘어가겠습니다.

지금까지 우리는 서버를 아마존에서 설정(구입)하고 거기에 fastapi라는 서버 프로그램을 깔았으며

이 서버로 하여금 4개의 파일을 입력받도록 만들었고,

이 입력받은 파일의 이름을 요청한 사람에게 반환하도록 설정했습니다.

반환한 형태는 JSON형태로, JSON에 대해서 모르신다면 간단히 찾아보시기를 추천드립니다. (매우 쉽습니다 ^^)

 

4개의 파일 임시 경로에 저장하기

업로드된 파일은 가상의 공간에 존재하며, 이것을 디스크의 특정 위치로 옮긴뒤 작업하기를 강력히 권장하고 있습니다.

그래서 임시디렉토리를 서버상에 아무데나 만들고, 여기다가 해당 파일들을 저장하는 작업을 하겠습니다.

 

아래와같이 서버상에 저장할 특정 폴더를 만들어주시고 시작하시면됩니다 (저는 /home/ec2-user/tmp 로 설정했습니다)

 

이후 main.py를 아래와같이 작성합니다.

 

from fastapi import FastAPI, File, UploadFile
import aiofiles
import time


app = FastAPI()

TMP_DIR = "/home/ec2-user/tmp"


async def saveFiles(file) :
    TIME = time.time() ### set current time as unique name for all file
    out_file_path =f"{TMP_DIR}/{TIME}-{file.filename}" 
    async with aiofiles.open(out_file_path, 'wb') as out_file:
        content = await file.read()  # async read
        await out_file.write(content)  # async write
    return out_file_path
    

@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):
    tmp_file = {}
    tmp_file["t1"] = await saveFiles(t1_file)
    tmp_file["t2"] = await saveFiles(t2_file)
    tmp_file["t1ce"] = await saveFiles(t1ce_file)
    tmp_file["flair"] = await saveFiles(flair_file)

    return tmp_file

/docs에서 또다시 테스트를 해보고나면 sftp client에서 확인해보면 아래와같이 파일이 잘 생성된것을 확인할 수 있습니다.

 

 

PyTorch로 모델 결과 구하기

이제 업로드된 파일로 김병훈 선생님의 코드를 훔쳐와서 적용하여, segmentation image를 도출합니다.

 

시작하기 전에 적절히 모델 pth파일도 서버 경로상에 올려놔야겠습니다.

저는 /home/ec2-user/server/model.pth에 올려놨습니다.

 

그리고 segmentation된 결과물이 저장되는 (.gif파일) 위치도 미리 지정해놓고, 이 위치에 폴더도 미리 만들어놔야합니다.

저는 /home/ec2-user/server/static/output 으로 해놨습니다.

 

아래와 같이 main.py를 수정합니다.

 

from fastapi import FastAPI, File, UploadFile
import aiofiles
import time

import os
import torch
import torchio as tio
import monai

app = FastAPI()

TMP_DIR = "/home/ec2-user/tmp"


MODEL_PATH = '/home/ec2-user/server/model.pth'
SAVE_DIR = '/home/ec2-user/server/static/output'

async def saveFiles(file) :
    TIME = time.time() 
    out_file_path =f"{TMP_DIR}/{TIME}-{file.filename}" 
    async with aiofiles.open(out_file_path, 'wb') as out_file:
        content = await file.read()  # async read
        await out_file.write(content)  # async write
    return out_file_path
    

def get_segmentation(model, data, device):
    model.to(device)
    input = torch.cat([data[sequence]['data'].unsqueeze(0) for sequence in ['t1', 't2', 't1ce', 'flair']], dim=1).to(device) 
    output = model(input).cpu().detach()
    pred = torch.nn.functional.one_hot(output.argmax(dim=1).squeeze(0)).permute(3,0,1,2) 
    return tio.LabelMap(tensor=pred)

@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):
    tmp_file = {}
    tmp_file["t1"] = await saveFiles(t1_file)
    tmp_file["t2"] = await saveFiles(t2_file)
    tmp_file["t1ce"] = await saveFiles(t1ce_file)
    tmp_file["flair"] = await saveFiles(flair_file)

    tio_images = {}
    for sequence in tmp_file :
        tio_images[sequence] =tio.ScalarImage(tmp_file[sequence]) 

    model = monai.networks.nets.BasicUNet(spatial_dims=3, in_channels=4, out_channels=2)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    subject = tio.Subject(tio_images)
    transforms = [
        tio.ToCanonical(),
        tio.Resample(3),
        tio.CropOrPad((64,64,48)),
        tio.RescaleIntensity(out_min_max=(0, 1)),
    ]
    transform = tio.Compose(transforms)
    dataset = tio.SubjectsDataset([subject], transform=transform)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    seg = get_segmentation(model, dataset[0], device)

    TIME = time.time()
    segmented_gif = f"{SAVE_DIR}/{TIME}-segmented.gif"
    seg.to_gif(axis=2, duration=10, output_path=segmented_gif, loop=0)

    return {
        "success" : True,
        "segmented" : segmented_gif
    }

 

또다시 실행을 한 뒤 : 

sftp로 파일이 잘 생성되었는지 확인해봅니다 :

 

정확히 모르지만 어쨌든 잘 생성된것같습니다!

이제 모든 이미지 sequence에 대하여  gif를 만들어서 output directory에 잘 넣어주면 pytorch로 할것은 모두 마무리된것 같습니다.

 

비슷하게 아래와같이 수정합니다.

 

from fastapi import FastAPI, File, UploadFile
import aiofiles
import time

import os
import torch
import torchio as tio
import monai

app = FastAPI()

TMP_DIR = "/home/ec2-user/tmp"


MODEL_PATH = '/home/ec2-user/server/model.pth'
SAVE_DIR = '/home/ec2-user/server/static/output'

async def saveFiles(file) :
    TIME = time.time() 
    out_file_path =f"{TMP_DIR}/{TIME}-{file.filename}" 
    async with aiofiles.open(out_file_path, 'wb') as out_file:
        content = await file.read()  # async read
        await out_file.write(content)  # async write
    return out_file_path
    

def get_segmentation(model, data, device):
    model.to(device)
    input = torch.cat([data[sequence]['data'].unsqueeze(0) for sequence in ['t1', 't2', 't1ce', 'flair']], dim=1).to(device) 
    output = model(input).cpu().detach()
    pred = torch.nn.functional.one_hot(output.argmax(dim=1).squeeze(0)).permute(3,0,1,2) 
    return tio.LabelMap(tensor=pred)

@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):
    tmp_file = {}
    tmp_file["t1"] = await saveFiles(t1_file)
    tmp_file["t2"] = await saveFiles(t2_file)
    tmp_file["t1ce"] = await saveFiles(t1ce_file)
    tmp_file["flair"] = await saveFiles(flair_file)

    tio_images = {}
    for sequence in tmp_file :
        tio_images[sequence] =tio.ScalarImage(tmp_file[sequence]) 

    model = monai.networks.nets.BasicUNet(spatial_dims=3, in_channels=4, out_channels=2)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    subject = tio.Subject(tio_images)
    transforms = [
        tio.ToCanonical(),
        tio.Resample(3),
        tio.CropOrPad((64,64,48)),
        tio.RescaleIntensity(out_min_max=(0, 1)),
    ]
    transform = tio.Compose(transforms)
    dataset = tio.SubjectsDataset([subject], transform=transform)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    seg = get_segmentation(model, dataset[0], device)


    TIME = time.time() 

    seg.to_gif(axis=2, duration=10, output_path=f"{SAVE_DIR}/{TIME}-segmented.gif", loop=0)
    paths['segmented'] = f"output/{TIME}-segmented.gif" 
    for sequence in dataset[0] :
        output_path = f"{SAVE_DIR}/{TIME}-{sequence}.gif" 
        dataset[0][sequence].to_gif(axis=2, duration=10, output_path=output_path, loop=0)
        paths[sequence] = f"output/{TIME}-{sequence}.gif"
    
    return {
        "success" : True
    }

 

 

정리하기

이제 client(웹페이지)에게 전달할 내용을 고민해보겠습니다.

위 예제에서 그랬던것 처럼 파일의 절대 경로를 알려주는것 (/home/ec2-user/...) 은 아무 도움이 되지 않습니다.

해당 gif파일에 대한 인터넷상에서 접속가능한 URL을 줘야할텐데, 이건 static file serving이 필요한 부분입니다.

 

즉, 브라우져가 서버의 특정 주소로 접속하면 서버상의 특정 파일을 보내줘서 '읽을 수 있도록'해주는 것입니다.

static file serving은 fastapi의 아래 문서에 잘 나와있습니다.

https://fastapi.tiangolo.com/tutorial/static-files/

 

Static Files - FastAPI

Static Files You can serve static files automatically from a directory using StaticFiles. Use StaticFiles Import StaticFiles. "Mount" a StaticFiles() instance in a specific path. from fastapi import FastAPI from fastapi.staticfiles import StaticFiles app =

fastapi.tiangolo.com

 

이에 따라 아래 import를 추가해주고

from fastapi.staticfiles import StaticFiles

 

아래의 기존 predict 를 정의해준 것이 모두 끝나고 나서 맨 뒤에다가 아래 항목을 작성해줍니다.

 

...
...

@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):

...
...



app.mount("/", StaticFiles(directory="static", html=True), name="static")

app.mount안의 코드를 잘 보면

첫번째 "/"은 외부에서 접속할 주소를 나타내며 (결국 우리의 경우 http://[IP_ADDRESS]/ 이렇게 그냥 가장 root 주소로 접속하는 경우를 뜻함)

directory="static"은 실제 경로상에서 main.py기준으로 어느 폴더를 보여주기를 희망하는지를 물어보며

(저는 main.py가 /home/ec2-user/server/main.py이므로 /home/ec2-user/server/static/ 폴더를 지칭하게 됩니다)

html=True option으로 html파일을 서빙하는 목적이라고 명시해주며 (그렇게 하면 index.html이렇게 써주지 않아도 기본적으로 index.html을 보여줍니다... 잘 모르겠으면 걍 써주세요)

name은 내부적으로 사용할 이름입니다.

 

이제 https://[IP_ADDRESS]/ 로 접속해보면 아래와 같이 보입니다.

 

왜냐면 해당 폴더(/home/ec2-user/server/static/)에 아무것도 없기 때문입니다. 그래서 여기에 index.html을 샘플로 만들어서 잘 작동하는지 보겠습니다.

 

/home/ec2/user/server/static/index.html 에다가 아래 파일을 넣어줍니다.

<b>hello</b> world!

그러면 잘 작동하는것을 볼 수 있습니다. 그러면 우리가 실제로 원하는 gif파일은

/home/ec2-user/server/static/output폴더에 있는데, 여기있는것은 잘 보일까요?

넵! 위와같이 잘 보입니다.

 

그러면 이제 서버에서 클라이언트로 무엇을 전달하면될까요?

제 생각에는 "/output/1635666652.807742-flair.gif" 이렇게 http://[IP_ADDRESS]/ 이후의 경로만 잘 전달해주면 될것 같습니다.

이에 따라 아래와같이 출력되도록 구성해보도록 하겠습니다 : 

{
	"success" : True,
    "paths" : {
    	"t1" : "output/1635666652.807742-t1.gif",
        "t2" : "output/1635666652.807742-t2.gif",
        "t1ce" : "output/1635666652.807742-t1ce.gif",
        "flair" : "output/1635666652.807742-flair.gif",
        "segmented" : "output/1635666652.807742-segmented.gif",
    }
}

 

main.py를 아래와같이 수정해줍니다.

 

from fastapi import FastAPI, File, UploadFile
import aiofiles
import time

import os
import torch
import torchio as tio
import monai

from fastapi.staticfiles import StaticFiles

app = FastAPI()



TMP_DIR = "/home/ec2-user/tmp"

MODEL_PATH = '/home/ec2-user/server/model.pth'
SAVE_DIR = '/home/ec2-user/server/static/output'

async def saveFiles(file) :
    TIME = time.time() 
    out_file_path =f"{TMP_DIR}/{TIME}-{file.filename}" 
    async with aiofiles.open(out_file_path, 'wb') as out_file:
        content = await file.read()  # async read
        await out_file.write(content)  # async write
    return out_file_path
    

def get_segmentation(model, data, device):
    model.eval()
    model.to(device)
    input = torch.cat([data[sequence]['data'].unsqueeze(0) for sequence in ['t1', 't2', 't1ce', 'flair']], dim=1).to(device) 
    output = model(input).cpu().detach()
    pred = torch.nn.functional.one_hot(output.argmax(dim=1).squeeze(0)).permute(3,0,1,2) 
    return tio.LabelMap(tensor=pred)

@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):
    tmp_file = {}
    tmp_file["t1"] = await saveFiles(t1_file)
    tmp_file["t2"] = await saveFiles(t2_file)
    tmp_file["t1ce"] = await saveFiles(t1ce_file)
    tmp_file["flair"] = await saveFiles(flair_file)

    tio_images = {}
    for sequence in tmp_file :
        tio_images[sequence] =tio.ScalarImage(tmp_file[sequence]) 

    model = monai.networks.nets.BasicUNet(spatial_dims=3, in_channels=4, out_channels=2)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    subject = tio.Subject(tio_images)
    transforms = [
        tio.ToCanonical(),
        tio.Resample(3),
        tio.CropOrPad((64,64,48)),
        tio.RescaleIntensity(out_min_max=(0, 1)),
    ]
    transform = tio.Compose(transforms)
    dataset = tio.SubjectsDataset([subject], transform=transform)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    seg = get_segmentation(model, dataset[0], device)

    paths = {}

    TIME = time.time() 

    seg.to_gif(axis=2, duration=10, output_path=f"{SAVE_DIR}/{TIME}-segmented.gif", loop=0)
    paths['segmented'] = f"output/{TIME}-segmented.gif" 
    for sequence in dataset[0] :
        output_path = f"{SAVE_DIR}/{TIME}-{sequence}.gif" 
        dataset[0][sequence].to_gif(axis=2, duration=10, output_path=output_path, loop=0)
        paths[sequence] = f"output/{TIME}-{sequence}.gif"
    

    return {
        "success" : True,
        "paths" : paths
    }


app.mount("/", StaticFiles(directory="static", html=True), name="static")

Swagger에서 예상한대로 잘 출력이됩니다.

 

 

 

파일 처리하기

이제 끝내고 싶지만, 생각해보면 사용자들이 업로드하는 파일들이 그대로 다 서버에 쌓이므로, 15기가가 금방 다 찰것으로 기대할 수 있습니다. 이에 따라 결과물은 바로 삭제하면 client에서 확인하기도 전에 삭제가 될 것이므로 조금 문제가 되겠지만, 용량이 큰 원본파일은 프로세스가 끝나고 나면 불필요하므로 삭제하는것이 좋을것 같습니다. 

이에 따라 마지막에 return하기 직전에 원본 파일을 아래와같이 지웁니다 :

 

...
...

    for sequence in tmp_file :
        os.remove(tmp_file[sequence])

    return {
        "success" : True,
        "paths" : paths
    }


app.mount("/", StaticFiles(directory="static", html=True), name="static")

 

기존의 파일을 지우고나면 이제는 tmp폴더에 더이상 파일이 쌓이지 않습니다.

 

결과물 (.gif)파일 지우는것에 대하여
파일을 바로 지우면 client에서 해당 이미지를 보기도 전에 지워져버리기 때문에 어떤 delay를 주고 지우는것이 좋겠습니다. 이에 대해서는 cron.daily등의 linux상에서 매일 특정시간에 실행되는 bash script를 통해 실현이 가능합니다.
아래의 링크 또는 검색을 통해 (cron delete files after 1 day...) 추가로 공부해보실 수 있습니다.
https://arstech.net/cron-job-for-linux-to-delete-files-older-than-x-days/

 

 

정말 마지막, CORS설정

CORS 설정이라는것이 필요합니다. 보안관련된 문제인데, API를 아무데서나 외부에서 막 리퀘스트하면 보안상 여러가지 문제가 발생한다는 점에서, 정해진 위치에서만 해당 서버에 api콜을 할 수 있다는 개념입니다.

여기서 위치라고 하면 실제 사용하는 사람의 위치가 아니라, 이 서버에 요청을 보내는 client 앱의 위치라는 뜻으로,

우리의 앱의 경우에는 웹사이트가 설치된 서버가 되겠습니다.

결국 최종에는 우리의 웹앱 자체가 같은 서버에 올라갈 예정이므로 아무 문제가 없겠지만,

아래에서 프론트 개발중에는 localhost에서 돌아갈 예정이기때문에 이 내용을 추가해줘야합니다.

아래 공식 가이드라인에 따라 cors middleware를 설치해주신다음에

main.py에는 아래와같이 추가해주시면됩니다.

 

https://fastapi.tiangolo.com/tutorial/cors/

 

CORS (Cross-Origin Resource Sharing) - FastAPI

CORS (Cross-Origin Resource Sharing) CORS or "Cross-Origin Resource Sharing" refers to the situations when a frontend running in a browser has JavaScript code that communicates with a backend, and the backend is in a different "origin" than the frontend. O

fastapi.tiangolo.com

 

 

 

이제 백엔드는 모두 완성되었습니다.............

 

from fastapi import FastAPI, File, UploadFile
import aiofiles
import time

import os
import torch
import torchio as tio
import monai

from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

origins = [
    "http://localhost:3000",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)



TMP_DIR = "/home/ec2-user/tmp"

MODEL_PATH = '/home/ec2-user/server/model.pth'
SAVE_DIR = '/home/ec2-user/server/static/output'

async def saveFiles(file) :
    TIME = time.time() 
    out_file_path =f"{TMP_DIR}/{TIME}-{file.filename}" 
    async with aiofiles.open(out_file_path, 'wb') as out_file:
        content = await file.read()  # async read
        await out_file.write(content)  # async write
    return out_file_path
    

def get_segmentation(model, data, device):
    model.eval()
    model.to(device)
    input = torch.cat([data[sequence]['data'].unsqueeze(0) for sequence in ['t1', 't2', 't1ce', 'flair']], dim=1).to(device) 
    output = model(input).cpu().detach()
    pred = torch.nn.functional.one_hot(output.argmax(dim=1).squeeze(0)).permute(3,0,1,2) 
    return tio.LabelMap(tensor=pred)

@app.post("/predict")
async def getPrediction(t1_file: UploadFile = File(...),t2_file: UploadFile = File(...),t1ce_file: UploadFile = File(...),flair_file: UploadFile = File(...)):
    tmp_file = {}
    tmp_file["t1"] = await saveFiles(t1_file)
    tmp_file["t2"] = await saveFiles(t2_file)
    tmp_file["t1ce"] = await saveFiles(t1ce_file)
    tmp_file["flair"] = await saveFiles(flair_file)

    tio_images = {}
    for sequence in tmp_file :
        tio_images[sequence] =tio.ScalarImage(tmp_file[sequence]) 

    model = monai.networks.nets.BasicUNet(spatial_dims=3, in_channels=4, out_channels=2)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    subject = tio.Subject(tio_images)
    transforms = [
        tio.ToCanonical(),
        tio.Resample(3),
        tio.CropOrPad((64,64,48)),
        tio.RescaleIntensity(out_min_max=(0, 1)),
    ]
    transform = tio.Compose(transforms)
    dataset = tio.SubjectsDataset([subject], transform=transform)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    seg = get_segmentation(model, dataset[0], device)

    paths = {}

    TIME = time.time() 

    seg.to_gif(axis=2, duration=10, output_path=f"{SAVE_DIR}/{TIME}-segmented.gif", loop=0)
    paths['segmented'] = f"output/{TIME}-segmented.gif" 
    for sequence in dataset[0] :
        output_path = f"{SAVE_DIR}/{TIME}-{sequence}.gif" 
        dataset[0][sequence].to_gif(axis=2, duration=10, output_path=output_path, loop=0)
        paths[sequence] = f"output/{TIME}-{sequence}.gif"
    

    for sequence in tmp_file :
        os.remove(tmp_file[sequence])

    return {
        "success" : True,
        "paths" : paths
    }


app.mount("/", StaticFiles(directory="static", html=True), name="static")

 

프론트엔드 만들기

프론트엔드는 리액트로 만들예정입니다.

리액트에 대해서는 너무 강좌가 많을정도로 대중적으로 쓰이는 프레임워크라서 쉽게 검색해서 공부해보실 수 있습니다.

 

리액트 등 프레임워크 구성하기

우선 간단한 예제를 위하여 npx create-react-app을 통하여 새로운 프로젝트를 만드는데,

서버에서 바로 하지 않고 로컬에서 진행합니다

 

터미널상에서 해당 프로젝트를 만들고 싶은 폴더에 가서 npx create-react-app을 실행합니다.

저는 프로젝트 이름을 tumor로 설정하겠으며, 그러면 tumor라는 폴더가 만들어집니다.

npx create-react-app tumor

 

이제 해당 디렉토리에 가서 필요한 프레임워크 두가지를 더 추가해줍니다.

 

cd tumor
npm install axios
npm install react-bootstrap bootstrap@5.1.3

 

axios는 network request를 잘 해주는 프레임워크로, 많이들 사용하고 있으며

(https://www.npmjs.com/package/axios)

 

axios

Promise based HTTP client for the browser and node.js

www.npmjs.com

 

react-bootstrap은 너무 허접해보이는 것을 방지하기 위해서 좀 이쁜 HTML +CSS  컴포넌트를 제공해줍니다.

(https://react-bootstrap.github.io/)

 

React-Bootstrap

The most popular front-end framework, rebuilt for React.

react-bootstrap.github.io

 

 

본인이 좋아하는 editor에서 생성된 폴더를 열어줍니다. (저는 VSCode)

 

 

React-bootstrap관련 설정

Bootstrap은 추가 설정이 조금 더 필요합니다. 공식 가이드에 따라 아래와같이 수정이 필요하며, 

/tumor/public/index.html의 <head> tag안에 추가해주시면됩니다.

https://react-bootstrap.github.io/getting-started/introduction/

 

React-Bootstrap

The most popular front-end framework, rebuilt for React.

react-bootstrap.github.io

<!DOCTYPE html>
<html lang="en">
  <head>
  	<!--- 아래를 추가해주세요 --->
    <script src="https://unpkg.com/react/umd/react.production.min.js" crossorigin></script>
    <script
      src="https://unpkg.com/react-dom/umd/react-dom.production.min.js"
      crossorigin></script>
    <script
      src="https://unpkg.com/react-bootstrap@next/dist/react-bootstrap.min.js"
      crossorigin></script>
    <script>var Alert = ReactBootstrap.Alert;</script>
    <link
      rel="stylesheet"
      href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
      integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
      crossorigin="anonymous"
    />
    <!--- 위를 추가해주세요 --->
    <meta charset="utf-8" />
    <link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="theme-color" content="#000000" />
    <meta

 

레이아웃 구성하기


우리는 이 중 src/App.js에서만 작업할 예정으로 이 파일을 띄워주시면 됩니다.

이제 다 지우고 새로운 마음으로 시작합니다.

 

import './App.css';

function App() {
  return (
    <div className="App">
   	  Hello world!
    </div>
  );
}

export default App;

 

이제 해당 폴더 (/tumor)에서 npm start 을 입력하여 실행시켜봅니다

 

npm start

브라우져에서 해당 위치 (http://localhost:3000)에서 웹페이지가 잘 작동하시는것을 볼 수 있으며, 내용을 수정하고 저장하면 자동으로 리프레쉬되며 내용이 바뀝니다.

 

간단히 아래와같이 제목을 바꿔보고 저장해보겠습니다.

 

import './App.css';

function App() {
  return (
    <div className="App">
      <h3>Tumor Detector</h3>
    </div>
  );
}

export default App;

 

 

 

레이아웃으로는 아래의 항목들이 필요합니다.

  1. 파일을 업로드할 수 있는 인풋 4개 (t1, t2, t1ce, flair)
  2. 업로드 버튼
  3. 받은 결과를 표시할 곳

이 레이아웃을 구성할 수 있는 방법은 여러가지가 있겠지만,

React-bootstrap의 <Row> <Col>을 이용하려고 합니다.

<Row> 는 한 가로줄을 구성하고, Col은 칸(?)을 생성해줍니다.

자세한 내용은 간단히 아래 사이트에서 읽어볼 수 있습니다. 조금 귀찮지만 이 페이지 몇개만 투자해서 읽어보시면 웹페이지 레이아웃은 마스터하실 수 있습니다.

https://react-bootstrap.github.io/layout/grid/ 

 

React-Bootstrap

The most popular front-end framework, rebuilt for React.

react-bootstrap.github.io

 

해당 컴포넌트들을 import해주시고 아래와같이 레이아웃을 구성해 봅니다.

 

 

import './App.css';
import {Row, Col, Container} from 'react-bootstrap'

function App() {
  return (
    <div className="App">
      <h3>Tumor Detector</h3>
      <Container fluid>
        <Row>
          <Col>input</Col>
          <Col>input</Col>
          <Col>input</Col>
          <Col>input</Col>
        </Row>
        <Row>
          Button
        </Row>
        <Row>
          <Col>output</Col>
          <Col>output</Col>
          <Col>output</Col>
          <Col>output</Col>
          <Col>output</Col>

        </Row>
      </Container>
    </div>
  );
}

export default App;

 

Input component만들어보기

리액트는 기본적으로 컴포넌트들을 만들고, 이 컴포넌트들의 조합으로 이뤄지는 개념입니다.

이에 따라 input 에 들어갈 컴포넌트들을 만들어보겠습니다.

Javascript function이며 아래와같이 문서 아래에 추가해줍니다.

 

function InputComponent(props) {
  return(

    <Col>
        <input type="file" />
    </Col>
  )
}

그러면 이제 위 스크립트의 <Col>input</Col>부분을 해당 InputComponent로 대체할 수 있습니다

import './App.css';
import {Row, Col, Container, InputGroup} from 'react-bootstrap'

function App() {
  return (
    <div className="App">
      <h3>Tumor Detector</h3>
      <Container fluid>
        <Row>
          <InputComponent />
          <InputComponent />
          <InputComponent />
          <InputComponent />
        </Row>
        <Row>
          Button
        </Row>
        <Row>
          <Col>output</Col>
          <Col>output</Col>
          <Col>output</Col>
          <Col>output</Col>
          <Col>output</Col>

        </Row>
      </Container>
    </div>
  );
}

function InputComponent(props) {
  return(

    <Col>
        <input type="file" />
    </Col>
  )
}
export default App;

위 InputComponent function은 props라는 argument를 받게 되어있습니다.

이 props는 다른곳에서 이 컴포넌트를 사용할때 attribute로 설정되는 값들을 받아오는 역할을 합니다.

 

즉, 아래와 같이 component에서 label (또는 어떤 이름이어도 상관없습니다) 이라는 attribute를 정의하고나면 아래와같이 해당 컴포넌트에서 갖다가 쓸수 있게됩니다.

<Component label="this is label" />
function Component(props) {
	return(<span>props.label</span>)
}

 

이 개념을 그대로 이용하여 InputComponent를 아래와같이 수정합니다.

function InputComponent(props) {
  return(

    <Col>
        {props.label} image :
        <input type="file" onChange={props.handleChange} />
    </Col>
  )
}

 

 

또한 output component도 React-bootstrap의 Figure component를 활용하여 새롭게 구성해 봅니다 :

 

import './App.css';
import {Row, Col, Container, Figure} from 'react-bootstrap'

const SEQUENCES = ["t1","t2","flair","t1ce"];
const OUTPUT_SEQUENCES = ["t1","t2","flair","t1ce","segmented"];
const SERVER_URL = "3.35.4.26";

function App() {
  
  return (
    <div className="App">
      <h3>Tumor Detector</h3>
      <Container fluid>
        <Row>
          {SEQUENCES.map((element) => <InputComponent key={element} label={element} handleChange={(e) => {
            console.log(e.target.files[0])
          }} />)}
        </Row>
        <Row>
          Button
        </Row>
        <Row>
          {OUTPUT_SEQUENCES.map((element) => <OutputComponent key={element} label={element} />)}


        </Row>
      </Container>
    </div>
  );
}

function InputComponent(props) {
  return(
    <Col>
        {props.label} image :
        <input type="file" onChange={props.handleChange} />
    </Col>
  )
}


function OutputComponent(props) {
  return (
    <Col>
      <Figure>
        <Figure.Image
          width={128}
          height={128}
          alt={props.label}
          src={"http://" + SERVER_URL + props.path}
        />
        <Figure.Caption>
          {props.label}
        </Figure.Caption>
      </Figure>
    </Col>
  )
}
export default App;

 

 

위 코드를 잘 살펴보면, InputComponent에 handleChange라는 attribute를 전달하며,

이 attribute에다가 함수를 전달하고 있습니다. 이 함수는 e 라는 instance를 받아서 e.target.files[0]을 console.log함수에 전달합니다. e.target.files[0] 은 사용자가 선택한 파일의 이름이라고 보시면되겠으며,

console.log()은 디버그를 위한 함수로, 브라우져에서 (사용자에게 보이지않게) 디버그 메세지를 출력할 수 있습니다.

 

브라우져에서 (저는 현재 chrome을 사용중이며) 우측 클릭 ->inspector와 같은 메뉴를 클릭하여 디버그 창을 같이 켜놓으면 아래와같이 보입니다.

 

이제 파일을 선택해보면 여기에 해당 파일이 출력됩니다.

 

이제 이렇게 사용자가 선택한 파일을 이제 프로그램상으로 받을 수 있어졌습니다.

 

이렇게 받은 변수는 javascript object를 만들어서 저장해놓으면 나중에 서버로 request를 보낼때 사용할 수 있습니다.

files라는 object를 정의하고, 업데이트시 여기에 저장되도록 스크립트를 변경합니다.

 

import './App.css';
import {Row, Col, Container, Figure} from 'react-bootstrap'
import Button from '@restart/ui/esm/Button';

const SEQUENCES = ["t1","t2","flair","t1ce"];
const OUTPUT_SEQUENCES = ["t1","t2","flair","t1ce","segmented"];
const SERVER_URL = "3.35.4.26";

var files = {
  "t1" : null,
  "t2" : null,
  "t1ce" : null,
  "flair" : null
};

function App() {
  return (
    <div className="App">
      <h3>Tumor Detector</h3>
      <Container fluid>
        <Row>
          {SEQUENCES.map((element) => <InputComponent key={element} label={element} handleChange={(e) => {
            files[element] = e.target.files[0];
          }} />)}
        </Row>
        <Row>
          Button
          
        </Row>
        <Row>
          {OUTPUT_SEQUENCES.map((element) => <OutputComponent key={element} label={element} />)}


        </Row>
      </Container>
    </div>
  );
}

function InputComponent(props) {
  return(
    <Col>
        {props.label} image :
        <input type="file" onChange={props.handleChange} />
    </Col>
  )
}


function OutputComponent(props) {
  return (
    <Col>
      <Figure>
        <Figure.Image
          width={128}
          height={128}
          alt={props.label}
          src={"http://" + SERVER_URL + props.path}
        />
        <Figure.Caption>
          {props.label}
        </Figure.Caption>
      </Figure>
    </Col>
  )
}
export default App;

 

Axios로 request보내기

이제는 파일이 준비가 되었으니 이 파일을 서버로 보내기만 하면될것 같습니다.

axios라는 프레임워크를 이용하게 될텐데요,

우선 이전에 "Button"이라고 쓰여있던 부분을 실제로 React-bootstrap의 Button component로 바꾸고, onClick 이벤트에 axios 함수를 불러 처리하도록 합니다.

 

우선 axios를 import하고,

import axios from 'axios';

Button부분에 아래 스크립트를 추가합니다.

 

<Button onClick={() => {
  if(Object.values(files).every((element) => element !== null)) {

    const formData = new FormData();

    for (const [key,value] of Object.entries(files)) {
      formData.append(key+"_file", value);
    }

    axios.post('http://'+SERVER_URL+'/predict',formData,{
      headers: { "Content-Type": "multipart/form-data"}
    }).then((response) => {
      console.log(response);
    }).catch((e) => {
      alert('error');
    }) ;
  } else {
    alert('please select files for all sequences');
  }
}}>Submit</Button>

 

익숙하지 않으시다면 좀 어렵게 느껴지실수 있고, 그냥 자세한 문법등은 무시하시고 대략 이해만 하시면 될것 같습니다.

Button이라는 컴포넌트를 만들고,

거기에서 onClick이벤트로 아래 내용을 실행합니다.

1. files object를 모두 검사해서 빈값이 없는지 확인하고(null), 빈값이 있으면 'please select files for all sequences'라고 에러 메세지를 출력한다

2. FormData instance를 만들고, 여기에 각 시퀀스별로 _file을 붙여서 보낼 값을 정의한다 (ie. t1_file, t2_file...) <- 이 부분은 이전에 만든 backend에 맞춰서 폼을 구성하는것입니다.

3. 이후 axios.post를 통해 post method로 우리 서버/predict주소로 해당 리퀘스트를 보낸다.

4. 성공적으로 response가 도착하면 console.log으로 디버그 메세지를 출력한다.

 

이렇게 해서 실행하면, 아래와같이 (우리가 백엔드 짤 때 의도한대로) 잘 response가 도착하는것을 볼 수 있습니다.

 

그런데, 처음에 버튼을 누르고나서 사실 아무반응이 없어서 '버그인가?'싶으셨을 수 있습니다. 

즉, 로딩을 표시하는 방법이 없습니다.

로딩을 표시하기위해서는 react state에 대한 이해가 필요합니다.

 

 

React State에 대한 이해

리액트는 state기반 프레임워크입니다. 이 컨셉을 처음 접하고는 정말 너무 편리해서 저는 충격적이었는데요, 이 state라는 개념을 이해하는게 아주 쉬운일은 아닙니다. 관련된 문서를 한번 읽어보시기를 권유드립니다.

https://ko.reactjs.org/docs/state-and-lifecycle.html

 

State and Lifecycle – React

A JavaScript library for building user interfaces

ko.reactjs.org

간단하게만 말씀드리자면,

react component에서 state란, 해당 컴포넌트의 출력과 관련된 여러가지 변수를 저장할 수 있는 창고로 생각하시면됩니다.

그리고, 여기안에 저장되는 변수들의 값이 변하게된다면, 출력에 반영되게 됩니다.

 

예를 들어 name이라는 변수가 state안에 저장되고 있다고 정의하고,

이 name이라는 변수를 이용해서 본문에서 출력하고 있다고 정의하면,

이 name을 변경하므로 해서 본문의 출력이 변경되게 되는 것입니다.

 

 

그렇다면 우리의 케이스를 고민해보면,

loading이라는 변수를 만들고,

만약에 현재 서버에 요청을 보내놓은 상태이면 이 loading state를 true로 변경해놓고,

 

loading state가 true라면 출력상에서 버튼에다가 돌아가는 loading그림을 보여주면 적절할것 같습니다.

 

useState hook 이라는 react 개념을 이용하게되며, 이를 이용하여 아래와같이 코드를 수정합니다.

https://ko.reactjs.org/docs/hooks-state.html

 

Using the State Hook – React

A JavaScript library for building user interfaces

ko.reactjs.org

 

 

import './App.css';
import {Row, Col, Container, Figure, Button, Spinner} from 'react-bootstrap'
import axios from 'axios';
import { useState } from 'react';

const SEQUENCES = ["t1","t2","flair","t1ce"];
const OUTPUT_SEQUENCES = ["t1","t2","flair","t1ce","segmented"];
const SERVER_URL = "3.35.4.26";

var files = {
  "t1" : null,
  "t2" : null,
  "t1ce" : null,
  "flair" : null
};

function App() {
  const [loading, setLoading] = useState(false);

  return (
    <div className="App">
      <h3>Tumor Detector</h3>
      <Container fluid>
        <Row>
          {SEQUENCES.map((element) => <InputComponent key={element} label={element} handleChange={(e) => {
            files[element] = e.target.files[0];
          }} />)}
        </Row>
        <Row>
          <Button onClick={() => {
            if(Object.values(files).every((element) => element !== null)) {
              
              const formData = new FormData();

              for (const [key,value] of Object.entries(files)) {
                formData.append(key+"_file", value);
              }
              setLoading(true);

              axios.post('http://'+SERVER_URL+'/predict',formData,{
                headers: { "Content-Type": "multipart/form-data"}
              }).then((response) => {
                console.log(response);
                setLoading(false);
              }).catch((e) => {
                alert('error');
                setLoading(false);
              }) ;
            } else {
              alert('please select files for all sequences');
            }
          }}>{loading ? <span><Spinner
            as="span"
            animation="grow"
            size="sm"
            role="status"
            aria-hidden="true"
          /><span> loading</span></span> : <span>Submit</span>}</Button>
          
        </Row>
        <Row>
          {OUTPUT_SEQUENCES.map((element) => <OutputComponent key={element} label={element} />)}


        </Row>
      </Container>
    </div>
  );
}

function InputComponent(props) {
  return(
    <Col>
        {props.label} image :
        <input type="file" onChange={props.handleChange} />
    </Col>
  )
}


function OutputComponent(props) {
  return (
    <Col>
      <Figure>
        <Figure.Image
          width={128}
          height={128}
          alt={props.label}
          src={"http://" + SERVER_URL + props.path}
        />
        <Figure.Caption>
          {props.label}
        </Figure.Caption>
      </Figure>
    </Col>
  )
}
export default App;

위와같이 버튼을 클릭하고나면 로딩 애니메이션이 보여서 사용자가 지금 작업중이라는것을 인지할 수 있도록 합니다.

 

또한, 아래 결과이미지들의 주소 또한 state에 저장하게된다면,

response가 오고나서 이 주소를 저장하는것만으로 아래 이미지도 표시할 수 있게 됩니다.

 

여기에 추가로 loaded 라는 변수까지 정의하여서, loading 이 끝난 시점에서만 이미지를 보여서 엑박이 출력되지 않도록 수정해보면 아래와 같습니다.

 

 

import './App.css';
import {Row, Col, Container, Figure, Button, Spinner} from 'react-bootstrap'
import axios from 'axios';
import { useState } from 'react';

const SEQUENCES = ["t1","t2","flair","t1ce"];
const OUTPUT_SEQUENCES = ["t1","t2","flair","t1ce","segmented"];
const SERVER_URL = "3.35.4.26";

var files = {
  "t1" : null,
  "t2" : null,
  "t1ce" : null,
  "flair" : null
};

function App() {
  const [loading, setLoading] = useState(false);
  const [loaded, setLoaded] = useState(false);
  const [urls, setUrls] = useState({
    "t1": null,
    "t2" : null,
    "t1ce" : null,
    "flair" : null
  });

  return (
    <div className="App">
      <h3>Tumor Detector</h3>
      <Container fluid>
        <Row>
          {SEQUENCES.map((element) => <InputComponent key={element} label={element} handleChange={(e) => {
            files[element] = e.target.files[0];
          }} />)}
        </Row>
        <Row>
          <Button onClick={() => {
            if(Object.values(files).every((element) => element !== null)) {
              
              const formData = new FormData();

              for (const [key,value] of Object.entries(files)) {
                formData.append(key+"_file", value);
              }
              setLoading(true);
              setLoaded(false); // reset loaded state, so hide result images

              axios.post('http://'+SERVER_URL+'/predict',formData,{
                headers: { "Content-Type": "multipart/form-data"}
              }).then((response) => {
                console.log(response);
                setLoading(false);
                setLoaded(true); // set loaded ==true only when successful
                setUrls(response.data.paths);

              }).catch((e) => {
                alert('error');
                setLoading(false);
              }) ;
            } else {
              alert('please select files for all sequences');
            }
          }}>{loading ? <span><Spinner
            as="span"
            animation="grow"
            size="sm"
            role="status"
            aria-hidden="true"
          /><span> loading</span></span> : <span>Submit</span>}</Button>
          
        </Row>
        {loading === false && loaded === true && <Row>
          {OUTPUT_SEQUENCES.map((element) => <OutputComponent key={element} label={element} path={urls[element]} />)}


        </Row>}
      </Container>
    </div>
  );
}

function InputComponent(props) {
  return(
    <Col>
        {props.label} image :
        <input type="file" onChange={props.handleChange} />
    </Col>
  )
}


function OutputComponent(props) {
  return (
    <Col>
      <Figure>
        <Figure.Image
          width={128}
          height={128}
          alt={props.label}
          src={"http://" + SERVER_URL + "/" + props.path}
        />
        <Figure.Caption>
          {props.label}
        </Figure.Caption>
      </Figure>
    </Col>
  )
}
export default App;

 

아래와같이 결과를 받기전에는 최종 이미지 관련 섹션이 뜨지 않다가,

 

결과가 나오면 그때 내용이 뜹니다.

 

리액트 사이트 웹에 올리기

현재까지는 개발목적으로 로컬에서 돌리고 있었는데, 실제로 서버에 올리려면 build 과정을 거쳐야 합니다.

/tumor 폴더에서

npm run build

을 입력하면 /tumor/build 디렉토리가 생성되며 파일들이 만들어집니다.

 

이 파일들을 sftp client를 이용하여 /home/ec-user/server/static/ 폴더안에 올려줍니다.

 

그렇게 하면

http://[IP_ADDRESS]/

로 접속하면 만들어진 웹사이트를 볼 수 있습니다.

 

웹사이트 꾸미기

이제 아무래도 좀더 보기 좋게 꾸미는 일만 남았습니다.

React-bootstrap framework에서 다양한 component를 보면서 하나하나 추가해나가는 재미가 있습니다.

 

제 최선의 센스를 이용하여서 아래와 같이 수정하였으며,

이 전체 소스코드는 github에서 확인하실 수 있습니다.

 

 

이상이었습니다. 감사합니다 ^^

  • 프로필사진
    이상협2021.11.19 23:54

    선생님께서 쓰신 Machine learning–based model for prediction of outcomes in acute stroke 논문을 읽고 감명 받았었습니다.

    그래서 저도 파이썬 부터 배우기 시작해서, Traumatic brain injury 환자의 예후를 예측하는 머신 러닝 모델에 대해 논문을 쓰게 되어 감사한 마음을 가지고 있습니다.

    이번에도 좋은 내용을 공유해주셔서 감사합니다.

    • 프로필사진
      JNHEO2022.10.04 05:48 신고

      이상협 선생님 안녕하십니까? 제가 덧글을 거의 일년이 다 지나서 이제서야 봤습니다.. 별것도 아닌 성과에 대해 좋게 봐주시고 이렇게 글까지 남겨주셔서 감사드립니다! 앞으로 훌륭한 연구 많이 하시는것 잘 보고 저도 많이 배울 수 있도록 하겠습니다. 다시한번 이렇게 덧글주셔서 감사드립니다!