Learn how to draw a segmentation mask on a slice with the click of a button.
May 27, 2024
10 min Read
By: Abhilaksh Singh Reen
In Part 1, we saw how to create a basic extension and module in 3D Slicer that could load a volume for us. Today, we'll create a segment on that loaded volume and update it using Python.
To our previous UI, we want to add a Horizontal Box below the Load Volume button. We will have 5 items in this box:
1) A Text Box (LineEdit) for the X position.
1) Another Text Box for the Y position.
1) One last Text Box for the size.
1) A button to draw a circle.
1) A button to draw a square.
We can click the Edit UI button in Slicer under Reload & Test to open up our UI in Qt Designer. I will be making these changes directly in the TutorialModule/Resources/UI/TutorialModule.ui file. Add the following lines after the item block containing the Load Volume QPushButton.
<item>
<layout class="QHBoxLayout" name="serverConfigQHBoxLayout">
<item>
<widget class="QLineEdit" name="shapePositionXTextBox">
<property name="placeholderText">
<string>Position X</string>
</property>
</widget>
</item>
<item>
<widget class="QLineEdit" name="shapePositionYTextBox">
<property name="placeholderText">
<string>Position Y</string>
</property>
</widget>
</item>
<item>
<widget class="QLineEdit" name="shapeSizeTextBox">
<property name="placeholderText">
<string>Size</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="drawCircleButton">
<property name="text">
<string>Circle</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="drawSquareButton">
<property name="text">
<string>Square</string>
</property>
</widget>
</item>
</layout>
</item>
The entire TutorialModule.ui file should now look something like the following
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>TutorialModule</class>
<widget class="qMRMLWidget" name="TutorialModule">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>316</width>
<height>338</height>
</rect>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<widget class="QLineEdit" name="volumePathTextBox"/>
</item>
<item>
<widget class="QPushButton" name="loadVolumeButton">
<property name="text">
<string>Load Volume</string>
</property>
</widget>
</item>
<item>
<layout class="QHBoxLayout" name="serverConfigQHBoxLayout">
<item>
<widget class="QLineEdit" name="shapePositionXTextBox">
<property name="placeholderText">
<string>Position X</string>
</property>
</widget>
</item>
<item>
<widget class="QLineEdit" name="shapePositionYTextBox">
<property name="placeholderText">
<string>Position Y</string>
</property>
</widget>
</item>
<item>
<widget class="QLineEdit" name="shapeSizeTextBox">
<property name="placeholderText">
<string>Size</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="drawCircleButton">
<property name="text">
<string>Circle</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="drawSquareButton">
<property name="text">
<string>Square</string>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="qMRMLSegmentEditorWidget" name="embeddedSegmentEditorWidget">
<property name="autoShowSourceVolumeNode">
<bool>true</bool>
</property>
<property name="maximumNumberOfUndoStates">
<number>10</number>
</property>
</widget>
</item>
<item>
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</property>
</spacer>
</item>
</layout>
</widget>
<customwidgets>
<customwidget>
<class>qMRMLWidget</class>
<extends>QWidget</extends>
<header>qMRMLWidget.h</header>
<container>1</container>
</customwidget>
</customwidgets>
<resources/>
<connections/>
</ui>
In the TutorialModuleWidget class, we'll add three new functions at the end:
def getPositionAndSize(self):
"Validates and extracts position and size information from the 3 text boxes."
def isValidNonNegativeFloatStr(input_str):
if not input_str.replace(".", "").isnumeric():
return False
input_num = float(input_str)
return input_num >= 0.0
errors = []
positionXStr = self.ui.shapePositionXTextBox.text
if not isValidNonNegativeFloatStr(positionXStr):
errors.append("Position X must be a valid non-negative number.")
positionYStr = self.ui.shapePositionYTextBox.text
if not isValidNonNegativeFloatStr(positionYStr):
errors.append("Position Y must be a valid non-negative number.")
sizeStr = self.ui.shapeSizeTextBox.text
if not isValidNonNegativeFloatStr(sizeStr):
errors.append("Size must be a valid non-negative number.")
if len(errors) > 0:
slicer.util.errorDisplay("\n".join(errors))
return False, False
position = (float(positionXStr), float(positionYStr))
size = float(sizeStr)
return position, size
def onDrawCircleButtonClick(self):
"Gets position and size, checks validity, and calls the logic function for drawing circle segment on slice."
position, size = self.getPositionAndSize()
if position == False:
return
self.logic.drawCircleSegmentOnSlice(position, size)
def onDrawSquareButtonClick(self):
"Gets position and size, checks validity, and calls the logic function for drawing square segment on slice."
position, size = self.getPositionAndSize()
if position == False:
return
self.logic.drawSquareSegmentOnSlice(position, size)
We have created on-click handlers for the drawCircleButton and drawSquareButton buttons. Let's connect them, and add the following lines in the setup function, right after we connect the loadVolume button.
self.ui.drawCircleButton.connect('clicked(bool)', self.onDrawCircleButtonClick)
self.ui.drawSquareButton.connect('clicked(bool)', self.onDrawSquareButtonClick)
The entire setup function should look something like the following:
def setup(self):
"""
Called when the user opens the module the first time and the widget is initialized.
"""
ScriptedLoadableModuleWidget.setup(self)
# Load widget from .ui file (created by Qt Designer).
# Additional widgets can be instantiated manually and added to self.layout.
uiWidget = slicer.util.loadUI(self.resourcePath('UI/TutorialModule.ui'))
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)
# Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's
# "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's.
# "setMRMLScene(vtkMRMLScene*)" slot.
uiWidget.setMRMLScene(slicer.mrmlScene)
# Create logic class. Logic implements all computations that should be possible to run
# in batch mode, without a graphical user interface.
self.logic = TutorialModuleLogic()
# Connections
# These connections ensure that we update parameter node when scene is closed
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.StartCloseEvent, self.onSceneStartClose)
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.EndCloseEvent, self.onSceneEndClose)
# Buttons
self.ui.loadVolumeButton.connect('clicked(bool)', self.onLoadVolumeButtonClick)
self.ui.drawCircleButton.connect('clicked(bool)', self.onDrawCircleButtonClick)
self.ui.drawSquareButton.connect('clicked(bool)', self.onDrawSquareButtonClick)
# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
And here's the entire TutorialModuleWidget class:
class TutorialModuleWidget(ScriptedLoadableModuleWidget, VTKObservationMixin):
"""Uses ScriptedLoadableModuleWidget base class, available at:
https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py
"""
def __init__(self, parent=None):
"""
Called when the user opens the module the first time and the widget is initialized.
"""
ScriptedLoadableModuleWidget.__init__(self, parent)
VTKObservationMixin.__init__(self) # needed for parameter node observation
self.logic = None
self._parameterNode = None
self._updatingGUIFromParameterNode = False
def setup(self):
"""
Called when the user opens the module the first time and the widget is initialized.
"""
ScriptedLoadableModuleWidget.setup(self)
# Load widget from .ui file (created by Qt Designer).
# Additional widgets can be instantiated manually and added to self.layout.
uiWidget = slicer.util.loadUI(self.resourcePath('UI/TutorialModule.ui'))
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)
# Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's
# "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's.
# "setMRMLScene(vtkMRMLScene*)" slot.
uiWidget.setMRMLScene(slicer.mrmlScene)
# Create logic class. Logic implements all computations that should be possible to run
# in batch mode, without a graphical user interface.
self.logic = TutorialModuleLogic()
# Connections
# These connections ensure that we update parameter node when scene is closed
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.StartCloseEvent, self.onSceneStartClose)
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.EndCloseEvent, self.onSceneEndClose)
# Buttons
self.ui.loadVolumeButton.connect('clicked(bool)', self.onLoadVolumeButtonClick)
self.ui.drawCircleButton.connect('clicked(bool)', self.onDrawCircleButtonClick)
self.ui.drawSquareButton.connect('clicked(bool)', self.onDrawSquareButtonClick)
# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
def cleanup(self):
"""
Called when the application closes and the module widget is destroyed.
"""
self.removeObservers()
def enter(self):
"""
Called each time the user opens this module.
"""
# Make sure parameter node exists and observed
self.initializeParameterNode()
def exit(self):
"""
Called each time the user opens a different module.
"""
# Do not react to parameter node changes (GUI wlil be updated when the user enters into the module)
self.removeObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode)
def onSceneStartClose(self, caller, event):
"""
Called just before the scene is closed.
"""
# Parameter node will be reset, do not use it anymore
self.setParameterNode(None)
def onSceneEndClose(self, caller, event):
"""
Called just after the scene is closed.
"""
# If this module is shown while the scene is closed then recreate a new parameter node immediately
if self.parent.isEntered:
self.initializeParameterNode()
def initializeParameterNode(self):
"""
Ensure parameter node exists and observed.
"""
# Parameter node stores all user choices in parameter values, node selections, etc.
# so that when the scene is saved and reloaded, these settings are restored.
self.setParameterNode(self.logic.getParameterNode())
# Select default input nodes if nothing is selected yet to save a few clicks for the user
if not self._parameterNode.GetNodeReference("InputVolume"):
firstVolumeNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLScalarVolumeNode")
if firstVolumeNode:
self._parameterNode.SetNodeReferenceID("InputVolume", firstVolumeNode.GetID())
def setParameterNode(self, inputParameterNode):
"""
Set and observe parameter node.
Observation is needed because when the parameter node is changed then the GUI must be updated immediately.
"""
if inputParameterNode:
self.logic.setDefaultParameters(inputParameterNode)
# Unobserve previously selected parameter node and add an observer to the newly selected.
# Changes of parameter node are observed so that whenever parameters are changed by a script or any other module
# those are reflected immediately in the GUI.
if self._parameterNode is not None and self.hasObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode):
self.removeObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode)
self._parameterNode = inputParameterNode
if self._parameterNode is not None:
self.addObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode)
# Initial GUI update
self.updateGUIFromParameterNode()
def updateGUIFromParameterNode(self, caller=None, event=None):
"""
This method is called whenever parameter node is changed.
The module GUI is updated to show the current state of the parameter node.
"""
if self._parameterNode is None or self._updatingGUIFromParameterNode:
return
# Make sure GUI changes do not call updateParameterNodeFromGUI (it could cause infinite loop)
self._updatingGUIFromParameterNode = True
# All the GUI updates are done
self._updatingGUIFromParameterNode = False
def updateParameterNodeFromGUI(self, caller=None, event=None):
"""
This method is called when the user makes any change in the GUI.
The changes are saved into the parameter node (so that they are restored when the scene is saved and loaded).
"""
if self._parameterNode is None or self._updatingGUIFromParameterNode:
return
wasModified = self._parameterNode.StartModify() # Modify all properties in a single batch
self._parameterNode.SetNodeReferenceID("InputVolume", self.ui.inputSelector.currentNodeID)
self._parameterNode.SetNodeReferenceID("OutputVolume", self.ui.outputSelector.currentNodeID)
self._parameterNode.SetParameter("Threshold", str(self.ui.imageThresholdSliderWidget.value))
self._parameterNode.SetParameter("Invert", "true" if self.ui.invertOutputCheckBox.checked else "false")
self._parameterNode.SetNodeReferenceID("OutputVolumeInverse", self.ui.invertedOutputSelector.currentNodeID)
self._parameterNode.EndModify(wasModified)
def onLoadVolumeButtonClick(self):
volumePath = self.ui.volumePathTextBox.text
volumeNode = self.logic.loadVolume(volumePath)
if volumeNode is None:
return
segmentationNode = self.logic.getSegmentationNode()
segmentEditorNode = self.logic.getSegmentEditorNode()
self.ui.embeddedSegmentEditorWidget.setMRMLScene(slicer.mrmlScene)
self.ui.embeddedSegmentEditorWidget.setSegmentationNodeSelectorVisible(False)
self.ui.embeddedSegmentEditorWidget.setSourceVolumeNodeSelectorVisible(False)
self.ui.embeddedSegmentEditorWidget.setMRMLSegmentEditorNode(segmentEditorNode)
self.ui.embeddedSegmentEditorWidget.setSegmentationNode(segmentationNode)
self.ui.embeddedSegmentEditorWidget.setSourceVolumeNode(volumeNode)
def getPositionAndSize(self):
"Validates and extracts position and size information from the 3 text boxes."
def isValidNonNegativeFloatStr(input_str):
if not input_str.replace(".", "").isnumeric():
return False
input_num = float(input_str)
return input_num >= 0.0
errors = []
positionXStr = self.ui.shapePositionXTextBox.text
if not isValidNonNegativeFloatStr(positionXStr):
errors.append("Position X must be a valid non-negative number.")
positionYStr = self.ui.shapePositionYTextBox.text
if not isValidNonNegativeFloatStr(positionYStr):
errors.append("Position Y must be a valid non-negative number.")
sizeStr = self.ui.shapeSizeTextBox.text
if not isValidNonNegativeFloatStr(sizeStr):
errors.append("Size must be a valid non-negative number.")
if len(errors) > 0:
slicer.util.errorDisplay("\n".join(errors))
return False, False
position = (float(positionXStr), float(positionYStr))
size = float(sizeStr)
return position, size
def onDrawCircleButtonClick(self):
"Gets position and size, checks validity, and calls the logic function for drawing circle segment on slice."
position, size = self.getPositionAndSize()
if position == False:
return
self.logic.drawCircleSegmentOnSlice(position, size)
def onDrawSquareButtonClick(self):
"Gets position and size, checks validity, and calls the logic function for drawing square segment on slice."
position, size = self.getPositionAndSize()
if position == False:
return
self.logic.drawSquareSegmentOnSlice(position, size)
Let's create the two functions in TutorialModuleLogic for drawing the segments. For now, we'll just print the position and size we're receiving. Inside TutorialModuleLogic, add the following two functions:
def drawCircleSegmentOnSlice(self, position, size):
print(f"Position: {position}")
print(f"Size: {size}")
def drawSquareSegmentOnSlice(self, position, size):
print(f"Position: {position}")
print(f"Size: {size}")
Head back to Slicer and hit Reload. Now, you can try entering some numbers in the text boxes and clicking the buttons. If the inputs are valid non-negative numbers, you should see the position and size printed to the console, else, you should see an appropriate error message.
Great! Let's draw the segments.
In order to draw a segment on a particular slice, we have to get the index of that slice, using which we will later manipulate the 3D array. We'll create a function in TutorialModuleLogic that can get us the indices of the slices in all the three views.
def getCurrentSliceIndices(self):
"Get the current slice indices for Axial, Coronal, and Sagittal."
if self.volumeNode is None:
return {
'axial': -1,
'coronal': -1,
'sagittal': -1,
}
sliceIndices = {}
for viewName, planeName in zip(["Red", "Green", "Yellow"], ["axial", "coronal", "sagittal"]):
sliceWidget = slicer.app.layoutManager().sliceWidget(viewName)
sliceLogic = sliceWidget.sliceLogic()
sliceOffset = sliceLogic.GetSliceOffset()
sliceIndex = sliceLogic.GetSliceIndexFromOffset(sliceOffset) - 1
sliceIndices[planeName] = sliceIndex
return sliceIndices
Now, we'll add the real logic to draw the circle and the square. But, first, we have to understand how segment manipulation works. Open up the Python Console in Slicer by going to View > Python Console or by pressing Ctrl + 3. Load a Volume and open up the Segment Editor module from the Favorite Modules section. Then, run the following lines of Python
volumePath = "example-volume-path"
volumeNode = slicer.util.loadVolume(volumePath)
segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
proceduralSegmentId = "test_segment"
segmentation = segmentationNode.GetSegmentation()
segmentation.AddEmptySegment(proceduralSegmentId, "Test", None)
You should see a new segment called 'Test' appear in the segment editor (make sure to select the appropriate Segmentation from the dropdown). Hover over the Red view and you should see the index of the Axial slice in the Data Probe section (the third value of the volume tuple).
We'll hardcode the index and update its segment.
import numpy as np
sliceIndex = 36
segmentArray = slicer.util.arrayFromSegmentBinaryLabelmap(
segmentationNode,
proceduralSegmentId,
volumeNode
)
slice2dArray = segmentArray[sliceIndex]
slice2dArray = np.random.randint(2, size=slice2dArray.shape, dtype=slice2dArray.dtype)
segmentArray[sliceIndex] = slice2dArray
slicer.util.updateSegmentBinaryLabelmapFromArray(
segmentArray,
segmentationNode,
proceduralSegmentId,
volumeNode
)
After executing the above lines, you should see the segment on slice 36 be updated to a random mask.
Let's update the __init__ function of TutorialModuleLogic to initialize some default values.
def __init__(self):
"""
Called when the logic class is instantiated. Can be used for initializing member variables.
"""
ScriptedLoadableModuleLogic.__init__(self)
self.loadedVolumePath = None
self.volumeNode = None
self.segmentationNode = None
self.segmentEditorNode = None
Now, we can create the getSegmentEditorNode and getSegmentationNode functions that we have referenced in the Widget class.
def getSegmentEditorNode(self):
segmentEditorSingletonTag = "SegmentEditor"
segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode")
if segmentEditorNode is None:
segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode")
segmentEditorNode.UnRegister(None)
segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag)
self.segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode)
return self.segmentEditorNode
def getSegmentationNode(self):
if self.segmentationNode is not None:
return
self.segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
self.proceduralSegmentId = "procedural_segment"
segmentation = self.segmentationNode.GetSegmentation()
segmentation.AddEmptySegment(self.proceduralSegmentId, "Procedural", None)
return self.segmentationNode
And, finally, create the functions for drawing the circle and the square.
def drawCircleSegmentOnSlice(self, position, size):
sliceIndices = self.getCurrentSliceIndices()
if sliceIndices['axial'] == -1:
slicer.util.errorDisplay("Please load a volume first.")
return
segmentArray = slicer.util.arrayFromSegmentBinaryLabelmap(
self.segmentationNode,
self.proceduralSegmentId,
self.volumeNode
)
slice2dArray = segmentArray[sliceIndices['axial']]
# Draw Circle Logic
radius = size / 2
x_coords, y_coords = np.meshgrid(
np.arange(slice2dArray.shape[0]),
np.arange(slice2dArray.shape[1]),
indexing='ij'
)
distances = np.sqrt((x_coords - position[0])**2 + (y_coords - position[1])**2)
slice2dArray = np.where(distances <= radius, 1, slice2dArray)
###
segmentArray[sliceIndices['axial']] = slice2dArray
slicer.util.updateSegmentBinaryLabelmapFromArray(
segmentArray,
self.segmentationNode,
self.proceduralSegmentId,
self.volumeNode
)
def drawSquareSegmentOnSlice(self, position, size):
sliceIndices = self.getCurrentSliceIndices()
if sliceIndices['axial'] == -1:
slicer.util.errorDisplay("Please load a volume first.")
return
segmentArray = slicer.util.arrayFromSegmentBinaryLabelmap(
self.segmentationNode,
self.proceduralSegmentId,
self.volumeNode
)
slice2dArray = segmentArray[sliceIndices['axial']]
# Draw Square Logic
x_min = int(position[0] - size / 2)
x_max = int(position[0] + size / 2)
y_min = int(position[1] - size / 2)
y_max = int(position[1] + size / 2)
x_coords, y_coords = np.meshgrid(
np.arange(slice2dArray.shape[0]),
np.arange(slice2dArray.shape[1]),
indexing='ij'
)
isInside = (x_coords >= x_min) & (x_coords <= x_max) & (y_coords >= y_min) & (y_coords <= y_max)
slice2dArray = np.where(isInside, 1, slice2dArray)
###
segmentArray[sliceIndices['axial']] = slice2dArray
slicer.util.updateSegmentBinaryLabelmapFromArray(
segmentArray,
self.segmentationNode,
self.proceduralSegmentId,
self.volumeNode
)
Head back to Slicer and hit Reload. Now, you can try loading a volume, input the position and size, and test out the Circle and Square buttons. Everything should work as expected.
It would help to have an option in the GUI to allow the user to select which plane (Axial, Coronal, or Sagittal) they want to draw the segment on, but for the purpose of this tutorial, we would just hardcode that in our logic.
Let's do this on the Coronal (Green) plane. We can perform the update as follows:
slice2dArray = segmentArray[:, sliceIndices['coronal'], :]
# Draw Circle Logic
radius = size / 2
x_coords, y_coords = np.meshgrid(
np.arange(slice2dArray.shape[0]),
np.arange(slice2dArray.shape[1]),
indexing='ij'
)
distances = np.sqrt((x_coords - position[0])**2 + (y_coords - position[1])**2)
slice2dArray = np.where(distances <= radius, 1, slice2dArray)
###
segmentArray[:, sliceIndices['coronal'], :] = slice2dArray
Reload the extension and try drawing a circle. It should work as expected.
A lot of Nifti images have different dimensions for the Slicer. For example, the Axial slice may be 256x256, and the Coronal slice 75x512. In this case, Slicer will change the aspect ratio of the slice and get it close to a square. Hence, our "Circle" also shrunk in one dimension.
In this article, we've seen should we can update segments through code in a 3D Slicer module. We can now proceed to do something more meaningful than drawing simple shapes. In this Post, we created a webserver to run TotalSegmentator, that can segment an entire Nifti image. In Part 3, we'll update our extension to not just draw circles or squares but to work with our TotalSegmentator API.
See you next time :)