Skip to content

FeaturesLineWidget #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
replace label by object_id
  • Loading branch information
zoccoler committed Aug 1, 2023
commit ece6ed104fc3395f47c644356643e629193a6a75
46 changes: 26 additions & 20 deletions src/napari_matplotlib/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:

class FeaturesLineWidget(LineBaseWidget):
"""
Widget to do line plots of two features from a layer, grouped by label.
Widget to do line plots of two features from a layer, grouped by object_id.
"""

n_layers_input = Interval(1, 1)
Expand All @@ -77,7 +77,7 @@ def __init__(
# Add split-by selector
self._selectors["object_id"] = QComboBox()
self._selectors["object_id"].currentTextChanged.connect(self._draw)
self.layout().addWidget(QLabel(f"object-id:"))
self.layout().addWidget(QLabel(f"object_id:"))
self.layout().addWidget(self._selectors["object_id"])

for dim in ["x", "y"]:
Expand Down Expand Up @@ -121,17 +121,17 @@ def y_axis_key(self, key: str) -> None:
self._draw()

@property
def label_axis_key(self) -> Union[str, None]:
def object_id_axis_key(self) -> Union[str, None]:
"""
Key for the label factor.
Key for the object_id factor.
"""
if self._selectors["object_id"].count() == 0:
return None
else:
return self._selectors["object_id"].currentText()

@label_axis_key.setter
def label_axis_key(self, key: str) -> None:
@object_id_axis_key.setter
def object_id_axis_key(self, key: str) -> None:
self._selectors["object_id"].setCurrentText(key)
self._draw()

Expand All @@ -150,14 +150,20 @@ def _get_valid_axis_keys(self) -> List[str]:
else:
return self.layers[0].features.keys()

def _check_valid_labels_data_and_set_color_cycle(self):
def _check_valid_object_id_data_and_set_color_cycle(self):
# If no features, return False
# If no object_id_axis_key, return False
if self.layers[0].features is None \
or len(self.layers[0].features) == 0 \
or self.object_id_axis_key is None:
return False
feature_table = self.layers[0].features
# Get sorted unique labels
labels_from_table = np.unique(feature_table[self.label_axis_key].values).astype(int)
# Return True if object_ids from table match labels from layer, otherwise False
object_ids_from_table = np.unique(feature_table[self.object_id_axis_key].values).astype(int)
labels_from_layer = np.unique(self.layers[0].data)[1:] # exclude zero
if np.array_equal(labels_from_table, labels_from_layer):
if np.array_equal(object_ids_from_table, labels_from_layer):
# Set color cycle
self._set_color_cycle(labels_from_table.tolist())
self._set_color_cycle(object_ids_from_table.tolist())
return True
return False

Expand All @@ -172,20 +178,20 @@ def _ready_to_plot(self) -> bool:

feature_table = self.layers[0].features
valid_keys = self._get_valid_axis_keys()
valid_labels_data = self._check_valid_labels_data_and_set_color_cycle()
valid_object_id_data = self._check_valid_object_id_data_and_set_color_cycle()

return (
feature_table is not None
and len(feature_table) > 0
and self.x_axis_key in valid_keys
and self.y_axis_key in valid_keys
and self.label_axis_key in valid_keys
and valid_labels_data
and self.object_id_axis_key in valid_keys
and valid_object_id_data
)

def draw(self) -> None:
"""
Plot lines for two features from the currently selected layer, grouped by labels.
Plot lines for two features from the currently selected layer, grouped by object_id.
"""
if self._ready_to_plot():
# draw calls _get_data and then plots the data
Expand All @@ -202,7 +208,7 @@ def _set_color_cycle(self, labels):
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data from the ``features`` attribute of the first
selected layer grouped by labels.
selected layer grouped by object_id.

Returns
-------
Expand All @@ -218,10 +224,10 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
feature_table = self.layers[0].features

# Sort features by 'label' and x_axis_key
feature_table = feature_table.sort_values(by=[self.label_axis_key, self.x_axis_key])
# Get data for each label
grouped = feature_table.groupby(self.label_axis_key)
# Sort features by object_id and x_axis_key
feature_table = feature_table.sort_values(by=[self.object_id_axis_key, self.x_axis_key])
# Get data for each object_id (usually label)
grouped = feature_table.groupby(self.object_id_axis_key)
x = np.array([sub_df[self.x_axis_key].values for label, sub_df in grouped]).T.squeeze()
y = np.array([sub_df[self.y_axis_key].values for label, sub_df in grouped]).T.squeeze()

Expand Down