FAST  3.2.0
Framework for Heterogeneous Medical Image Computing and Visualization
Public Member Functions | Protected Attributes | List of all members
fast::TensorFlowEngine Class Reference

#include <TensorFlowEngine.hpp>

+ Inheritance diagram for fast::TensorFlowEngine:
+ Collaboration diagram for fast::TensorFlowEngine:

Public Member Functions

void load () override
 
void run () override
 
std::string getName () const override
 
 ~TensorFlowEngine () override
 
ImageOrdering getPreferredImageOrdering () const override
 
std::vector< ModelFormatgetSupportedModelFormats () const
 
ModelFormat getPreferredModelFormat () const
 
 TensorFlowEngine ()
 
std::vector< InferenceDeviceInfogetDeviceList ()
 
- Public Member Functions inherited from fast::InferenceEngine
virtual void setFilename (std::string filename)
 
virtual void setModelAndWeights (std::vector< uint8_t > model, std::vector< uint8_t > weights)
 
virtual std::string getFilename () const
 
virtual void addInputNode (uint portID, std::string name, NodeType type=NodeType::IMAGE, TensorShape shape={})
 
virtual void addOutputNode (uint portID, std::string name, NodeType type=NodeType::IMAGE, TensorShape shape={})
 
virtual void setInputNodeShape (std::string name, TensorShape shape)
 
virtual void setOutputNodeShape (std::string name, TensorShape shape)
 
virtual NetworkNode getInputNode (std::string name) const
 
virtual NetworkNode getOutputNode (std::string name) const
 
virtual std::unordered_map< std::string, NetworkNodegetOutputNodes () const
 
virtual std::unordered_map< std::string, NetworkNodegetInputNodes () const
 
virtual void setInputData (std::string inputNodeName, std::shared_ptr< Tensor > tensor)
 
virtual std::shared_ptr< TensorgetOutputData (std::string inputNodeName)
 
virtual bool isLoaded () const
 
virtual bool isModelFormatSupported (ModelFormat format)
 
virtual void setDeviceType (InferenceDeviceType type)
 
virtual void setDevice (int index=-1, InferenceDeviceType type=InferenceDeviceType::ANY)
 
virtual int getMaxBatchSize ()
 
virtual void setMaxBatchSize (int size)
 
- Public Member Functions inherited from fast::Object
 Object ()
 
virtual ~Object ()
 
ReportergetReporter ()
 

Protected Attributes

std::unique_ptr< tensorflow::Session > mSession
 
std::unique_ptr< tensorflow::SavedModelBundle > mSavedModelBundle
 
std::vector< std::string > mLearningPhaseTensors
 
- Protected Attributes inherited from fast::InferenceEngine
std::unordered_map< std::string, NetworkNodemInputNodes
 
std::unordered_map< std::string, NetworkNodemOutputNodes
 
int m_deviceIndex = -1
 
InferenceDeviceType m_deviceType = InferenceDeviceType::ANY
 
int m_maxBatchSize = 1
 
std::vector< uint8_t > m_model
 
std::vector< uint8_t > m_weights
 
- Protected Attributes inherited from fast::Object
std::weak_ptr< ObjectmPtr
 

Additional Inherited Members

- Public Types inherited from fast::InferenceEngine
typedef std::shared_ptr< InferenceEnginepointer
 
- Public Types inherited from fast::Object
typedef std::shared_ptr< Objectpointer
 
- Static Public Member Functions inherited from fast::Object
static std::string getStaticNameOfClass ()
 
- Protected Member Functions inherited from fast::InferenceEngine
virtual void setIsLoaded (bool loaded)
 
- Protected Member Functions inherited from fast::Object
ReporterreportError ()
 
ReporterreportWarning ()
 
ReporterreportInfo ()
 
ReporterEnd reportEnd () const
 

Constructor & Destructor Documentation

◆ ~TensorFlowEngine()

fast::TensorFlowEngine::~TensorFlowEngine ( )
override

◆ TensorFlowEngine()

fast::TensorFlowEngine::TensorFlowEngine ( )

Member Function Documentation

◆ getDeviceList()

std::vector<InferenceDeviceInfo> fast::TensorFlowEngine::getDeviceList ( )
virtual

Get a list of devices available for this inference engine.

Returns
vector with info on each device

Reimplemented from fast::InferenceEngine.

◆ getName()

std::string fast::TensorFlowEngine::getName ( ) const
overridevirtual

Implements fast::InferenceEngine.

◆ getPreferredImageOrdering()

ImageOrdering fast::TensorFlowEngine::getPreferredImageOrdering ( ) const
overridevirtual

Implements fast::InferenceEngine.

◆ getPreferredModelFormat()

ModelFormat fast::TensorFlowEngine::getPreferredModelFormat ( ) const
inlinevirtual

Implements fast::InferenceEngine.

◆ getSupportedModelFormats()

std::vector<ModelFormat> fast::TensorFlowEngine::getSupportedModelFormats ( ) const
inlinevirtual

Implements fast::InferenceEngine.

◆ load()

void fast::TensorFlowEngine::load ( )
overridevirtual

Implements fast::InferenceEngine.

◆ run()

void fast::TensorFlowEngine::run ( )
overridevirtual

Implements fast::InferenceEngine.

Member Data Documentation

◆ mLearningPhaseTensors

std::vector<std::string> fast::TensorFlowEngine::mLearningPhaseTensors
protected

◆ mSavedModelBundle

std::unique_ptr<tensorflow::SavedModelBundle> fast::TensorFlowEngine::mSavedModelBundle
protected

◆ mSession

std::unique_ptr<tensorflow::Session> fast::TensorFlowEngine::mSession
protected

The documentation for this class was generated from the following file: