Skip to content

Commit b24ed59

Browse files
committed
Lazify hs.plot.markers
1 parent 7c4251e commit b24ed59

File tree

3 files changed

+44
-33
lines changed

3 files changed

+44
-33
lines changed

hyperspy/drawing/markers.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# You should have received a copy of the GNU General Public License
1717
# along with HyperSpy. If not, see <https://www.gnu.org/licenses/#GPL>.
1818

19-
import logging
2019
from copy import deepcopy
2120

2221
import matplotlib.collections as mpl_collections
@@ -25,11 +24,7 @@
2524
from matplotlib.transforms import IdentityTransform
2625

2726
from hyperspy.events import Event, Events
28-
from hyperspy.misc._markers import markers_dict_to_markers
29-
from hyperspy.misc.dask_utils import _get_navigation_dimension_chunk_slice
30-
from hyperspy.misc.utils import is_dask_array, isiterable
31-
32-
_logger = logging.getLogger(__name__)
27+
from hyperspy.misc import _markers, dask_utils, utils
3328

3429

3530
def convert_positions(peaks, signal_axes):
@@ -194,16 +189,16 @@ def __init__(
194189
for key, value in self.kwargs.items():
195190
# Populate `_iterable_argument_keys`
196191
if (
197-
isiterable(value)
192+
utils.isiterable(value)
198193
and not isinstance(value, str)
199194
and key != self._position_key
200195
):
201196
self._iterable_argument_keys.append(key)
202197

203198
# Handling dask arrays
204-
if is_dask_array(value) and value.dtype == object:
199+
if utils.is_dask_array(value) and value.dtype == object:
205200
self.dask_kwargs[key] = self.kwargs[key]
206-
elif is_dask_array(value): # and value.dtype != object:
201+
elif utils.is_dask_array(value): # and value.dtype != object:
207202
self.kwargs[key] = value.compute()
208203
# Patches or verts shouldn't be cast to array
209204
elif (
@@ -414,7 +409,7 @@ def remove_items(self, indices, keys=None, navigation_indices=None):
414409
# Don't remove when it doesn't have the same length as the
415410
# position kwargs because it is a "cycling" argument
416411
if (
417-
isiterable(value)
412+
utils.isiterable(value)
418413
and not isinstance(value, str)
419414
and len(value) == len(self.kwargs[self._position_key])
420415
):
@@ -498,7 +493,7 @@ def _get_cache_dask_kwargs_chunk(self, indices):
498493

499494
chunks = {key: value.chunks for key, value in self.dask_kwargs.items()}
500495
chunk_slices = {
501-
key: _get_navigation_dimension_chunk_slice(indices, chunk)
496+
key: dask_utils._get_navigation_dimension_chunk_slice(indices, chunk)
502497
for key, chunk in chunks.items()
503498
}
504499
to_compute = {}
@@ -612,7 +607,7 @@ def from_signal(
612607
return cls(**kwargs)
613608

614609
def __deepcopy__(self, memo):
615-
new_marker = markers_dict_to_markers(self._to_dictionary())
610+
new_marker = _markers.markers_dict_to_markers(self._to_dictionary())
616611
return new_marker
617612

618613
def _to_dictionary(self):
@@ -843,4 +838,6 @@ def plot_colorbar(self):
843838

844839

845840
def is_iterating(arg):
846-
return (isinstance(arg, np.ndarray) or is_dask_array(arg)) and arg.dtype == object
841+
return (
842+
isinstance(arg, np.ndarray) or utils.is_dask_array(arg)
843+
) and arg.dtype == object

hyperspy/drawing/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from hyperspy import signals
4141
from hyperspy.defaults_parser import preferences
4242
from hyperspy.docstrings.signal import HISTOGRAM_BIN_ARGS, HISTOGRAM_RANGE_ARGS
43-
from hyperspy.misc.utils import is_dask_array, isiterable, to_numpy
43+
from hyperspy.misc import utils
4444

4545
_logger = logging.getLogger(__name__)
4646

@@ -507,11 +507,11 @@ def _transpose_if_required(signal, expected_dimension):
507507
def _parse_array(signal, normalise=False):
508508
"""Convenience function to parse array from a signal."""
509509
data = signal.data
510-
if is_dask_array(data):
510+
if utils.is_dask_array(data):
511511
data = data.compute()
512512
if normalise:
513513
data = (data - data.min()) / (data.max() - data.min())
514-
return to_numpy(data)
514+
return utils.to_numpy(data)
515515

516516

517517
def plot_images(
@@ -929,7 +929,7 @@ def __check_single_colorbar(cbar):
929929

930930
# Get the figure from ax is provided
931931
if ax is not None:
932-
if isiterable(ax):
932+
if utils.isiterable(ax):
933933
if isinstance(ax, np.ndarray):
934934
# plt.subplots can return numpy array
935935
# convert and flatten to support list and array
@@ -1057,7 +1057,7 @@ def transparent_single_color_cmap(color):
10571057
ax = fig.add_axes([0, 0, 1, 1])
10581058
else:
10591059
ax = fig.add_subplot()
1060-
elif isiterable(ax):
1060+
elif utils.isiterable(ax):
10611061
raise ValueError(
10621062
"When using `overlay=True`, `ax` must be a matplotlib axis."
10631063
)
@@ -1133,7 +1133,7 @@ def transparent_single_color_cmap(color):
11331133
# Below is for non-overlayed images
11341134
else:
11351135
if ax is not None:
1136-
if not isiterable(ax):
1136+
if not utils.isiterable(ax):
11371137
ax = (ax,)
11381138

11391139
# Loop through each image, adding subplot for each one
@@ -1633,13 +1633,13 @@ def _reverse_legend(ax_, legend_loc_):
16331633
raise ValueError("The `ax` parameter is not supported for 'heatmap' style.")
16341634
# To avoid ambiguity, don't support iterable with overalp and cascase style
16351635
elif style in ["overlap", "cascade"]:
1636-
if isiterable(ax):
1636+
if utils.isiterable(ax):
16371637
raise ValueError(
16381638
"When using 'overlap' or 'cascade' style, `ax` must be a matplotlib axis."
16391639
)
16401640
fig = ax.get_figure()
16411641
else:
1642-
if isiterable(ax):
1642+
if utils.isiterable(ax):
16431643
if isinstance(ax, np.ndarray):
16441644
# plt.subplots can return numpy array
16451645
# convert and flatten to support list and array

hyperspy/utils/markers.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,9 @@
3434
3535
"""
3636

37-
from hyperspy.drawing._markers.arrows import Arrows
38-
from hyperspy.drawing._markers.circles import Circles
39-
from hyperspy.drawing._markers.ellipses import Ellipses
40-
from hyperspy.drawing._markers.horizontal_lines import HorizontalLines
41-
from hyperspy.drawing._markers.lines import Lines
42-
from hyperspy.drawing._markers.points import Points
43-
from hyperspy.drawing._markers.polygons import Polygons
44-
from hyperspy.drawing._markers.rectangles import Rectangles
45-
from hyperspy.drawing._markers.squares import Squares
46-
from hyperspy.drawing._markers.texts import Texts
47-
from hyperspy.drawing._markers.vertical_lines import VerticalLines
48-
from hyperspy.drawing.markers import Markers
37+
import importlib
38+
39+
# ruff: noqa: F822
4940

5041
__all__ = [
5142
"Arrows",
@@ -62,6 +53,29 @@
6253
"VerticalLines",
6354
]
6455

56+
_import_mapping = {
57+
"Arrows": "_markers.arrows",
58+
"Circles": "_markers.circles",
59+
"Ellipses": "_markers.ellipses",
60+
"HorizontalLines": "_markers.horizontal_lines",
61+
"Lines": "_markers.lines",
62+
"Markers": "markers",
63+
"Points": "_markers.points",
64+
"Polygons": "_markers.polygons",
65+
"Rectangles": "_markers.rectangles",
66+
"Squares": "_markers.squares",
67+
"Texts": "_markers.texts",
68+
"VerticalLines": "_markers.vertical_lines",
69+
}
70+
6571

6672
def __dir__():
6773
return sorted(__all__)
74+
75+
76+
def __getattr__(name):
77+
if name in __all__:
78+
import_path = f"hyperspy.drawing.{_import_mapping.get(name)}"
79+
return getattr(importlib.import_module(import_path), name)
80+
81+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

0 commit comments

Comments
 (0)