Skip to content

Commit 3ee84a3

Browse files
authored
Merge pull request #1128 from c410-f3r/array
Add support for arbitrary arrays
2 parents 1937e21 + 29dbd99 commit 3ee84a3

File tree

6 files changed

+308
-104
lines changed

6 files changed

+308
-104
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1313
- Support PyPy 3.7. [#1538](https://github.com/PyO3/pyo3/pull/1538)
1414

1515
### Added
16+
- Add conversions for `[T; N]` for all `N` on Rust 1.51 and up. [#1128](https://github.com/PyO3/pyo3/pull/1128)
1617
- Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379)
1718
- Add `#[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411)
1819
- Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425)

build.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,17 @@ fn ensure_python_version(interpreter_config: &InterpreterConfig) -> Result<()> {
746746
Ok(())
747747
}
748748

749+
fn rustc_minor_version() -> Option<u32> {
750+
let rustc = env::var_os("RUSTC")?;
751+
let output = Command::new(rustc).arg("--version").output().ok()?;
752+
let version = core::str::from_utf8(&output.stdout).ok()?;
753+
let mut pieces = version.split('.');
754+
if pieces.next() != Some("rustc 1") {
755+
return None;
756+
}
757+
pieces.next()?.parse().ok()
758+
}
759+
749760
fn emit_cargo_configuration(interpreter_config: &InterpreterConfig) -> Result<()> {
750761
let target_os = cargo_env_var("CARGO_CFG_TARGET_OS").unwrap();
751762
let is_extension_module = cargo_env_var("CARGO_FEATURE_EXTENSION_MODULE").is_some();
@@ -850,6 +861,12 @@ fn emit_cargo_configuration(interpreter_config: &InterpreterConfig) -> Result<()
850861
println!("cargo:rustc-cfg=py_sys_config=\"{}\"", flag)
851862
}
852863

864+
// Enable use of const generics on Rust 1.51 and greater
865+
866+
if rustc_minor_version().unwrap_or(0) >= 51 {
867+
println!("cargo:rustc-cfg=min_const_generics");
868+
}
869+
853870
Ok(())
854871
}
855872

src/conversions/array.rs

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
use crate::{exceptions, PyErr};
2+
3+
#[cfg(min_const_generics)]
4+
mod min_const_generics {
5+
use super::invalid_sequence_length;
6+
use crate::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject};
7+
8+
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
9+
where
10+
T: ToPyObject,
11+
{
12+
fn into_py(self, py: Python) -> PyObject {
13+
self.as_ref().to_object(py)
14+
}
15+
}
16+
17+
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
18+
where
19+
T: FromPyObject<'a>,
20+
{
21+
#[cfg(not(feature = "nightly"))]
22+
fn extract(obj: &'a PyAny) -> PyResult<Self> {
23+
create_array_from_obj(obj)
24+
}
25+
26+
#[cfg(feature = "nightly")]
27+
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
28+
create_array_from_obj(obj)
29+
}
30+
}
31+
32+
#[cfg(feature = "nightly")]
33+
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
34+
where
35+
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
36+
{
37+
fn extract(obj: &'source PyAny) -> PyResult<Self> {
38+
use crate::{AsPyPointer, PyNativeType};
39+
// first try buffer protocol
40+
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
41+
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
42+
let mut array = [T::default(); N];
43+
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
44+
buf.release(obj.py());
45+
return Ok(array);
46+
}
47+
buf.release(obj.py());
48+
}
49+
}
50+
create_array_from_obj(obj)
51+
}
52+
}
53+
54+
fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
55+
where
56+
T: FromPyObject<'s>,
57+
{
58+
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
59+
let seq_len = seq.len()? as usize;
60+
if seq_len != N {
61+
return Err(invalid_sequence_length(N, seq_len));
62+
}
63+
array_try_from_fn(|idx| seq.get_item(idx as isize).and_then(PyAny::extract))
64+
}
65+
66+
// TODO use std::array::try_from_fn, if that stabilises:
67+
// (https://github.com/rust-lang/rust/pull/75644)
68+
fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
69+
where
70+
F: FnMut(usize) -> Result<T, E>,
71+
{
72+
// Helper to safely create arrays since the standard library doesn't
73+
// provide one yet. Shouldn't be necessary in the future.
74+
struct ArrayGuard<T, const N: usize> {
75+
dst: *mut T,
76+
initialized: usize,
77+
}
78+
79+
impl<T, const N: usize> Drop for ArrayGuard<T, N> {
80+
fn drop(&mut self) {
81+
debug_assert!(self.initialized <= N);
82+
let initialized_part =
83+
core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
84+
unsafe {
85+
core::ptr::drop_in_place(initialized_part);
86+
}
87+
}
88+
}
89+
90+
// [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
91+
// APIs which would make this easier.
92+
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
93+
let mut guard: ArrayGuard<T, N> = ArrayGuard {
94+
dst: array.as_mut_ptr() as _,
95+
initialized: 0,
96+
};
97+
unsafe {
98+
let mut value_ptr = array.as_mut_ptr() as *mut T;
99+
for i in 0..N {
100+
core::ptr::write(value_ptr, cb(i)?);
101+
value_ptr = value_ptr.offset(1);
102+
guard.initialized += 1;
103+
}
104+
core::mem::forget(guard);
105+
Ok(array.assume_init())
106+
}
107+
}
108+
109+
#[cfg(test)]
110+
mod test {
111+
use super::*;
112+
use std::{
113+
panic,
114+
sync::atomic::{AtomicUsize, Ordering},
115+
};
116+
117+
#[test]
118+
fn array_try_from_fn() {
119+
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
120+
struct CountDrop;
121+
impl Drop for CountDrop {
122+
fn drop(&mut self) {
123+
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
124+
}
125+
}
126+
let _ = catch_unwind_silent(move || {
127+
let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
128+
if idx == 2 {
129+
panic!("peek a boo");
130+
}
131+
Ok(CountDrop)
132+
});
133+
});
134+
assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
135+
}
136+
137+
#[test]
138+
fn test_extract_bytearray_to_array() {
139+
Python::with_gil(|py| {
140+
let v: [u8; 33] = py
141+
.eval(
142+
"bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
143+
None,
144+
None,
145+
)
146+
.unwrap()
147+
.extract()
148+
.unwrap();
149+
assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
150+
})
151+
}
152+
153+
// https://stackoverflow.com/a/59211505
154+
fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
155+
where
156+
F: FnOnce() -> R + panic::UnwindSafe,
157+
{
158+
let prev_hook = panic::take_hook();
159+
panic::set_hook(Box::new(|_| {}));
160+
let result = panic::catch_unwind(f);
161+
panic::set_hook(prev_hook);
162+
result
163+
}
164+
}
165+
}
166+
167+
#[cfg(not(min_const_generics))]
168+
mod array_impls {
169+
use super::invalid_sequence_length;
170+
use crate::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject};
171+
172+
macro_rules! array_impls {
173+
($($N:expr),+) => {
174+
$(
175+
impl<T> IntoPy<PyObject> for [T; $N]
176+
where
177+
T: ToPyObject
178+
{
179+
fn into_py(self, py: Python) -> PyObject {
180+
self.as_ref().to_object(py)
181+
}
182+
}
183+
184+
impl<'a, T> FromPyObject<'a> for [T; $N]
185+
where
186+
T: Copy + Default + FromPyObject<'a>,
187+
{
188+
#[cfg(not(feature = "nightly"))]
189+
fn extract(obj: &'a PyAny) -> PyResult<Self> {
190+
let mut array = [T::default(); $N];
191+
extract_sequence_into_slice(obj, &mut array)?;
192+
Ok(array)
193+
}
194+
195+
#[cfg(feature = "nightly")]
196+
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
197+
let mut array = [T::default(); $N];
198+
extract_sequence_into_slice(obj, &mut array)?;
199+
Ok(array)
200+
}
201+
}
202+
203+
#[cfg(feature = "nightly")]
204+
impl<'source, T> FromPyObject<'source> for [T; $N]
205+
where
206+
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
207+
{
208+
fn extract(obj: &'source PyAny) -> PyResult<Self> {
209+
let mut array = [T::default(); $N];
210+
// first try buffer protocol
211+
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
212+
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
213+
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
214+
buf.release(obj.py());
215+
return Ok(array);
216+
}
217+
buf.release(obj.py());
218+
}
219+
}
220+
// fall back to sequence protocol
221+
extract_sequence_into_slice(obj, &mut array)?;
222+
Ok(array)
223+
}
224+
}
225+
)+
226+
}
227+
}
228+
229+
#[cfg(not(min_const_generics))]
230+
array_impls!(
231+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
232+
25, 26, 27, 28, 29, 30, 31, 32
233+
);
234+
235+
#[cfg(not(min_const_generics))]
236+
fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
237+
where
238+
T: FromPyObject<'s>,
239+
{
240+
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
241+
let seq_len = seq.len()? as usize;
242+
if seq_len != slice.len() {
243+
return Err(invalid_sequence_length(slice.len(), seq_len));
244+
}
245+
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
246+
*value = item?.extract::<T>()?;
247+
}
248+
Ok(())
249+
}
250+
}
251+
252+
fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
253+
exceptions::PyValueError::new_err(format!(
254+
"expected a sequence of length {} (got {})",
255+
expected, actual
256+
))
257+
}
258+
259+
#[cfg(test)]
260+
mod test {
261+
use crate::{PyResult, Python};
262+
263+
#[test]
264+
fn test_extract_small_bytearray_to_array() {
265+
Python::with_gil(|py| {
266+
let v: [u8; 3] = py
267+
.eval("bytearray(b'abc')", None, None)
268+
.unwrap()
269+
.extract()
270+
.unwrap();
271+
assert!(&v == b"abc");
272+
});
273+
}
274+
275+
#[test]
276+
fn test_extract_invalid_sequence_length() {
277+
Python::with_gil(|py| {
278+
let v: PyResult<[u8; 3]> = py
279+
.eval("bytearray(b'abcdefg')", None, None)
280+
.unwrap()
281+
.extract();
282+
assert_eq!(
283+
v.unwrap_err().to_string(),
284+
"ValueError: expected a sequence of length 3 (got 7)"
285+
);
286+
})
287+
}
288+
}

src/conversions/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//! This module contains conversions between non-String Rust object and their string representation
2-
//! in Python
1+
//! This module contains conversions between various Rust object and their representation in Python.
32
3+
mod array;
44
mod osstr;
55
mod path;

src/types/list.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,6 @@ where
178178
}
179179
}
180180

181-
macro_rules! array_impls {
182-
($($N:expr),+) => {
183-
$(
184-
impl<T> IntoPy<PyObject> for [T; $N]
185-
where
186-
T: ToPyObject
187-
{
188-
fn into_py(self, py: Python) -> PyObject {
189-
self.as_ref().to_object(py)
190-
}
191-
}
192-
)+
193-
}
194-
}
195-
196-
array_impls!(
197-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
198-
26, 27, 28, 29, 30, 31, 32
199-
);
200-
201181
impl<T> ToPyObject for Vec<T>
202182
where
203183
T: ToPyObject,

0 commit comments

Comments
 (0)