Skip to content

Commit

Permalink
Refactor group/selection unit list getters
Browse files Browse the repository at this point in the history
Also fixes #1220, huge overallocation of Lua tables
  • Loading branch information
sprunk authored and lhog committed Feb 23, 2024
1 parent 75c3ef7 commit b0cacea
Showing 1 changed file with 59 additions and 143 deletions.
202 changes: 59 additions & 143 deletions rts/Lua/LuaUnsyncedRead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,59 @@ static int GetSolidObjectSelectionVolume(lua_State* L, const CSolidObject* obj)



template <typename T>
static void PushNumberContainerAsArray(lua_State* const L, const T &v)
{
lua_createtable(L, v.size(), 0);
for (size_t i = 0; const auto &x : v) {
lua_pushnumber(L, x);
lua_rawseti(L, -2, ++i);
}
}

template <typename T>
static size_t PushUnitListSortedByDef(lua_State *const L, const T &units)
{
using unitDefID_t = int;
using unitID_t = int;

std::map <unitDefID_t, std::vector <unitID_t>> unitsByDef;

for (const auto unitID : units)
unitsByDef[unitHandler.GetUnit(unitID)->unitDef->id].push_back(unitID);

lua_createtable(L, 0, unitsByDef.size());

for (const auto & [unitDefID, unitIDs] : unitsByDef) {
assert(!unitIDs.empty());

PushNumberContainerAsArray(L, unitIDs);
lua_rawseti(L, -2, unitDefID);
}

return unitsByDef.size();
}

template <typename T>
static size_t PushSparseUnitTallyByDef(lua_State *const L, const T &v)
{
std::vector <size_t> counts (unitDefHandler->NumUnitDefs() + 1, 0);
size_t numDefKeys = 0;
for (const int unitID: v)
if (!counts[unitHandler.GetUnit(unitID)->unitDef->id]++)
numDefKeys++;

lua_createtable(L, 0, numDefKeys);
for (size_t i = 0; i < counts.size(); ++i) {
if (counts[i] == 0)
continue;

lua_pushnumber(L, counts[i]);
lua_rawseti(L, -2, i);
}

return numDefKeys;
}


/******************************************************************************
Expand Down Expand Up @@ -2309,27 +2362,10 @@ int LuaUnsyncedRead::GetSpectatingState(lua_State* L)
*/
int LuaUnsyncedRead::GetSelectedUnits(lua_State* L)
{
unsigned int count = 0;
const auto& selUnits = selectedUnitsHandler.selectedUnits;

// { [1] = number unitID, ... }
lua_createtable(L, selUnits.size(), 0);

for (const int unitID: selUnits) {
lua_pushnumber(L, unitID);
lua_rawseti(L, -2, ++count);
}
PushNumberContainerAsArray(L, selectedUnitsHandler.selectedUnits);
return 1;
}


static std::vector< std::pair<int, std::vector<const CUnit*> > > gsusUnitDefMap;
static std::vector< std::pair<int, int> > gsucCountMap;

static std::vector< std::pair<int, std::vector<const CUnit*> > > ggusUnitDefMap;
static std::vector< std::pair<int, int> > ggucCountMap;


/*** Get selected units aggregated by unitDefID
*
* @function Spring.GetSelectedUnitsSorted
Expand All @@ -2338,44 +2374,7 @@ static std::vector< std::pair<int, int> > ggucCountMap;
*/
int LuaUnsyncedRead::GetSelectedUnitsSorted(lua_State* L)
{
gsusUnitDefMap.clear();
gsusUnitDefMap.resize(unitDefHandler->NumUnitDefs() + 1);

int numDefKeys = 0;

for (const int unitID: selectedUnitsHandler.selectedUnits) {
const CUnit* unit = unitHandler.GetUnit(unitID);
const UnitDef* unitDef = unit->unitDef;

gsusUnitDefMap[unitDef->id].first = unitDef->id;
gsusUnitDefMap[unitDef->id].second.push_back(unit);
}

// { [number unitDefID] = { [1] = [number unitID], ...}, ... }
lua_createtable(L, 0, gsusUnitDefMap.size());

for (const std::pair<int, std::vector<const CUnit*> >& p: gsusUnitDefMap) {
const std::vector<const CUnit*>& v = p.second;

if (v.empty())
continue;

{
// inner array-table
lua_createtable(L, v.size(), 0);

for (unsigned int i = 0; i < v.size(); i++) {
lua_pushnumber(L, v[i]->id);
lua_rawseti(L, -2, i + 1);
}

// push the UnitDef index
lua_rawseti(L, -2, p.first);
}

numDefKeys += 1;
}

const auto numDefKeys = PushUnitListSortedByDef(L, selectedUnitsHandler.selectedUnits);
lua_pushnumber(L, numDefKeys);

return 2;
Expand All @@ -2391,33 +2390,7 @@ int LuaUnsyncedRead::GetSelectedUnitsSorted(lua_State* L)
*/
int LuaUnsyncedRead::GetSelectedUnitsCounts(lua_State* L)
{
gsucCountMap.clear();
gsucCountMap.resize(unitDefHandler->NumUnitDefs() + 1, {0, 0});

int numDefKeys = 0;

// tally the types
for (const int unitID: selectedUnitsHandler.selectedUnits) {
const CUnit* unit = unitHandler.GetUnit(unitID);
const UnitDef* unitDef = unit->unitDef;

gsucCountMap[unitDef->id].first = unitDef->id;
gsucCountMap[unitDef->id].second += 1;
}

// { [number unitDefID] = number count, ... }
lua_createtable(L, 0, gsucCountMap.size());

for (const std::pair<int, int>& p: gsucCountMap) {
if (p.second == 0)
continue;

lua_pushnumber(L, p.second); // push the UnitDef unit count (value)
lua_rawseti(L, -2, p.first); // push the UnitDef index (key)

numDefKeys += 1;
}

const auto numDefKeys = PushSparseUnitTallyByDef(L, selectedUnitsHandler.selectedUnits);
lua_pushnumber(L, numDefKeys);

return 2;
Expand Down Expand Up @@ -4093,15 +4066,7 @@ int LuaUnsyncedRead::GetGroupUnits(lua_State* L)

const CGroup* group = uiGroupHandlers[gu->myTeam].GetGroup(groupID);

lua_createtable(L, group->units.size(), 0);

unsigned int count = 0;

for (const int unitID: group->units) {
lua_pushnumber(L, unitID);
lua_rawseti(L, -2, ++count);
}

PushNumberContainerAsArray(L, group->units);
return 1;
}

Expand All @@ -4121,36 +4086,7 @@ int LuaUnsyncedRead::GetGroupUnitsSorted(lua_State* L)

const CGroup* group = uiGroupHandlers[gu->myTeam].GetGroup(groupID);

ggusUnitDefMap.clear();
ggusUnitDefMap.resize(unitDefHandler->NumUnitDefs() + 1);

for (const int unitID: group->units) {
const CUnit* unit = unitHandler.GetUnit(unitID);
const UnitDef* unitDef = unit->unitDef;

ggusUnitDefMap[unitDef->id].first = unitDef->id;
ggusUnitDefMap[unitDef->id].second.push_back(unit);
}

lua_createtable(L, 0, ggusUnitDefMap.size());

for (const auto& el: ggusUnitDefMap) {
const std::vector<const CUnit*>& v = el.second;

if (v.empty())
continue;

lua_pushnumber(L, el.first); // push the UnitDef index
lua_createtable(L, v.size(), 0); {

for (size_t i = 0; i < v.size(); i++) {
lua_pushnumber(L, v[i]->id);
lua_rawseti(L, -2, i + 1);
}
}
lua_rawset(L, -3);
}

PushUnitListSortedByDef(L, group->units);
return 1;
}

Expand All @@ -4170,27 +4106,7 @@ int LuaUnsyncedRead::GetGroupUnitsCounts(lua_State* L)

const CGroup* group = uiGroupHandlers[gu->myTeam].GetGroup(groupID);

ggucCountMap.clear();
ggucCountMap.resize(unitDefHandler->NumUnitDefs() + 1, {0, 0});

for (const int unitID: group->units) {
const CUnit* unit = unitHandler.GetUnit(unitID);
const UnitDef* unitDef = unit->unitDef;

ggucCountMap[unitDef->id].first = unitDef->id;
ggucCountMap[unitDef->id].second += 1;
}

lua_createtable(L, 0, ggucCountMap.size());

for (const auto& el: ggucCountMap) {
if (el.second == 0)
continue;

lua_pushnumber(L, el.second); // push the UnitDef unit count
lua_rawseti(L, -2, el.first); // push the UnitDef index
}

PushSparseUnitTallyByDef(L, group->units);
return 1;
}

Expand Down

0 comments on commit b0cacea

Please sign in to comment.