CS/NLP

실습2-2. 텍스트 정보 추출 모델 API 생성

초코chip 2024. 3. 27. 21:10

배경

이전 연구에서 문자 데이터의 특정 정보를 효율적으로 추출하는 방법에 대해 탐구하였습니다.

이 과정에서 다양한 기계 학습 및 자연어 처리 모델 중 KoBERT 모델이 한국어 텍스트에 대해 우수한 성능을 보임을 확인했습니다.

https://chocochip125.tistory.com/220

 

텍스트 정보 추출 (with. 개체명 인식(NER))

보호되어 있는 글입니다. 내용을 보시려면 비밀번호를 입력하세요.

chocochip125.tistory.com

이에 본 프로젝트는 학습된 KoBERT 모델을 활용하여 문자 데이터로부터 필요한 도메인 정보를 식별하고 추출하는 API를 구현하는 것을 목표로 하였습니다.

 

구현 과정

  1. 모델 및 토크나이저 로드: 학습된 KoBERT 모델과 토크나이저를 메모리에 로드하여 요청 처리를 준비
  2. 입력 데이터 전처리: 사용자의 입력을 모델이 이해할 수 있는 형태로 변환
  3. 모델을 통한 예측 수행: 전처리된 데이터에 대해 분류 작업을 수행하고, 각 토큰의 카테고리를 예측
  4. 예측 결과 처리 및 정보 추출: 예측된 카테고리에 따라 원문에서 필요한 정보를 추출후 사용자가 이해하기 쉬운 형식으로 가공
  5. 사용자에게 결과 반환: 최종적으로 가공된 내용을 사용자에게 반환

 

1. 모델 및 토크나이저 로드

  • 이유: 텍스트 정보 추출 API를 만들기 위해 학습된 KoBERT 모델을 로드할 필요가 있습니다.
  • 방법: 프로그램 시작 시에 사전 학습된 KoBERT 모델과 토크나이저를 로드하여 메모리에 상주시킵니다.
  • 코드:
from transformers import AutoModelForTokenClassification, AutoTokenizer 

model_name = "./model/kobert" 
model = AutoModelForTokenClassification.from_pretrained(model_name) 
tokenizer = AutoTokenizer.from_pretrained(model_name)

 

2. 입력 데이터 전처리

  • 이유: 입력된 텍스트를 모델이 처리할 수 있는 형식으로 변환하기 위해서는 적절한 토큰화와 전처리가 필수적입니다.
  • 방법: 토크나이저를 사용하여 텍스트를 토큰화하고, 필요한 패딩을 적용하여 모든 입력이 동일한 길이를 갖도록 합니다.
  • 코드:
# 텍스트 전처리 함수
def preprocess_text(text: str):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512, return_offsets_mapping=True)
    return inputs
  • 입력 예시:
{
  "message": "[Web발신] [신한체크취소] 장*진(8730) 04/18 11:13 (금액)3,700원 구글플레이"
}
  • 결과:
{
  'input_ids': tensor([[   2,  362,    0,  363,  362,    0,  363, 7178,   44, 7344,   18,    0,
           40,   86,   59,  115,  108,  249,  110,   18, 5553,   40,  142,   46,
            0,    0,    3]]),
  'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0]]),
  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1]]),
  'offset_mapping': tensor([[[ 0,  0], [ 0,  1], [ 1,  6], [ 6,  7], [ 8,  9], [ 9, 15],
           [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [21, 25], [25, 26], 
           [27, 29], [29, 30], [30, 32], [33, 35], [35, 36], [36, 38], [39, 40], 
           [40, 42], [42, 43], [43, 44], [44, 45], [45, 49], [50, 55], [ 0,  0]]])
}

 

 

3. 모델을 통한 예측 수행

  • 이유: 텍스트에서 필요한 정보를 식별하기 위해 각 토큰이 어느 카테고리에 속하는지 예측할 필요가 있습니다.
  • 방법: 모델에 전처리된 데이터를 입력하여 각 토큰에 대한 카테고리 예측을 수행합니다. 
  • 코드:
# 예측 수행 함수
def predict_categories(inputs):
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=-1)
    return predictions, inputs['offset_mapping'].detach().cpu().numpy()[0]
  • 결과:
tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 0, 4, 4,
         4, 2, 0]])

 

4. 예측 결과 처리 및 정보 추출

  • 이유: 모델의 예측 결과는 사람이 이해하기 어려운 형태이므로, 사용자에게 의미있는 형태로 제공하기 위해 추가적인 처리가 필요합니다. 
  • 방법: 예측 결과를 순회하며, 각 토큰의 카테고리에 따라 해당하는 텍스트를 추출합니다. 이렇게 추출된 텍스트를 사용자가 이해할 수 있는 형태로 가공합니다.
  • 코드:
id2tag = {0: 'O', 1: 'METHOD', 2: 'LOCATION', 3: 'TIME', 4: 'COST'}

# 정보 추출 함수
def extract_information(predictions, offset_mapping, original_text):
    labels = [id2tag[id] for id in predictions[0].cpu().numpy()]
    extracted_info = {"METHOD": "", "LOCATION": "", "TIME": "", "COST": ""}

    for i, (offset, label) in enumerate(zip(offset_mapping, labels)):
        if label != "O":
            start, end = offset
            extracted_text = original_text[start:end]
            extracted_info[label] += extracted_text + " "

    for key in extracted_info:
        extracted_info[key] = extracted_info[key].strip()

    return extracted_info
  • 결과:
{
    "METHOD": "신한체크취소",
    "LOCATION": "구글플레이",
    "TIME": "04/18 11:13",
    "COST": "3,700원"
}

 

5. 사용자에게 결과 반환

  • 이유: 최종적으로 가공된 정보를 사용자가 요청한 형식에 맞게 반환하는 것이 필요합니다.
  • 방법: 각 카테고리별로 추출된 정보를 JSON 형식으로 정리하여 사용자에게 응답으로 제공합니다.
  • 코드:
@app.post("/keywords/")
async def create_item(item: Item):
    try:
        # 텍스트 전처리
        preprocessed_inputs = preprocess_text(item.message)
        # 예측 수행
        predictions, offset_mapping = predict_categories(preprocessed_inputs)
        # 정보 추출
        extracted_info = extract_information(predictions, offset_mapping, item.message)

        return {"result": extracted_info}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
  • 최종 결과:
{
    "result": {
        "METHOD": "신한체크취소",
        "LOCATION": "구글플레이",
        "TIME": "04/18 11:13",
        "COST": "3,700원"
    }
}

 

결과 테스트

API의 성능을 검증하기 위해 Postman을 사용하여 실제 문자 정보에 대한 정보 추출 결과를 테스트하였습니다.

테스트 결과, 방법, 위치, 시간, 비용 등의 정보가 텍스트에서 정확히 식별되고 추출되는 것을 확인할 수 있었습니다.

댓글수0