본문 바로가기
AI

PyTorch 모델을 ONNX으로 변환

by icebear3000 2023. 3. 31.
반응형

ONNX 런타임은 ONNX 모델을 위한 엔진으로서 성능에 초점을 맞추고 있고 여러 다양한 플랫폼과 하드웨어(윈도우, 리눅스, 맥을 비롯한 플랫폼 뿐만 아니라 CPU, GPU 등의 하드웨어)에서 효율적인 추론을 가능하게 합니다. 이 튜토리얼을 진행하기 위해서는 ONNX 와 ONNX Runtime 설치가 필요합니다. ONNX와 ONNX 런타임의 바이너리 빌드를 pip install onnx onnxruntime 를 통해 받을 수 있습니다.

 

Load the PyTorch model

import torch
import torch.onnx
import onnx


path_to_pytorch_model = 'path_to_pytorch_model.pth'
model = torch.load(path_to_pytorch_model)
model.eval()  # Set the model to evaluation mode

모델 내보내기(Export the model to ONNX)

 모델을 ONNX로 내보내려면 올바른 모양의 예제 입력 텐서를 제공해야 합니다. input_shape를 LSTR 모델에 적합한 입력 모양으로 바꿉니다.

onnx_model_path = 'lstr_model.onnx'
input_shape = (1, 3, 224, 224)  # Replace with the correct input shape for the LSTR model
example_input = torch.randn(input_shape)

# Export the model
torch.onnx.export(model, example_input, onnx_model_path, export_params=True, opset_version=11)

 

Verify the ONNX model
 ONNX 모델을 내보낸 후 ONNX 패키지를 사용하여 유효한지 확인할 수 있습니다.

onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
print("The ONNX model is valid.")

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

반응형

댓글