Skip to content

Commit 6d8d9b4

Browse files
committed
Align geomspace and linspace with input arrays towards NumPy implementation
1 parent 290ab65 commit 6d8d9b4

File tree

1 file changed

+81
-74
lines changed

1 file changed

+81
-74
lines changed

dpnp/dpnp_algo/dpnp_arraycreation.py

Lines changed: 81 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,46 @@
4646

4747

4848
def _as_usm_ndarray(a, usm_type, sycl_queue):
49+
"""Converts input object to `dpctl.tensor.usm_ndarray`"""
50+
4951
if isinstance(a, dpnp_array):
50-
return a.get_array()
52+
a = a.get_array()
5153
return dpt.asarray(a, usm_type=usm_type, sycl_queue=sycl_queue)
5254

5355

56+
def _check_has_zero_val(a):
57+
"""Check if any element in input object is equal to zero"""
58+
59+
if dpnp.isscalar(a):
60+
if a == 0:
61+
return True
62+
elif hasattr(a, "any"):
63+
if (a == 0).any():
64+
return True
65+
elif any(val == 0 for val in a):
66+
return True
67+
return False
68+
69+
70+
def _get_usm_allocations(objs, device=None, usm_type=None, sycl_queue=None):
71+
"""
72+
Get common USM allocations based on a list of input objects and an explicit
73+
device, a SYCL queue, or a USM type if specified.
74+
75+
"""
76+
77+
alloc_usm_type, alloc_sycl_queue = get_usm_allocations(objs)
78+
79+
if sycl_queue is None and device is None:
80+
sycl_queue = alloc_sycl_queue
81+
82+
if usm_type is None:
83+
usm_type = alloc_usm_type or "device"
84+
return usm_type, dpnp.get_normalized_queue_device(
85+
sycl_queue=sycl_queue, device=device
86+
)
87+
88+
5489
def dpnp_geomspace(
5590
start,
5691
stop,
@@ -62,76 +97,57 @@ def dpnp_geomspace(
6297
endpoint=True,
6398
axis=0,
6499
):
65-
usm_type_alloc, sycl_queue_alloc = get_usm_allocations([start, stop])
66-
67-
if sycl_queue is None and device is None:
68-
sycl_queue = sycl_queue_alloc
69-
sycl_queue_normalized = dpnp.get_normalized_queue_device(
70-
sycl_queue=sycl_queue, device=device
100+
usm_type, sycl_queue = _get_usm_allocations(
101+
[start, stop], device=device, usm_type=usm_type, sycl_queue=sycl_queue
71102
)
72103

73-
if usm_type is None:
74-
_usm_type = "device" if usm_type_alloc is None else usm_type_alloc
75-
else:
76-
_usm_type = usm_type
104+
if _check_has_zero_val(start) or _check_has_zero_val(stop):
105+
raise ValueError("Geometric sequence cannot include zero")
77106

78-
start = _as_usm_ndarray(start, _usm_type, sycl_queue_normalized)
79-
stop = _as_usm_ndarray(stop, _usm_type, sycl_queue_normalized)
107+
start = dpnp.array(start, usm_type=usm_type, sycl_queue=sycl_queue)
108+
stop = dpnp.array(stop, usm_type=usm_type, sycl_queue=sycl_queue)
80109

81110
dt = numpy.result_type(start, stop, float(num))
82-
dt = map_dtype_to_device(dt, sycl_queue_normalized.sycl_device)
111+
dt = map_dtype_to_device(dt, sycl_queue.sycl_device)
83112
if dtype is None:
84113
dtype = dt
85114

86-
if dpnp.any(start == 0) or dpnp.any(stop == 0):
87-
raise ValueError("Geometric sequence cannot include zero")
115+
# promote both arguments to the same dtype
116+
start = start.astype(dt, copy=False)
117+
stop = stop.astype(dt, copy=False)
88118

89-
out_sign = dpt.ones(
90-
dpt.broadcast_arrays(start, stop)[0].shape,
91-
dtype=dt,
92-
usm_type=_usm_type,
93-
sycl_queue=sycl_queue_normalized,
94-
)
95-
# Avoid negligible real or imaginary parts in output by rotating to
96-
# positive real, calculating, then undoing rotation
97-
if dpnp.issubdtype(dt, dpnp.complexfloating):
98-
all_imag = (start.real == 0.0) & (stop.real == 0.0)
99-
if dpnp.any(all_imag):
100-
start[all_imag] = start[all_imag].imag
101-
stop[all_imag] = stop[all_imag].imag
102-
out_sign[all_imag] = 1j
103-
104-
both_negative = (dpt.sign(start) == -1) & (dpt.sign(stop) == -1)
105-
if dpnp.any(both_negative):
106-
dpt.negative(start[both_negative], out=start[both_negative])
107-
dpt.negative(stop[both_negative], out=stop[both_negative])
108-
dpt.negative(out_sign[both_negative], out=out_sign[both_negative])
109-
110-
log_start = dpt.log10(start)
111-
log_stop = dpt.log10(stop)
119+
# Allow negative real values and ensure a consistent result for complex
120+
# (including avoiding negligible real or imaginary parts in output) by
121+
# rotating start to positive real, calculating, then undoing rotation.
122+
out_sign = dpnp.sign(start)
123+
start = start / out_sign
124+
stop = stop / out_sign
125+
126+
log_start = dpnp.log10(start)
127+
log_stop = dpnp.log10(stop)
112128
res = dpnp_logspace(
113129
log_start,
114130
log_stop,
115131
num=num,
116132
endpoint=endpoint,
117133
base=10.0,
118-
dtype=dtype,
119-
usm_type=_usm_type,
120-
sycl_queue=sycl_queue_normalized,
121-
).get_array()
134+
dtype=dt,
135+
usm_type=usm_type,
136+
sycl_queue=sycl_queue,
137+
)
122138

139+
# Make sure the endpoints match the start and stop arguments. This is
140+
# necessary because np.exp(np.log(x)) is not necessarily equal to x.
123141
if num > 0:
124142
res[0] = start
125143
if num > 1 and endpoint:
126144
res[-1] = stop
127145

128-
res = out_sign * res
146+
res *= out_sign
129147

130148
if axis != 0:
131-
res = dpt.moveaxis(res, 0, axis)
132-
133-
res = dpt.astype(res, dtype, copy=False)
134-
return dpnp_array._create_from_usm_ndarray(res)
149+
res = dpnp.moveaxis(res, 0, axis)
150+
return res.astype(dtype, copy=False)
135151

136152

137153
def dpnp_linspace(
@@ -252,45 +268,36 @@ def dpnp_logspace(
252268
dtype=None,
253269
axis=0,
254270
):
255-
if not dpnp.isscalar(base):
256-
usm_type_alloc, sycl_queue_alloc = get_usm_allocations(
257-
[start, stop, base]
258-
)
259-
260-
if sycl_queue is None and device is None:
261-
sycl_queue = sycl_queue_alloc
262-
sycl_queue = dpnp.get_normalized_queue_device(
263-
sycl_queue=sycl_queue, device=device
264-
)
265-
266-
if usm_type is None:
267-
usm_type = "device" if usm_type_alloc is None else usm_type_alloc
268-
else:
269-
usm_type = usm_type
271+
usm_type, sycl_queue = _get_usm_allocations(
272+
[start, stop, base],
273+
device=device,
274+
usm_type=usm_type,
275+
sycl_queue=sycl_queue,
276+
)
270277

271-
start = _as_usm_ndarray(start, usm_type, sycl_queue)
272-
stop = _as_usm_ndarray(stop, usm_type, sycl_queue)
273-
base = _as_usm_ndarray(base, usm_type, sycl_queue)
278+
if not dpnp.isscalar(base):
279+
base = dpnp.array(base, usm_type=usm_type, sycl_queue=sycl_queue)
280+
start = dpnp.array(start, usm_type=usm_type, sycl_queue=sycl_queue)
281+
stop = dpnp.array(stop, usm_type=usm_type, sycl_queue=sycl_queue)
274282

275-
[start, stop, base] = dpt.broadcast_arrays(start, stop, base)
276-
base = dpt.expand_dims(base, axis=axis)
283+
start, stop, base = dpnp.broadcast_arrays(start, stop, base)
284+
base = dpnp.expand_dims(base, axis=axis)
277285

278-
# assume res as not a tuple, because retstep is False
286+
# assume `res` as not a tuple, because retstep is False
279287
res = dpnp_linspace(
280288
start,
281289
stop,
282290
num=num,
283-
device=device,
284291
usm_type=usm_type,
285292
sycl_queue=sycl_queue,
286293
endpoint=endpoint,
287294
axis=axis,
288-
).get_array()
295+
)
289296

290-
dpt.pow(base, res, out=res)
297+
dpnp.pow(base, res, out=res)
291298
if dtype is not None:
292-
res = dpt.astype(res, dtype, copy=False)
293-
return dpnp_array._create_from_usm_ndarray(res)
299+
res = res.astype(dtype, copy=False)
300+
return res
294301

295302

296303
class dpnp_nd_grid:

0 commit comments

Comments
 (0)