neuralNetworkUltrasoundSegmentation.cpp source
#include <FAST/Tools/CommandLineParser.hpp> #include <FAST/Algorithms/NeuralNetwork/SegmentationNetwork.hpp> #include <FAST/Streamers/ImageFileStreamer.hpp> #include <FAST/Visualization/SimpleWindow.hpp> #include <FAST/Visualization/ImageRenderer/ImageRenderer.hpp> #include <FAST/Visualization/SegmentationRenderer/SegmentationRenderer.hpp> #include <FAST/Visualization/SegmentationLabelRenderer/SegmentationLabelRenderer.hpp> #include <FAST/Algorithms/UltrasoundImageCropper/UltrasoundImageCropper.hpp> #include <FAST/Algorithms/NeuralNetwork/InferenceEngineManager.hpp> #include <FAST/Visualization/Widgets/PlaybackWidget/PlaybackWidget.hpp> using namespace fast; int main(int argc, char** argv) { CommandLineParser parser("Neural network ultrasound segmentation example"); parser.addChoice("inference-engine", {"default", "OpenVINO", "TensorFlow", "TensorRT", "ONNXRuntime"}, "default", "Which neural network inference engine to use"); parser.addVariable("filename", Config::getTestDataPath() + "US/JugularVein/US-2D_#.mhd", "Path to files to stream from disk"); parser.addVariable("filename-timestamps", Config::getTestDataPath() + "US/JugularVein/timestamps.fts", "Path to a file with timestamps related to 'filename'"); parser.parse(argc, argv); auto streamer = ImageFileStreamer::create(parser.get("filename"), true); streamer->setTimestampFilename(parser.get("filename-timestamps")); InferenceEngine::pointer engine; if(parser.get("inference-engine") == "default") { engine = InferenceEngineManager::loadBestAvailableEngine(); } else { engine = InferenceEngineManager::loadEngine(parser.get("inference-engine")); } auto segmentation = SegmentationNetwork::create( join(Config::getTestDataPath(), "NeuralNetworkModels/jugular_vein_segmentation." + getModelFileExtension(engine->getPreferredModelFormat())), {}, {}, engine->getName()) ->connect(streamer); segmentation->setScaleFactor(1.0f / 255.0f); segmentation->enableRuntimeMeasurements(); auto segmentationRenderer = SegmentationRenderer::create({{1, Color::Red()}, {2, Color::Blue()}}, 0.25) ->connect(segmentation); auto labelRenderer = SegmentationLabelRenderer::create( {{1, "Artery"}, {2, "Vein"}}, {{1, Color::Red()}, {2, Color::Blue()}}, 10) ->connect(segmentation); auto imageRenderer = ImageRenderer::create() ->connect(streamer); auto window = SimpleWindow2D::create(Color::Black()) ->connect({imageRenderer, segmentationRenderer, labelRenderer}) ->connect(new PlaybackWidget(streamer)); window->run(); segmentation->getAllRuntimes()->printAll(); }