diff --git a/.gitignore b/.gitignore
index 26a9bfe..6240411 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
*.iml
+.idea
target
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
index 85a9054..d8bb2bb 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
@@ -17,7 +17,9 @@
package org.tensorflow.ndarray;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
/**
* The shape of a Tensor or {@link NdArray}.
@@ -74,8 +76,8 @@ public static Shape scalar() {
* Shape scalar = Shape.of()
* }
*
- * @param dimensionSizes number of elements in each dimension of this shape, if any, or
- * {@link Shape#UNKNOWN_SIZE} if unknown.
+ * @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link
+ * Shape#UNKNOWN_SIZE} if unknown.
* @return a new shape
*/
public static Shape of(long... dimensionSizes) {
@@ -108,13 +110,34 @@ public long size() {
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
*
* @param i the index of the dimension to get the size for. If this Shape has a known number of
- * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
- * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
- * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in
+ * which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
+ * returns the size of the last dimension, {@code size(-2)} the size of the second to last
+ * dimension etc.
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
* otherwise.
+ * @deprecated Renamed to {@link #get(int)}.
*/
- public long size(int i) {
+ @Deprecated
+ public long size(int i){
+ return get(i);
+ }
+
+ /**
+ * The size of the dimension with the given index.
+ *
+ *
If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has
+ * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
+ *
+ * @param i the index of the dimension to get the size for. If this Shape has a known number of
+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in
+ * which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
+ * returns the size of the last dimension, {@code size(-2)} the size of the second to last
+ * dimension etc.
+ * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
+ * otherwise.
+ */
+ public long get(int i) {
if (dimensionSizes == null) {
return UNKNOWN_SIZE;
} else if (i >= 0) {
@@ -177,6 +200,24 @@ public long[] asArray() {
}
}
+ /**
+ * Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
+ * this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
+ */
+ public List toListOrNull() {
+ long[] array = asArray();
+ if (array == null) {
+ return null;
+ }
+
+ List list = new ArrayList<>(array.length);
+ for (long l : array) {
+ list.add(l);
+ }
+
+ return list;
+ }
+
@Override
public int hashCode() {
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
@@ -186,6 +227,7 @@ public int hashCode() {
* Equals implementation for Shapes. Two Shapes are considered equal iff:
*
*
+ *
*
* - the number of dimensions is defined and equal for both
*
- the size of each dimension is defined and equal for both
@@ -236,7 +278,8 @@ public Shape head() {
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
* shape
*
- * @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
+ * @param n the number of leading dimensions to get, must be <= than {@link
+ * Shape#numDimensions()}
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
* this Shape
*/
@@ -252,7 +295,9 @@ public Shape take(int n) {
/** Returns a new Shape, with this Shape's first dimension removed. */
public Shape tail() {
- if (dimensionSizes.length < 2) return Shape.of();
+ if (dimensionSizes.length < 2) {
+ return Shape.of();
+ }
return Shape.of(Arrays.copyOfRange(dimensionSizes, 1, dimensionSizes.length));
}
@@ -276,15 +321,21 @@ public Shape takeLast(int n) {
}
/**
- * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
+ * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code
+ * begin} to {@code end}.
+ *
* @param begin Where to start the sub-shape.
* @param end Where to end the sub-shape, exclusive.
* @return the sub-shape bounded by begin and end.
*/
- public Shape subShape(int begin, int end){
+ public Shape subShape(int begin, int end) {
if (end > numDimensions()) {
throw new ArrayIndexOutOfBoundsException(
- "End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
+ "End index "
+ + end
+ + " out of bounds: shape only has "
+ + numDimensions()
+ + " dimensions.");
}
if (begin < 0) {
throw new ArrayIndexOutOfBoundsException(
@@ -423,7 +474,7 @@ public boolean isCompatibleWith(Shape shape) {
return false;
}
for (int i = 0; i < numDimensions(); i++) {
- if (!isCompatible(size(i), shape.size(i))) {
+ if (!isCompatible(get(i), shape.get(i))) {
return false;
}
}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java
index 7d847bd..249e69a 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java
@@ -3798,9 +3798,9 @@ private static int[] computeArrayDims(NdArray> ndArray, int expectedRank) {
}
int[] arrayShape = new int[expectedRank];
for (int i = 0; i < expectedRank; ++i) {
- long dimSize = shape.size(i);
+ long dimSize = shape.get(i);
if (dimSize > Integer.MAX_VALUE) {
- throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.size(i) + ")");
+ throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.get(i) + ")");
}
arrayShape[i] = (int)dimSize;
}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java
index 7d0f022..9327013 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java
@@ -28,7 +28,7 @@ public static DimensionalSpace create(Shape shape) {
// Start from the last dimension, where all elements are continuous
for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) {
- dimensions[i] = new Axis(shape.size(i), elementSize);
+ dimensions[i] = new Axis(shape.get(i), elementSize);
elementSize *= dimensions[i].numElements();
}
return new DimensionalSpace(dimensions, shape);
@@ -189,7 +189,9 @@ public long positionOf(long[] coords) {
return position;
}
- /** Succinct description of the shape meant for debugging. */
+ /**
+ * Succinct description of the shape meant for debugging.
+ */
@Override
public String toString() {
return Arrays.toString(dimensions);
diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
index 26ac533..ec020f4 100644
--- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
+++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
@@ -24,10 +24,10 @@
import static org.tensorflow.ndarray.index.Indices.at;
import static org.tensorflow.ndarray.index.Indices.even;
import static org.tensorflow.ndarray.index.Indices.flip;
-import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.odd;
import static org.tensorflow.ndarray.index.Indices.range;
import static org.tensorflow.ndarray.index.Indices.seq;
+import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;
import java.nio.BufferOverflowException;
@@ -132,15 +132,15 @@ public void iterateElements() {
long value = 0L;
for (NdArray matrix : matrix3d.elements(0)) {
assertEquals(2L, matrix.shape().numDimensions());
- assertEquals(4L, matrix.shape().size(0));
- assertEquals(5L, matrix.shape().size(1));
+ assertEquals(4L, matrix.shape().get(0));
+ assertEquals(5L, matrix.shape().get(1));
for (NdArray vector : matrix.elements(0)) {
- assertEquals(1L, vector.shape().numDimensions()) ;
- assertEquals(5L, vector.shape().size(0));
+ assertEquals(1L, vector.shape().numDimensions());
+ assertEquals(5L, vector.shape().get(0));
for (NdArray scalar : vector.scalars()) {
- assertEquals(0L, scalar.shape().numDimensions()) ;
+ assertEquals(0L, scalar.shape().numDimensions());
scalar.setObject(valueOf(value++));
try {
scalar.elements(0);
@@ -162,7 +162,7 @@ public void iterateElements() {
@Test
public void slices() {
NdArray matrix3d = allocate(Shape.of(5, 4, 5));
-
+
T val100 = valueOf(100L);
matrix3d.setObject(val100, 1, 0, 0);
T val101 = valueOf(101L);
@@ -318,8 +318,8 @@ public void equalsAndHashCode() {
NdArray array4 = allocate(Shape.of(1, 2, 2));
@SuppressWarnings("unchecked")
- T[][][] values = (T[][][])(new Object[][][] {
- { { valueOf(0L), valueOf(1L) }, { valueOf(2L), valueOf(0L) } }
+ T[][][] values = (T[][][]) (new Object[][][]{
+ {{valueOf(0L), valueOf(1L)}, {valueOf(2L), valueOf(0L)}}
});
StdArrays.copyTo(values[0], array1);
diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
index d2e3e43..c1247ab 100644
--- a/ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
+++ b/ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
@@ -16,9 +16,15 @@
*/
package org.tensorflow.ndarray;
-import org.junit.jupiter.api.Test;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
-import static org.junit.jupiter.api.Assertions.*;
+import org.junit.jupiter.api.Test;
public class ShapeTest {
@@ -26,22 +32,22 @@ public class ShapeTest {
public void allKnownDimensions() {
Shape shape = Shape.of(5, 4, 5);
assertEquals(3, shape.numDimensions());
- assertEquals(5, shape.size(0));
- assertEquals(4, shape.size(1));
- assertEquals(5, shape.size(2));
+ assertEquals(5, shape.get(0));
+ assertEquals(4, shape.get(1));
+ assertEquals(5, shape.get(2));
assertEquals(100, shape.size());
- assertArrayEquals(new long[] {5, 4, 5}, shape.asArray());
+ assertArrayEquals(new long[]{5, 4, 5}, shape.asArray());
try {
- shape.size(3);
+ shape.get(3);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
}
- assertEquals(5, shape.size(-1));
- assertEquals(4, shape.size(-2));
- assertEquals(5, shape.size(-3));
+ assertEquals(5, shape.get(-1));
+ assertEquals(4, shape.get(-2));
+ assertEquals(5, shape.get(-3));
try {
- shape.size(-4);
+ shape.get(-4);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
@@ -133,7 +139,7 @@ public void testShapeModification() {
long[] internalShape = one.asArray();
assertNotNull(internalShape);
internalShape[0] = 42L;
- assertEquals(2L, one.size(0));
+ assertEquals(2L, one.get(0));
}
@Test