FAST  3.2.0
Framework for Heterogeneous Medical Image Computing and Visualization
TensorFlowEngine.hpp
Go to the documentation of this file.
1 #pragma once
2 
4 #include <TensorFlowExport.hpp>
5 
6 // Forward declare
7 namespace tensorflow {
8 class Session;
9 class SavedModelBundle;
10 }
11 
12 namespace fast {
13 
15  public:
16  void load() override;
17  void run() override;
18  std::string getName() const override;
19  ~TensorFlowEngine() override;
21  std::vector<ModelFormat> getSupportedModelFormats() const {
23  };
24 
26  return ModelFormat::PROTOBUF;
27  };
28 
35  std::vector<InferenceDeviceInfo> getDeviceList();
36  protected:
37  std::unique_ptr<tensorflow::Session> mSession;
38  std::unique_ptr<tensorflow::SavedModelBundle> mSavedModelBundle;
39  std::vector<std::string> mLearningPhaseTensors;
40 
41 };
42 
43 DEFINE_INFERENCE_ENGINE(TensorFlowEngine, INFERENCEENGINETENSORFLOW_EXPORT)
44 
45 
46 // Forward declare
47 class TensorFlowTensorWrapper;
48 
52 class TensorFlowTensor : public Tensor {
54  public:
55  void create(TensorFlowTensorWrapper* tensorflowTensor);
57  private:
58  TensorFlowTensorWrapper* m_tensorflowTensor;
59  float* getHostDataPointer() override;
60  bool hasAnyData() override;
61 };
62 
63 }
InferenceEngine
Definition: OpenVINOEngine.hpp:6
fast::TensorFlowTensor::~TensorFlowTensor
~TensorFlowTensor()
fast::TensorFlowTensor::create
void create(TensorFlowTensorWrapper *tensorflowTensor)
fast::TensorFlowEngine::getDeviceList
std::vector< InferenceDeviceInfo > getDeviceList()
fast
Definition: AffineTransformation.hpp:7
fast::ModelFormat::SAVEDMODEL
@ SAVEDMODEL
fast::TensorFlowEngine::getSupportedModelFormats
std::vector< ModelFormat > getSupportedModelFormats() const
Definition: TensorFlowEngine.hpp:21
FAST_OBJECT
#define FAST_OBJECT(className)
Definition: Object.hpp:9
fast::Tensor
Definition: Tensor.hpp:12
fast::ModelFormat
ModelFormat
Definition: InferenceEngine.hpp:48
fast::TensorFlowEngine::mSession
std::unique_ptr< tensorflow::Session > mSession
Definition: TensorFlowEngine.hpp:37
fast::TensorFlowEngine::load
void load() override
DEFINE_INFERENCE_ENGINE
#define DEFINE_INFERENCE_ENGINE(classType, exportStatement)
Definition: InferenceEngine.hpp:10
fast::TensorFlowEngine::getPreferredImageOrdering
ImageOrdering getPreferredImageOrdering() const override
fast::ModelFormat::PROTOBUF
@ PROTOBUF
fast::TensorFlowEngine::getPreferredModelFormat
ModelFormat getPreferredModelFormat() const
Definition: TensorFlowEngine.hpp:25
fast::TensorFlowEngine::mLearningPhaseTensors
std::vector< std::string > mLearningPhaseTensors
Definition: TensorFlowEngine.hpp:39
fast::TensorFlowEngine::~TensorFlowEngine
~TensorFlowEngine() override
fast::TensorFlowEngine::getName
std::string getName() const override
fast::TensorFlowTensor
Definition: TensorFlowEngine.hpp:52
InferenceEngine.hpp
tensorflow
Definition: TensorFlowEngine.hpp:7
fast::ImageOrdering
ImageOrdering
Definition: InferenceEngine.hpp:21
fast::TensorFlowEngine
Definition: TensorFlowEngine.hpp:14
fast::TensorFlowEngine::mSavedModelBundle
std::unique_ptr< tensorflow::SavedModelBundle > mSavedModelBundle
Definition: TensorFlowEngine.hpp:38
fast::TensorFlowEngine::run
void run() override
fast::TensorFlowEngine::TensorFlowEngine
TensorFlowEngine()