fast::InferenceEngine class

Abstract class for neural network inference engines (TensorFlow, TensorRT ++)

Base classes

class Object
Base class for all FAST objects.

Derived classes

class ONNXRuntimeEngine
Microsofts ONNX Runtime inference engine with DirectX/ML support.
class OpenVINOEngine
class TensorFlowEngine
class TensorRTEngine

Public types

using pointer = std::shared_ptr<InferenceEngine>

Public static functions

static auto detectNodeType(const TensorShape& shape) -> NodeType
Detect node type from shape.
static auto detectImageOrdering(const TensorShape& shape, bool hasBatchDim = true) -> ImageOrdering
Detect image ordering from shape.

Public functions

void setFilename(std::string filename) virtual
void setModelAndWeights(std::vector<uint8_t> model, std::vector<uint8_t> weights) virtual
auto getFilename() const -> std::string virtual
void run() pure virtual
void addInputNode(NeuralNetworkNode node) virtual
void addOutputNode(NeuralNetworkNode node) virtual
void setInputNodeShape(std::string name, TensorShape shape) virtual
void setOutputNodeShape(std::string name, TensorShape shape) virtual
auto getInputNode(std::string name) const -> NeuralNetworkNode virtual
auto getOutputNode(std::string name) const -> NeuralNetworkNode virtual
auto getOutputNodes() const -> std::map<std::string, NeuralNetworkNode> virtual
auto getInputNodes() const -> std::map<std::string, NeuralNetworkNode> virtual
void setInputData(std::string inputNodeName, std::shared_ptr<Tensor> tensor) virtual
auto getOutputData(std::string inputNodeName) -> std::shared_ptr<Tensor> virtual
void load() pure virtual
auto isLoaded() const -> bool virtual
auto getPreferredImageOrdering() const -> ImageOrdering pure virtual
auto getName() const -> std::string pure virtual
auto getSupportedModelFormats() const -> std::vector<ModelFormat> pure virtual
auto getPreferredModelFormat() const -> ModelFormat pure virtual
auto isModelFormatSupported(ModelFormat format) -> bool virtual
void setDeviceType(InferenceDeviceType type) virtual
void setDevice(int index = -1, InferenceDeviceType type = InferenceDeviceType::ANY) virtual
auto getDeviceList() -> std::vector<InferenceDeviceInfo> virtual
auto getMaxBatchSize() -> int virtual
void setMaxBatchSize(int size) virtual
void loadCustomPlugins(std::vector<std::string> filenames) virtual
void setImageOrdering(ImageOrdering ordering) virtual
Set dimension image ordering manually. E.g. channel last or channel-first.

Protected functions

void setIsLoaded(bool loaded) virtual

Protected variables

std::map<std::string, NeuralNetworkNode> mInputNodes
std::map<std::string, NeuralNetworkNode> mOutputNodes
int m_deviceIndex
InferenceDeviceType m_deviceType
int m_maxBatchSize
std::vector<uint8_t> m_model
std::vector<uint8_t> m_weights
ImageOrdering m_imageOrdering

Function documentation

static NodeType fast::InferenceEngine::detectNodeType(const TensorShape& shape)

Detect node type from shape.


static ImageOrdering fast::InferenceEngine::detectImageOrdering(const TensorShape& shape, bool hasBatchDim = true)

Detect image ordering from shape.

shape shape to check
hasBatchDim Whether first dimension is batch dimension

void fast::InferenceEngine::setDeviceType(InferenceDeviceType type) virtual


Set which device type the inference engine should use (assuming the IE supports multiple devices like OpenVINO)

void fast::InferenceEngine::setDevice(int index = -1, InferenceDeviceType type = InferenceDeviceType::ANY) virtual

index Index of the device to use. -1 means any device can be used

Specify which device index and/or device type to use

std::vector<InferenceDeviceInfo> fast::InferenceEngine::getDeviceList() virtual

Returns vector with info on each device

Get a list of devices available for this inference engine.

void fast::InferenceEngine::loadCustomPlugins(std::vector<std::string> filenames) virtual

Load a custom operator (op), plugin. Must be called before load()

void fast::InferenceEngine::setImageOrdering(ImageOrdering ordering) virtual

Set dimension image ordering manually. E.g. channel last or channel-first.