diff --git a/.gitignore b/.gitignore index 26a9bfe..6240411 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.iml +.idea target diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java index 85a9054..d8bb2bb 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java @@ -17,7 +17,9 @@ package org.tensorflow.ndarray; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; /** * The shape of a Tensor or {@link NdArray}. @@ -74,8 +76,8 @@ public static Shape scalar() { * Shape scalar = Shape.of() * } * - * @param dimensionSizes number of elements in each dimension of this shape, if any, or - * {@link Shape#UNKNOWN_SIZE} if unknown. + * @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link + * Shape#UNKNOWN_SIZE} if unknown. * @return a new shape */ public static Shape of(long... dimensionSizes) { @@ -108,13 +110,34 @@ public long size() { * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned. * * @param i the index of the dimension to get the size for. If this Shape has a known number of - * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which - * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the - * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc. + * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in + * which case the position is counted from the end of the shape. E.g.: {@code size(-1)} + * returns the size of the last dimension, {@code size(-2)} the size of the second to last + * dimension etc. * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE} * otherwise. + * @deprecated Renamed to {@link #get(int)}. */ - public long size(int i) { + @Deprecated + public long size(int i){ + return get(i); + } + + /** + * The size of the dimension with the given index. + * + *

If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has + * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned. + * + * @param i the index of the dimension to get the size for. If this Shape has a known number of + * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in + * which case the position is counted from the end of the shape. E.g.: {@code size(-1)} + * returns the size of the last dimension, {@code size(-2)} the size of the second to last + * dimension etc. + * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE} + * otherwise. + */ + public long get(int i) { if (dimensionSizes == null) { return UNKNOWN_SIZE; } else if (i >= 0) { @@ -177,6 +200,24 @@ public long[] asArray() { } } + /** + * Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change + * this Shape's state. Returns null if {@link Shape#isUnknown()} is true. + */ + public List toListOrNull() { + long[] array = asArray(); + if (array == null) { + return null; + } + + List list = new ArrayList<>(array.length); + for (long l : array) { + list.add(l); + } + + return list; + } + @Override public int hashCode() { return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode(); @@ -186,6 +227,7 @@ public int hashCode() { * Equals implementation for Shapes. Two Shapes are considered equal iff: * *

+ * *