4646
4747
4848def _as_usm_ndarray (a , usm_type , sycl_queue ):
49+ """Converts input object to `dpctl.tensor.usm_ndarray`"""
50+
4951 if isinstance (a , dpnp_array ):
50- return a .get_array ()
52+ a = a .get_array ()
5153 return dpt .asarray (a , usm_type = usm_type , sycl_queue = sycl_queue )
5254
5355
56+ def _check_has_zero_val (a ):
57+ """Check if any element in input object is equal to zero"""
58+
59+ if dpnp .isscalar (a ):
60+ if a == 0 :
61+ return True
62+ elif hasattr (a , "any" ):
63+ if (a == 0 ).any ():
64+ return True
65+ elif any (val == 0 for val in a ):
66+ return True
67+ return False
68+
69+
70+ def _get_usm_allocations (objs , device = None , usm_type = None , sycl_queue = None ):
71+ """
72+ Get common USM allocations based on a list of input objects and an explicit
73+ device, a SYCL queue, or a USM type if specified.
74+
75+ """
76+
77+ alloc_usm_type , alloc_sycl_queue = get_usm_allocations (objs )
78+
79+ if sycl_queue is None and device is None :
80+ sycl_queue = alloc_sycl_queue
81+
82+ if usm_type is None :
83+ usm_type = alloc_usm_type or "device"
84+ return usm_type , dpnp .get_normalized_queue_device (
85+ sycl_queue = sycl_queue , device = device
86+ )
87+
88+
5489def dpnp_geomspace (
5590 start ,
5691 stop ,
@@ -62,76 +97,57 @@ def dpnp_geomspace(
6297 endpoint = True ,
6398 axis = 0 ,
6499):
65- usm_type_alloc , sycl_queue_alloc = get_usm_allocations ([start , stop ])
66-
67- if sycl_queue is None and device is None :
68- sycl_queue = sycl_queue_alloc
69- sycl_queue_normalized = dpnp .get_normalized_queue_device (
70- sycl_queue = sycl_queue , device = device
100+ usm_type , sycl_queue = _get_usm_allocations (
101+ [start , stop ], device = device , usm_type = usm_type , sycl_queue = sycl_queue
71102 )
72103
73- if usm_type is None :
74- _usm_type = "device" if usm_type_alloc is None else usm_type_alloc
75- else :
76- _usm_type = usm_type
104+ if _check_has_zero_val (start ) or _check_has_zero_val (stop ):
105+ raise ValueError ("Geometric sequence cannot include zero" )
77106
78- start = _as_usm_ndarray (start , _usm_type , sycl_queue_normalized )
79- stop = _as_usm_ndarray (stop , _usm_type , sycl_queue_normalized )
107+ start = dpnp . array (start , usm_type = usm_type , sycl_queue = sycl_queue )
108+ stop = dpnp . array (stop , usm_type = usm_type , sycl_queue = sycl_queue )
80109
81110 dt = numpy .result_type (start , stop , float (num ))
82- dt = map_dtype_to_device (dt , sycl_queue_normalized .sycl_device )
111+ dt = map_dtype_to_device (dt , sycl_queue .sycl_device )
83112 if dtype is None :
84113 dtype = dt
85114
86- if dpnp .any (start == 0 ) or dpnp .any (stop == 0 ):
87- raise ValueError ("Geometric sequence cannot include zero" )
115+ # promote both arguments to the same dtype
116+ start = start .astype (dt , copy = False )
117+ stop = stop .astype (dt , copy = False )
88118
89- out_sign = dpt .ones (
90- dpt .broadcast_arrays (start , stop )[0 ].shape ,
91- dtype = dt ,
92- usm_type = _usm_type ,
93- sycl_queue = sycl_queue_normalized ,
94- )
95- # Avoid negligible real or imaginary parts in output by rotating to
96- # positive real, calculating, then undoing rotation
97- if dpnp .issubdtype (dt , dpnp .complexfloating ):
98- all_imag = (start .real == 0.0 ) & (stop .real == 0.0 )
99- if dpnp .any (all_imag ):
100- start [all_imag ] = start [all_imag ].imag
101- stop [all_imag ] = stop [all_imag ].imag
102- out_sign [all_imag ] = 1j
103-
104- both_negative = (dpt .sign (start ) == - 1 ) & (dpt .sign (stop ) == - 1 )
105- if dpnp .any (both_negative ):
106- dpt .negative (start [both_negative ], out = start [both_negative ])
107- dpt .negative (stop [both_negative ], out = stop [both_negative ])
108- dpt .negative (out_sign [both_negative ], out = out_sign [both_negative ])
109-
110- log_start = dpt .log10 (start )
111- log_stop = dpt .log10 (stop )
119+ # Allow negative real values and ensure a consistent result for complex
120+ # (including avoiding negligible real or imaginary parts in output) by
121+ # rotating start to positive real, calculating, then undoing rotation.
122+ out_sign = dpnp .sign (start )
123+ start = start / out_sign
124+ stop = stop / out_sign
125+
126+ log_start = dpnp .log10 (start )
127+ log_stop = dpnp .log10 (stop )
112128 res = dpnp_logspace (
113129 log_start ,
114130 log_stop ,
115131 num = num ,
116132 endpoint = endpoint ,
117133 base = 10.0 ,
118- dtype = dtype ,
119- usm_type = _usm_type ,
120- sycl_queue = sycl_queue_normalized ,
121- ). get_array ()
134+ dtype = dt ,
135+ usm_type = usm_type ,
136+ sycl_queue = sycl_queue ,
137+ )
122138
139+ # Make sure the endpoints match the start and stop arguments. This is
140+ # necessary because np.exp(np.log(x)) is not necessarily equal to x.
123141 if num > 0 :
124142 res [0 ] = start
125143 if num > 1 and endpoint :
126144 res [- 1 ] = stop
127145
128- res = out_sign * res
146+ res * = out_sign
129147
130148 if axis != 0 :
131- res = dpt .moveaxis (res , 0 , axis )
132-
133- res = dpt .astype (res , dtype , copy = False )
134- return dpnp_array ._create_from_usm_ndarray (res )
149+ res = dpnp .moveaxis (res , 0 , axis )
150+ return res .astype (dtype , copy = False )
135151
136152
137153def dpnp_linspace (
@@ -252,45 +268,36 @@ def dpnp_logspace(
252268 dtype = None ,
253269 axis = 0 ,
254270):
255- if not dpnp .isscalar (base ):
256- usm_type_alloc , sycl_queue_alloc = get_usm_allocations (
257- [start , stop , base ]
258- )
259-
260- if sycl_queue is None and device is None :
261- sycl_queue = sycl_queue_alloc
262- sycl_queue = dpnp .get_normalized_queue_device (
263- sycl_queue = sycl_queue , device = device
264- )
265-
266- if usm_type is None :
267- usm_type = "device" if usm_type_alloc is None else usm_type_alloc
268- else :
269- usm_type = usm_type
271+ usm_type , sycl_queue = _get_usm_allocations (
272+ [start , stop , base ],
273+ device = device ,
274+ usm_type = usm_type ,
275+ sycl_queue = sycl_queue ,
276+ )
270277
271- start = _as_usm_ndarray (start , usm_type , sycl_queue )
272- stop = _as_usm_ndarray (stop , usm_type , sycl_queue )
273- base = _as_usm_ndarray (base , usm_type , sycl_queue )
278+ if not dpnp .isscalar (base ):
279+ base = dpnp .array (base , usm_type = usm_type , sycl_queue = sycl_queue )
280+ start = dpnp .array (start , usm_type = usm_type , sycl_queue = sycl_queue )
281+ stop = dpnp .array (stop , usm_type = usm_type , sycl_queue = sycl_queue )
274282
275- [ start , stop , base ] = dpt .broadcast_arrays (start , stop , base )
276- base = dpt .expand_dims (base , axis = axis )
283+ start , stop , base = dpnp .broadcast_arrays (start , stop , base )
284+ base = dpnp .expand_dims (base , axis = axis )
277285
278- # assume res as not a tuple, because retstep is False
286+ # assume ` res` as not a tuple, because retstep is False
279287 res = dpnp_linspace (
280288 start ,
281289 stop ,
282290 num = num ,
283- device = device ,
284291 usm_type = usm_type ,
285292 sycl_queue = sycl_queue ,
286293 endpoint = endpoint ,
287294 axis = axis ,
288- ). get_array ()
295+ )
289296
290- dpt .pow (base , res , out = res )
297+ dpnp .pow (base , res , out = res )
291298 if dtype is not None :
292- res = dpt .astype (res , dtype , copy = False )
293- return dpnp_array . _create_from_usm_ndarray ( res )
299+ res = res .astype (dtype , copy = False )
300+ return res
294301
295302
296303class dpnp_nd_grid :
0 commit comments