Skip to content

Commit

Permalink
Merge pull request #152 from Stanford-NavLab/ashwin/copy_no_cols_bug_fix
Browse files Browse the repository at this point in the history
Fixed issue with num_cols initialization and setting after `NavData.copy()`
  • Loading branch information
kanhereashwin authored Jan 29, 2024
2 parents ead1f77 + 12f3abd commit a63a0c3
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 8 deletions.
20 changes: 14 additions & 6 deletions gnss_lib_py/navdata/navdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ class NavData():
str_map : Dict
Map of the form {pandas column name : {array value : string}}.
Map is of the form {pandas column name : {}} for non string rows.
num_cols : int
Number of columns in array containing data, set to 0 by default
for empty NavData
curr_cols : int
Current number of column for iterator, set to 0 by default
Expand All @@ -70,7 +67,6 @@ def __init__(self, csv_path=None, pandas_df=None, numpy_array=None,
# Attributes for looping over all columns

self.curr_col = 0
self.num_cols = 0

if csv_path is not None:
self.from_csv_path(csv_path, **kwargs)
Expand All @@ -81,6 +77,7 @@ def __init__(self, csv_path=None, pandas_df=None, numpy_array=None,
else:
self._build_navdata()


if len(self) > 0:
self.rename(self._row_map(), inplace=True)
self.postprocess()
Expand Down Expand Up @@ -694,6 +691,18 @@ def shape(self):
shp = np.shape(self.array)
return shp

@property
def num_cols(self):
"""Return the number of columns in the NavData instance.
Returns
-------
num_cols : int
Number of columns in the NavData instance.
"""
num_cols = self.shape[1]
return num_cols

@property
def rows(self):
"""Return all row names in instance as a list
Expand Down Expand Up @@ -879,7 +888,6 @@ def __iter__(self):
Instantiation of NavData class with iteration initialized
"""
self.curr_col = 0
self.num_cols = np.shape(self.array)[1]
return self

def __next__(self):
Expand All @@ -890,7 +898,7 @@ def __next__(self):
x_curr : gnss_lib_py.navdata.navdata.NavData
Current column (based on iteration count)
"""
if self.curr_col >= self.num_cols:
if self.curr_col >= len(self):
raise StopIteration
x_curr = self.copy(rows=None, cols=self.curr_col)
self.curr_col += 1
Expand Down
9 changes: 9 additions & 0 deletions notebooks/tutorials/navdata/navdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,15 @@
"len(nav_data_csv)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nav_data_csv.num_cols"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gnss-lib-py"
version = "1.0.0"
version = "1.0.1"
description = "Modular Python tool for parsing, analyzing, and visualizing Global Navigation Satellite Systems (GNSS) data and state estimates"
authors = ["Derek Knowles <[email protected]>",
"Ashwin Kanhere <[email protected]>",
Expand Down
24 changes: 24 additions & 0 deletions tests/navdata/test_navdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,30 @@ def test_init_np(numpy_array):
with pytest.raises(TypeError):
data = NavData(numpy_array=pd.DataFrame([0]))


def test_len_rows(data):
"""Test that `len()` and `rows` return correct output.
Parameters
----------
data : gnss_lib_py.navdata.navdata.NavData
Simple version of NavData to use for test.
"""
assert len(data) == 6
assert data.rows == ['names', 'integers', 'floats', 'strings']


def test_num_cols(data):
"""Test that `num_cols` returns correct output.
Parameters
----------
data : gnss_lib_py.navdata.navdata.NavData
Simple version of NavData to use for test."""
assert data.num_cols == 6



@pytest.mark.parametrize('pandas_df',
[
lazy_fixture("df_simple"),
Expand Down
4 changes: 3 additions & 1 deletion tests/utils/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def pytest_collection_modifyitems(items):

module_mapping = {item: item.module.__name__ for item in items}
download_tests = [
"test_ephemeris_downloader"
"test_ephemeris_downloader",
"test_rinex_nav",
"test_rinex_obs"
]
sorted_items = [item for item in items if module_mapping[item] not in download_tests] \
+ [item for item in items if module_mapping[item] in download_tests]
Expand Down

0 comments on commit a63a0c3

Please sign in to comment.