Skip to content

Commit

Permalink
removed cuda dependency from pt converters
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishaan-Datta committed Sep 11, 2024
1 parent f1858cf commit cdd8a15
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
1 change: 0 additions & 1 deletion conversion_tools/PT_ONNX.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import onnx
# import onnxruntime as ort
import numpy as np
import pycuda.driver as cuda

OPSET_VERS = 13

Expand Down
6 changes: 3 additions & 3 deletions conversion_tools/PT_TRT.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import torch
from torch2trt import torch2trt
import pycuda.driver as cuda
import pycuda.autoinit
# import pycuda.driver as cuda
# import pycuda.autoinit
import numpy as np

def get_max_memory():
Expand All @@ -21,7 +21,7 @@ def convert_pt_to_trt(model_path='./model.pt', output_path='./model_trt.trt', FP

input_data = torch.randn(input_shape).cuda()
print("Building TensorRT engine. This may take a few minutes.")
model_trt = torch2trt(model, [input_data], fp16_mode=FP16_mode, max_batch_size=batch_size, max_workspace_size=get_max_memory())
model_trt = torch2trt(model, [input_data], fp16_mode=FP16_mode, max_batch_size=batch_size, max_workspace_size=15000000000) # get_max_memory()
# torch.save(model_trt.state_dict(), output_file)

with open(output_path, 'wb') as f:
Expand Down

0 comments on commit cdd8a15

Please sign in to comment.