FAST  3.2.0
Framework for Heterogeneous Medical Image Computing and Visualization
InferenceEngine.hpp
Go to the documentation of this file.
1 #pragma once
2 
4 #include <FAST/Data/Tensor.hpp>
6 
7 // This is a macro for creating a load function for a given inference engine
8 // Need C linkage here (extern "C" to avoid mangled names of the load function on windows, see https://stackoverflow.com/questions/19422550/why-getprocaddress-is-not-working
9 // Export statement is needed for Windows
10 #define DEFINE_INFERENCE_ENGINE(classType, exportStatement) \
11 extern "C" exportStatement \
12 InferenceEngine* load() { \
13  return new classType(); \
14 } \
15 
16 namespace fast {
17 
21 enum class ImageOrdering {
24 };
25 
26 enum class NodeType {
27  IMAGE,
28  TENSOR,
29 };
30 
31 enum class InferenceDeviceType {
32  ANY,
33  CPU,
34  GPU,
35  VPU,
36  OTHER,
37 };
38 
40  std::string name;
42  int index;
43 };
44 
48 enum class ModelFormat {
49  PROTOBUF,
50  SAVEDMODEL,
51  ONNX,
52  OPENVINO,
53  UFF
54 };
55 
59 FAST_EXPORT std::string getModelFileExtension(ModelFormat format);
60 
64 FAST_EXPORT ModelFormat getModelFormat(std::string filename);
65 
71 FAST_EXPORT std::string getModelFormatName(ModelFormat format);
72 
76 class FAST_EXPORT InferenceEngine : public Object {
77  public:
78  typedef std::shared_ptr<InferenceEngine> pointer;
79  struct NetworkNode {
83  std::shared_ptr<Tensor> data;
84  };
85  virtual void setFilename(std::string filename);
86  virtual void setModelAndWeights(std::vector<uint8_t> model, std::vector<uint8_t> weights);
87  virtual std::string getFilename() const;
88  virtual void run() = 0;
89  virtual void addInputNode(uint portID, std::string name, NodeType type = NodeType::IMAGE, TensorShape shape = {});
90  virtual void addOutputNode(uint portID, std::string name, NodeType type = NodeType::IMAGE, TensorShape shape = {});
91  virtual void setInputNodeShape(std::string name, TensorShape shape);
92  virtual void setOutputNodeShape(std::string name, TensorShape shape);
93  virtual NetworkNode getInputNode(std::string name) const;
94  virtual NetworkNode getOutputNode(std::string name) const;
95  virtual std::unordered_map<std::string, NetworkNode> getOutputNodes() const;
96  virtual std::unordered_map<std::string, NetworkNode> getInputNodes() const;
97  virtual void setInputData(std::string inputNodeName, std::shared_ptr<Tensor> tensor);
98  virtual std::shared_ptr<Tensor> getOutputData(std::string inputNodeName);
99  virtual void load() = 0;
100  virtual bool isLoaded() const;
101  virtual ImageOrdering getPreferredImageOrdering() const = 0;
102  virtual std::string getName() const = 0;
103  virtual std::vector<ModelFormat> getSupportedModelFormats() const = 0;
104  virtual ModelFormat getPreferredModelFormat() const = 0;
105  virtual bool isModelFormatSupported(ModelFormat format);
111  virtual void setDeviceType(InferenceDeviceType type);
117  virtual void setDevice(int index = -1, InferenceDeviceType type = InferenceDeviceType::ANY);
123  virtual std::vector<InferenceDeviceInfo> getDeviceList();
124 
125  virtual int getMaxBatchSize();
126  virtual void setMaxBatchSize(int size);
127  protected:
128  virtual void setIsLoaded(bool loaded);
129 
130  std::unordered_map<std::string, NetworkNode> mInputNodes;
131  std::unordered_map<std::string, NetworkNode> mOutputNodes;
132 
133  int m_deviceIndex = -1;
135  int m_maxBatchSize = 1;
136 
137  std::vector<uint8_t> m_model;
138  std::vector<uint8_t> m_weights;
139  private:
140  std::string m_filename = "";
141  bool m_isLoaded = false;
142 };
143 
144 }
InferenceEngine
Definition: OpenVINOEngine.hpp:6
fast::getModelFormatName
FAST_EXPORT std::string getModelFormatName(ModelFormat format)
fast::NodeType
NodeType
Definition: InferenceEngine.hpp:26
fast::ImageOrdering::ChannelFirst
@ ChannelFirst
fast::NodeType::TENSOR
@ TENSOR
fast::InferenceEngine::pointer
std::shared_ptr< InferenceEngine > pointer
Definition: InferenceEngine.hpp:78
fast::TensorShape
Definition: TensorShape.hpp:9
TensorShape.hpp
fast::ModelFormat::UFF
@ UFF
fast
Definition: AffineTransformation.hpp:7
fast::ImageOrdering::ChannelLast
@ ChannelLast
fast::InferenceDeviceInfo
Definition: InferenceEngine.hpp:39
fast::ModelFormat::SAVEDMODEL
@ SAVEDMODEL
fast::InferenceEngine::m_model
std::vector< uint8_t > m_model
Definition: InferenceEngine.hpp:137
fast::InferenceEngine::NetworkNode::shape
TensorShape shape
Definition: InferenceEngine.hpp:82
fast::InferenceDeviceInfo::name
std::string name
Definition: InferenceEngine.hpp:40
fast::Object
Definition: Object.hpp:34
fast::InferenceDeviceType::OTHER
@ OTHER
fast::InferenceDeviceInfo::type
InferenceDeviceType type
Definition: InferenceEngine.hpp:41
fast::InferenceDeviceType::VPU
@ VPU
fast::InferenceEngine::mInputNodes
std::unordered_map< std::string, NetworkNode > mInputNodes
Definition: InferenceEngine.hpp:130
fast::getModelFormat
FAST_EXPORT ModelFormat getModelFormat(std::string filename)
fast::InferenceDeviceType::GPU
@ GPU
fast::format
std::string format(std::string format, Args &&... args)
Definition: Utility.hpp:33
fast::ModelFormat
ModelFormat
Definition: InferenceEngine.hpp:48
fast::InferenceDeviceInfo::index
int index
Definition: InferenceEngine.hpp:42
fast::InferenceEngine::NetworkNode::type
NodeType type
Definition: InferenceEngine.hpp:81
fast::ModelFormat::PROTOBUF
@ PROTOBUF
DataTypes.hpp
fast::InferenceEngine::m_weights
std::vector< uint8_t > m_weights
Definition: InferenceEngine.hpp:138
fast::InferenceDeviceType
InferenceDeviceType
Definition: InferenceEngine.hpp:31
fast::ModelFormat::ONNX
@ ONNX
fast::getModelFileExtension
FAST_EXPORT std::string getModelFileExtension(ModelFormat format)
fast::NodeType::IMAGE
@ IMAGE
Tensor.hpp
fast::InferenceEngine::NetworkNode::portID
uint portID
Definition: InferenceEngine.hpp:80
fast::ModelFormat::OPENVINO
@ OPENVINO
fast::InferenceDeviceType::CPU
@ CPU
fast::ImageOrdering
ImageOrdering
Definition: InferenceEngine.hpp:21
uint
unsigned int uint
Definition: DataTypes.hpp:16
fast::InferenceDeviceType::ANY
@ ANY
fast::InferenceEngine::mOutputNodes
std::unordered_map< std::string, NetworkNode > mOutputNodes
Definition: InferenceEngine.hpp:131
fast::InferenceEngine::NetworkNode::data
std::shared_ptr< Tensor > data
Definition: InferenceEngine.hpp:83
fast::InferenceEngine::NetworkNode
Definition: InferenceEngine.hpp:79