Python API Documentation#
The TensorRT Python API enables developers in Python-based development environments, and those looking to experiment with TensorRT, to easily parse models (for example, from ONNX) and generate and run PLAN files.
To view API changes between releases, refer to the TensorRT GitHub repository and use the compare tool.
This section illustrates the basic usage of the Python API, assuming you are starting with an ONNX model. The onnx_resnet50.py sample illustrates this use case in more detail. For the same workflow using the C++ API, refer to C++ API Documentation.
The Python API can be accessed through the tensorrt module:
import tensorrt as trt
The Build Phase#
The build phase uses the builder to optimize a model and produce an engine. To create a builder:
Create a logger. The Python bindings include a simple logger implementation that logs all messages preceding a certain severity to
stdout:
logger = trt.Logger(trt.Logger.WARNING)
Note
Alternatively, define your own logger by deriving from the ILogger class:
class MyLogger(trt.ILogger):
def __init__(self):
trt.ILogger.__init__(self)
def log(self, severity, msg):
pass # Your custom logging implementation here
logger = MyLogger()
Create the builder:
builder = trt.Builder(logger)
Building engines is intended as an offline process, so it can take significant time. The Reducing Engine Build Time section has tips on making the builder run faster.
Creating a Network Definition#
After the builder has been created, the first step in optimizing a model is to create a network definition:
Specify the network creation options using a combination of flags OR’d together (or
0for none). Note that all networks are strongly typed in TensorRT 11, so you need not set theSTRONGLY_TYPEDflag (a warning will be emitted if you do). For more information, refer to the Strongly Typed Networks section.Create the network:
network = builder.create_network(flag)
Creating a Network Definition from Scratch (Advanced)
Instead of using a parser, you can define the network directly to TensorRT through the Network Definition API.
The code corresponding to this section can be found in network_api_pytorch_mnist.
This example creates a simple network with Input, Convolution, Pooling, MatrixMultiply, Shuffle, Activation, and Softmax layers.
Define a helper class to hold model metadata:
class ModelData(object):
INPUT_NAME = "data"
INPUT_SHAPE = (1, 1, 28, 28)
OUTPUT_NAME = "prob"
OUTPUT_SIZE = 10
DTYPE = trt.float32
Import the weights from the PyTorch MNIST model:
weights = mnist_model.get_weights()
Create the logger, builder, and network:
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(0)
Create the input tensor, specifying the name, datatype, and shape:
input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)
Add a convolution layer, specifying the inputs, number of output maps, kernel shape, weights, bias, and stride:
conv1_w = weights["conv1.weight"].cpu().numpy()
conv1_b = weights["conv1.bias"].cpu().numpy()
conv1 = network.add_convolution_nd(
input=input_tensor, num_output_maps=20, kernel_shape=(5, 5), kernel=conv1_w, bias=conv1_b
)
conv1.stride_nd = (1, 1)
Add a pooling layer, specifying the inputs, pooling type, window size, and stride:
pool1 = network.add_pooling_nd(input=conv1.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2))
pool1.stride_nd = trt.Dims2(2, 2)
Add the next pair of convolution and pooling layers:
conv2_w = weights["conv2.weight"].cpu().numpy()
conv2_b = weights["conv2.bias"].cpu().numpy()
conv2 = network.add_convolution_nd(pool1.get_output(0), 50, (5, 5), conv2_w, conv2_b)
conv2.stride_nd = (1, 1)
pool2 = network.add_pooling_nd(conv2.get_output(0), trt.PoolingType.MAX, (2, 2))
pool2.stride_nd = trt.Dims2(2, 2)
Add a Shuffle layer to reshape the input in preparation for matrix multiplication:
# Get the output from the previous pooling layer
pool2_output = pool2.get_output(0)
batch = pool2_output.shape[0]
mm_inputs = np.prod(pool2_output.shape[1:])
input_reshape = network.add_shuffle(pool2_output)
input_reshape.reshape_dims = trt.Dims2(batch, mm_inputs)
Add a MatrixMultiply layer. The model exporter provided transposed weights, so the
kTRANSPOSEoption is specified:
fc1_w = weights['fc1.weight'].numpy()
fc1_const = network.add_constant(trt.Dims2(nbOutputs, mm_inputs), fc1_w)
mm = network.add_matrix_multiply(input_reshape.get_output(0), trt.MatrixOperation.NONE,
fc1_const.get_output(0), trt.MatrixOperation.TRANSPOSE)
Add bias, which will broadcast across the batch dimension:
bias_const = network.add_constant(trt.Dims2(1, nbOutputs), weights["fc1.bias"].numpy())
bias_add = network.add_elementwise(mm.get_output(0), bias_const.get_output(0), trt.ElementWiseOperation.SUM)
Add the ReLU activation layer:
relu1 = network.add_activation(input=bias_add.get_output(0), type=trt.ActivationType.RELU)
Add the final fully connected layer, and mark the output of this layer as the output of the entire network:
fc2_w = weights['fc2.weight'].numpy()
fc2_b = weights['fc2.bias'].numpy()
fc2 = add_matmul_as_fc(network, relu1.get_output(0), ModelData.OUTPUT_SIZE, fc2_w, fc2_b)
fc2.get_output(0).name = ModelData.OUTPUT_NAME
network.mark_output(tensor=fc2.get_output(0))
The network representing the MNIST model has now been fully constructed. For instructions on how to build an engine and run an inference with this network, refer to the Building an Engine and Performing Inference sections.
For more information regarding layers, refer to the TensorRT Operator documentation.
Importing a Model Using the ONNX Parser#
Now, the network definition must be populated from the ONNX representation.
Create an ONNX parser to populate the network.
parser = trt.OnnxParser(network, logger)
Read the model file and process any errors.
success = parser.parse_from_file(model_path)
for idx in range(parser.num_errors):
print(parser.get_error(idx))
if not success:
pass # Error handling code here
An important aspect of a TensorRT network definition is that it contains pointers to model weights, which the builder copies into the optimized engine. Since the network was created using the parser, the parser owns the memory occupied by the weights, so the parser object should not be deleted until after the builder has run.
Importing a Model Using the ONNX Parser with Custom Weights#
The ONNX parser API allows users to provide their own weights, also known as ONNX initializers, in host memory to override any weights found in the model. Instead of parsing the model immediately, use the following sequence:
parser = trt.OnnxParser(network, logger)
Load the model into the parser.
status = parser.load_model_proto(model)
assert status
Provide the name, pointer to data, and size of the data for the parser to use instead of the one found in the model. You can call this step multiple times to override multiple weights. These pointers must remain in scope until the parser is destroyed.
status = parser.load_initializer(name, data, dataSize)
assert status
Begin parsing with the user-defined weights.
status = parser.parse_model_proto()
assert status
The same idea extends to the IParserRefitter class, and similar APIs can be used to provide custom weights when refitting an engine built from an ONNX model. For more information, refer to the Refitting a Weight-Stripped Engine Directly from ONNX section.
Building an Engine#
The next step is to create a build configuration specifying how TensorRT should optimize the model:
Create a builder configuration object. This interface has many properties that you can set to control how TensorRT optimizes the network.
config = builder.create_builder_config()
Set the maximum workspace size. Layer implementations often require a temporary workspace, and this parameter limits the maximum size that any layer in the network can use. If insufficient workspace is provided, TensorRT might not be able to find an implementation for a layer. By default, the workspace is set to the total global memory size of the given device; restrict it when necessary, such as when multiple engines are to be built on a single device.
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1 GiB
Build the serialized engine:
serialized_engine = builder.build_serialized_network(network, config)
Save the serialized engine to disk for future use:
with open("sample.engine", "wb") as f:
f.write(serialized_engine)
Note
Serialized engines are not cross-platform portable. They are specific to the exact GPU model on which they were built (in addition to the platform).
Building engines is intended as an offline process, so it can take significant time. The Reducing Engine Build Time section has tips on making the builder run faster.
Deserializing a Plan#
To load a previously serialized plan and run inference:
Create a runtime. Like the builder, the runtime requires a logger:
runtime = trt.Runtime(logger)
Warning
Engine files are executable artifacts that contain compiled CUDA tactics. Deserialize only engines you built yourself or received over a trusted channel. Never deserialize an engine file from an untrusted source.
After creating the runtime, deserialize the plan using one of the following approaches:
Note
The Python API also supports IStreamReader for deserialization, which is now deprecated.
Deserialize from an in-memory buffer. This method is straightforward and suitable for smaller models or when memory is not a constraint:
with open("model.plan", "rb") as f:
model_data = f.read()
engine = runtime.deserialize_cuda_engine(model_data)
Alternatively, deserialize directly from a serialized engine object returned by the builder:
serialized_engine = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(serialized_engine)
Deserialize using
IStreamReaderV2. This method supports custom file handling and weight streaming, and can reduce peak memory usage by reading the plan in chunks as needed:
class MyStreamReaderV2(trt.IStreamReaderV2):
def __init__(self, bytes):
trt.IStreamReaderV2.__init__(self)
self.bytes = bytes
self.len = len(bytes)
self.index = 0
def read(self, size, cudaStreamPtr):
assert self.index + size <= self.len
data = self.bytes[self.index:self.index + size]
self.index += size
return data
def seek(self, offset, where):
if where == SeekPosition.SET:
self.index = offset
elif where == SeekPosition.CUR:
self.index += offset
elif where == SeekPosition.END:
self.index = self.len - offset
else:
raise ValueError(f"Invalid seek position: {where}")
with open("model.plan", "rb") as f:
plan_data = f.read()
reader_v2 = MyStreamReaderV2(plan_data)
engine = runtime.deserialize_cuda_engine(reader_v2)
The IStreamReaderV2 approach is particularly beneficial for large models or when using advanced features like GPUDirect or weight streaming. It can significantly reduce engine load time and memory usage.
When choosing a deserialization method, consider your specific requirements:
For small models or simple use cases, in-memory deserialization is often sufficient.
For large models or when memory efficiency is crucial, consider using
IStreamReaderV2.If you need custom file handling or weight streaming capabilities,
IStreamReaderV2provides the necessary flexibility.
Performing Inference#
After the engine is loaded, run inference through an execution context:
For production threading boundaries, refer to the Thread-Safety Deny-List. For memory bounding and multi-tenant OOM prevention, refer to Bounding TensorRT Memory in Production. For CUDA error isolation between tenants, refer to Cross-Context CUDA Error Isolation.
Create an execution context:
context = engine.create_execution_context()
An engine can have multiple execution contexts, allowing one set of weights to be used for multiple overlapping inference tasks.
Set device buffer addresses for input and output tensors using
set_tensor_address:
context.set_tensor_address(name, ptr)
Note
Several Python packages allow you to allocate memory on the GPU, including, but not limited to, the official CUDA Python bindings, PyTorch, cuPy, and Numba.
Create a CUDA stream for asynchronous execution. If you already have one—for example, a PyTorch
torch.cuda.Stream()—use itscuda_streamproperty; for Polygraphy streams, use theptrattribute; or create one withcudaStreamCreate()from the CUDA Python bindings:
from cuda.bindings import runtime as cudart
err, stream_ptr = cudart.cudaStreamCreate()
Populate the input buffer and start inference using
execute_async_v3:
context.execute_async_v3(stream_ptr)
A network will be executed asynchronously or not, depending on the structure and features of the network. A non-exhaustive list of features that can cause synchronous behavior are data-dependent shapes, DLA usage, loops, and synchronous plugins. It is common to enqueue data transfers with cudaMemcpyAsync() before and after the kernels to move data from the GPU if it is not already there.
To determine when the kernels (and possibly cudaMemcpyAsync()) are complete, use standard CUDA synchronization mechanisms such as events or waiting on the stream.
Complete End-to-End Example#
The snippets above introduce each Python API class in isolation. The script below stitches them into a single copy-paste-ready program that goes from an ONNX file to a benchmarked engine to a live inference call. Save it as trt_end_to_end.py and run it against any ONNX model. For example, the ResNet-50 v1 model referenced earlier in this section.
import sys
import numpy as np
import tensorrt as trt
from cuda.bindings import runtime as cudart
def _check(err):
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"CUDA error: {err}")
def build_engine(onnx_path: str, engine_path: str) -> bytes:
"""Build a strongly typed TensorRT engine from an ONNX file and save it."""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(onnx_path):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError(f"Failed to parse {onnx_path}")
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1 GiB
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("Engine build failed")
with open(engine_path, "wb") as f:
f.write(serialized)
return bytes(serialized)
def run_inference(engine_bytes: bytes, input_array: np.ndarray) -> np.ndarray:
"""Deserialize the engine and run a single inference on the supplied input."""
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(engine_bytes)
context = engine.create_execution_context()
input_name = engine.get_tensor_name(0)
output_name = engine.get_tensor_name(1)
context.set_input_shape(input_name, input_array.shape)
output_shape = tuple(context.get_tensor_shape(output_name))
host_input = np.ascontiguousarray(input_array, dtype=np.float32)
host_output = np.empty(output_shape, dtype=np.float32)
err, d_input = cudart.cudaMalloc(host_input.nbytes); _check(err)
err, d_output = cudart.cudaMalloc(host_output.nbytes); _check(err)
err, stream = cudart.cudaStreamCreate(); _check(err)
_check(cudart.cudaMemcpyAsync(
d_input, host_input.ctypes.data, host_input.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream,
))
context.set_tensor_address(input_name, int(d_input))
context.set_tensor_address(output_name, int(d_output))
context.execute_async_v3(stream)
_check(cudart.cudaMemcpyAsync(
host_output.ctypes.data, d_output, host_output.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream,
))
_check(cudart.cudaStreamSynchronize(stream))
cudart.cudaFree(d_input)
cudart.cudaFree(d_output)
cudart.cudaStreamDestroy(stream)
return host_output
if __name__ == "__main__":
onnx_path = sys.argv[1] if len(sys.argv) > 1 else "resnet50-v1-12.onnx"
engine_path = "model.engine"
engine_bytes = build_engine(onnx_path, engine_path)
# Replace this with a preprocessed image batch in your own application.
dummy_input = np.random.rand(1, 3, 224, 224).astype(np.float32)
output = run_inference(engine_bytes, dummy_input)
print("Output shape:", output.shape)
print("Top-1 class index:", int(np.argmax(output[0])))
Run the script with the ONNX file you want to benchmark; on first invocation it builds and serializes model.engine, then runs one inference on a random tensor sized for ResNet-50.
python3 trt_end_to_end.py resnet50-v1-12.onnx
Once you have a working baseline, swap the random input for a preprocessed image batch, plug the program into your application, and use the focused sections above when you need finer control over network construction, custom weights, alternate deserialization paths, or asynchronous inference scheduling.
Next Steps#
See also
- Optimizing Performance
Benchmarking methodology and best practices for maximizing inference throughput and latency.
- Working with Dynamic Shapes
Building engines that handle variable input dimensions at runtime.
- Accuracy Considerations
Understanding precision trade-offs and mitigating accuracy loss with reduced precision.
- Working with Quantized Types
INT8, FP8, and FP4 quantization workflows including PTQ and QAT.
- Advanced Features
Strongly typed networks, layer precision control, DLA, and custom plugins.