Skip to content

Commit

Permalink
Remove LRU cashe for invalid tokens (#235)
Browse files Browse the repository at this point in the history
* Add issued_at leeway and not_before leeway to jwt verifier. Remove LRU cache for tokens and always verify them using jwt library

* add unit tests

* download token 5 minutes before expiry

* use LRU cache for valid tokens

* request new token during the half life of the existing token

---------

Co-authored-by: Ravi Nagarjun Akella <raakella1@$HOSTNAME>
  • Loading branch information
raakella1 and Ravi Nagarjun Akella authored May 24, 2024
1 parent aa54f72 commit ccabd78
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 76 deletions.
2 changes: 1 addition & 1 deletion conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class SISLConan(ConanFile):
name = "sisl"
version = "8.6.8"
version = "8.7.0"
homepage = "https://github.com/eBay/sisl"
description = "Library for fast data structures, utilities"
topics = ("ebay", "components", "core", "efficiency")
Expand Down
16 changes: 1 addition & 15 deletions include/sisl/auth_manager/auth_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,7 @@ class LRUCache;
* they were extracted from decoded token.
*/
struct CachedToken {
AuthVerifyStatus response_status;
std::string msg;
bool valid;
std::chrono::system_clock::time_point expires_at;

inline void set_invalid(AuthVerifyStatus code, const std::string& reason) {
valid = false;
response_status = code;
msg = reason;
}

inline void set_valid() {
valid = true;
response_status = AuthVerifyStatus::OK;
}
};

class AuthManager {
Expand All @@ -68,4 +54,4 @@ class AuthManager {
// key_id -> signing public key
mutable LRUCache< std::string, std::string > m_cached_keys;
};
} // namespace sisl
} // namespace sisl
6 changes: 2 additions & 4 deletions include/sisl/auth_manager/trf_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ class TrfClient {
private:
void validate_grant_path() const;
bool grant_path_exists() const { return std::filesystem::exists(SECURITY_DYNAMIC_CONFIG(trf_client->grant_path)); }
bool access_token_expired() const {
return (std::chrono::system_clock::now() >
m_expiry + std::chrono::seconds(SECURITY_DYNAMIC_CONFIG(auth_manager->leeway)));
}
// If leeway is set, this will force us to download token ahead of its expiry
bool access_token_expired() const { return (std::chrono::system_clock::now() > m_expiry); }
static bool get_file_contents(const std::string& file_name, std::string& contents);

private:
Expand Down
35 changes: 13 additions & 22 deletions src/auth_manager/auth_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@ AuthVerifyStatus AuthManager::verify(const std::string& token, std::string& msg)
// if we have it in cache, just use it to make the decision
auto const token_hash = md5_sum(token);
if (auto const ct = m_cached_tokens.get(token_hash); ct) {
if (ct->valid) {
auto now = std::chrono::system_clock::now();
if (now > ct->expires_at + std::chrono::seconds(SECURITY_DYNAMIC_CONFIG(auth_manager->leeway))) {
m_cached_tokens.put(token_hash,
CachedToken{AuthVerifyStatus::UNAUTH, "token expired", false, ct->expires_at});
}
auto now = std::chrono::system_clock::now();
if (now > ct->expires_at + std::chrono::seconds(SECURITY_DYNAMIC_CONFIG(auth_manager->expiry_leeway_secs))) {
msg = "token expired";
return AuthVerifyStatus::UNAUTH;
}
msg = ct->msg;
return ct->response_status;
return AuthVerifyStatus::OK;
}

// not found in cache
Expand All @@ -63,31 +60,23 @@ AuthVerifyStatus AuthManager::verify(const std::string& token, std::string& msg)
verify_decoded(decoded);
app_name = get_app(decoded);
cached_token.expires_at = decoded.get_expires_at();
cached_token.set_valid();
} catch (const incomplete_verification_error& e) {
} catch (const std::exception& e) {
// verification incomplete, the token validity is not determined, shouldn't
// cache
msg = e.what();
return AuthVerifyStatus::UNAUTH;
} catch (const std::exception& e) {
cached_token.set_invalid(AuthVerifyStatus::UNAUTH, e.what());
m_cached_tokens.put(token_hash, cached_token);
msg = cached_token.msg;
return cached_token.response_status;
}

// check client application

if (SECURITY_DYNAMIC_CONFIG(auth_manager->auth_allowed_apps) != "all") {
if (SECURITY_DYNAMIC_CONFIG(auth_manager->auth_allowed_apps).find(app_name) == std::string::npos) {
cached_token.set_invalid(AuthVerifyStatus::FORBIDDEN,
fmt::format("application '{}' is not allowed to perform the request", app_name));
msg = fmt::format("application '{}' is not allowed to perform the request", app_name);
return AuthVerifyStatus::FORBIDDEN;
}
}

m_cached_tokens.put(token_hash, cached_token);
msg = cached_token.msg;
return cached_token.response_status;
return AuthVerifyStatus::OK;
}

void AuthManager::verify_decoded(const jwt::decoded_jwt& decoded) const {
Expand Down Expand Up @@ -125,7 +114,9 @@ void AuthManager::verify_decoded(const jwt::decoded_jwt& decoded) const {
const auto verifier{jwt::verify()
.with_issuer(SECURITY_DYNAMIC_CONFIG(auth_manager->issuer))
.allow_algorithm(jwt::algorithm::rs256(signing_key))
.expires_at_leeway(SECURITY_DYNAMIC_CONFIG(auth_manager->leeway))};
.expires_at_leeway(SECURITY_DYNAMIC_CONFIG(auth_manager->expiry_leeway_secs))
.issued_at_leeway(SECURITY_DYNAMIC_CONFIG(auth_manager->iat_leeway_secs))
.not_before_leeway(SECURITY_DYNAMIC_CONFIG(auth_manager->nbf_leeway_secs))};

// if verification fails, an instance of std::system_error subclass is thrown.
verifier.verify(decoded);
Expand Down Expand Up @@ -166,4 +157,4 @@ std::string AuthManager::get_app(const jwt::decoded_jwt& decoded) const {
const auto end{client_id.find_first_of(",", start)};
return client_id.substr(start, end - start);
}
} // namespace sisl
} // namespace sisl
8 changes: 7 additions & 1 deletion src/auth_manager/security_config.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ table AuthManager {
tf_token_url: string;

// leeway to the token expiration
leeway: uint32 = 0;
expiry_leeway_secs: uint32 = 0;

// leeway to the token issued_at
iat_leeway_secs: uint32 = 5;

// leeway to the token not_before
nbf_leeway_secs: uint32 = 5;

// ssl verification for the signing key download url
verify: bool = true;
Expand Down
67 changes: 59 additions & 8 deletions src/auth_manager/tests/AuthTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ class MockAuthManager : public AuthManager {
public:
using AuthManager::AuthManager;
MOCK_METHOD(std::string, download_key, (const std::string&), (const));
AuthVerifyStatus verify(const std::string& token) {
std::string msg;
return AuthManager::verify(token, msg);
}
std::string msg;
AuthVerifyStatus verify(const std::string& token) { return verify(token, msg); }
AuthVerifyStatus verify(const std::string& token, std::string& msg) { return AuthManager::verify(token, msg); }
};

class AuthTest : public ::testing::Test {
Expand All @@ -57,7 +56,6 @@ class AuthTest : public ::testing::Test {
SECURITY_SETTINGS_FACTORY().modifiable_settings([](auto& s) {
s.auth_manager->auth_allowed_apps = "app1, testapp, app2";
s.auth_manager->tf_token_url = "http://127.0.0.1";
s.auth_manager->leeway = 0;
s.auth_manager->issuer = "trustfabric";
});
SECURITY_SETTINGS_FACTORY().save();
Expand Down Expand Up @@ -140,6 +138,48 @@ TEST_F(AuthTest, reject_unauthorized_app) {
EXPECT_EQ(mock_auth_mgr->verify(token.sign_rs256()), AuthVerifyStatus::FORBIDDEN);
}

TEST_F(AuthTest, leeway_test) {
auto test_token = TestToken();
auto& trf_token = test_token.get_token();

// default leeway is 0 seconds for exp
trf_token.set_expires_at(std::chrono::system_clock::now() + std::chrono::seconds(1));
// default leeway is 5 seconds for iat and nbf
trf_token.set_issued_at(std::chrono::system_clock::now() + std::chrono::seconds(4));
trf_token.set_not_before(std::chrono::system_clock::now() + std::chrono::seconds(4));
auto raw_token = test_token.sign_rs256();

EXPECT_CALL(*mock_auth_mgr, download_key(_)).Times(1).WillOnce(Return(rsa_pub_key));
EXPECT_EQ(mock_auth_mgr->verify(raw_token), AuthVerifyStatus::OK);

std::string unauth_msg;
// token expired
trf_token.set_expires_at(std::chrono::system_clock::now() - std::chrono::seconds(1));
raw_token = test_token.sign_rs256();
EXPECT_CALL(*mock_auth_mgr, download_key(_)).Times(0);
EXPECT_EQ(mock_auth_mgr->verify(raw_token, unauth_msg), AuthVerifyStatus::UNAUTH);
EXPECT_EQ(unauth_msg, "token verification failed: token expired");

unauth_msg.clear();
// iat expired
trf_token.set_expires_at(std::chrono::system_clock::now() + std::chrono::seconds(1));
trf_token.set_issued_at(std::chrono::system_clock::now() + std::chrono::seconds(6));
trf_token.set_key_id("new_key_id");
raw_token = test_token.sign_rs256();
EXPECT_CALL(*mock_auth_mgr, download_key(_)).Times(1).WillOnce(Return(rsa_pub_key));
EXPECT_EQ(mock_auth_mgr->verify(raw_token, unauth_msg), AuthVerifyStatus::UNAUTH);
EXPECT_EQ(unauth_msg, "token verification failed: token expired");

unauth_msg.clear();
// nbf expired
trf_token.set_issued_at(std::chrono::system_clock::now() - std::chrono::seconds(1));
trf_token.set_not_before(std::chrono::system_clock::now() + std::chrono::seconds(6));
raw_token = test_token.sign_rs256();
EXPECT_CALL(*mock_auth_mgr, download_key(_)).Times(0);
EXPECT_EQ(mock_auth_mgr->verify(raw_token, unauth_msg), AuthVerifyStatus::UNAUTH);
EXPECT_EQ(unauth_msg, "token verification failed: token expired");
}

// Testing trf client
class MockTrfClient : public TrfClient {
public:
Expand Down Expand Up @@ -169,7 +209,7 @@ static void load_trf_settings() {
s.trf_client->grant_path = grant_path;
s.trf_client->server = "127.0.0.1:12346/token";
s.auth_manager->verify = false;
s.auth_manager->leeway = 30;
s.auth_manager->expiry_leeway_secs = 30;
});
SECURITY_SETTINGS_FACTORY().save();
}
Expand Down Expand Up @@ -200,7 +240,8 @@ TEST_F(AuthTest, trf_allow_valid_token) {
const auto raw_token{TestToken().sign_rs256()};
// mock_trf_client is expected to be called twice
// 1. First time when access_token is empty
// 2. When token is set to be expired
// 2. When expiry - leeway is less than current time
// 3. When access_token is expired
EXPECT_CALL(mock_trf_client, request_with_grant_token()).Times(2);
ON_CALL(mock_trf_client, request_with_grant_token())
.WillByDefault(
Expand All @@ -212,6 +253,8 @@ TEST_F(AuthTest, trf_allow_valid_token) {
// use the acces_token saved from the previous call
EXPECT_CALL(*mock_auth_mgr, download_key(_)).Times(0);
EXPECT_EQ(mock_auth_mgr->verify(mock_trf_client.get_token()), AuthVerifyStatus::OK);
mock_trf_client.set_expiry(std::chrono::system_clock::now() + std::chrono::seconds(25));
EXPECT_EQ(mock_auth_mgr->verify(mock_trf_client.get_token()), AuthVerifyStatus::OK);

// set token to be expired invoking request_with_grant_token
mock_trf_client.set_expiry(std::chrono::system_clock::now() - std::chrono::seconds(100));
Expand All @@ -222,9 +265,11 @@ TEST_F(AuthTest, trf_allow_valid_token) {
static const std::string trf_token_server_ip{"127.0.0.1"};
static const uint32_t trf_token_server_port{12346};
static std::string token_response;
static uint32_t token_expiry{4000};
static void set_token_response(const std::string& raw_token) {
token_response = "{\"access_token\":\"" + raw_token +
"\",\"token_type\":\"Bearer\",\"expires_in\":2000,\"refresh_token\":\"dummy_refresh_token\"}\n";
"\",\"token_type\":\"Bearer\",\"expires_in\":" + std::to_string(token_expiry) +
",\"refresh_token\":\"dummy_refresh_token\"}\n";
}

class TokenApiImpl : public TokenApi {
Expand Down Expand Up @@ -294,6 +339,9 @@ TEST_F(TrfClientTest, request_with_grant_token) {
mock_trf_client.__request_with_grant_token();
}));
mock_trf_client.get_token();
auto time_to_expiry = std::chrono::duration_cast< std::chrono::seconds >(mock_trf_client.get_expiry() -
std::chrono::system_clock::now());
EXPECT_LT(time_to_expiry, std::chrono::seconds{token_expiry} / 2);
EXPECT_EQ(raw_token, mock_trf_client.get_access_token());
EXPECT_EQ("Bearer", mock_trf_client.get_token_type());
}
Expand All @@ -309,6 +357,9 @@ TEST(TrfClientParseTest, parse_token) {
EXPECT_EQ(raw_token, mock_trf_client.get_access_token());
EXPECT_EQ("Bearer", mock_trf_client.get_token_type());
EXPECT_TRUE(mock_trf_client.get_expiry() > std::chrono::system_clock::now());
auto time_to_expiry = std::chrono::duration_cast< std::chrono::seconds >(mock_trf_client.get_expiry() -
std::chrono::system_clock::now());
EXPECT_LT(time_to_expiry, std::chrono::seconds{token_expiry} / 2);
remove_grant_path();
}
} // namespace sisl::testing
Expand Down
9 changes: 7 additions & 2 deletions src/auth_manager/trf_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ bool TrfClient::get_file_contents(const std::string& file_path, std::string& con
return false;
}

static std::chrono::seconds get_expiry(const std::chrono::seconds& token_expiry) {
// refresh after half the expiry time
return token_expiry / 2;
}

void TrfClient::request_with_grant_token() {
std::string grant_token;
if (!get_file_contents(SECURITY_DYNAMIC_CONFIG(trf_client->grant_path), grant_token)) {
Expand Down Expand Up @@ -72,7 +77,7 @@ void TrfClient::request_with_grant_token() {

try {
const nlohmann::json resp_json = nlohmann::json::parse(resp.text);
m_expiry = std::chrono::system_clock::now() + std::chrono::seconds(resp_json["expires_in"]);
m_expiry = std::chrono::system_clock::now() + get_expiry(std::chrono::seconds(resp_json["expires_in"]));
m_access_token = resp_json["access_token"];
m_token_type = resp_json["token_type"];
} catch ([[maybe_unused]] const nlohmann::detail::exception& e) {
Expand All @@ -91,7 +96,7 @@ void TrfClient::parse_response(const std::string& resp) {
if (m_token_type = get_quoted_string(resp, token2); m_access_token.empty()) { return; }
auto expiry_str = get_string(resp, token3);
if (expiry_str.empty()) { return; }
m_expiry = std::chrono::system_clock::now() + std::chrono::seconds(std::stol(expiry_str));
m_expiry = std::chrono::system_clock::now() + get_expiry(std::chrono::seconds(std::stol(expiry_str)));
} catch (const std::exception& e) { LOGERROR("failed to parse response: {}, what: {}", resp, e.what()); }
}

Expand Down
42 changes: 19 additions & 23 deletions src/grpc/tests/unit/auth_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ static void load_auth_settings() {
SECURITY_SETTINGS_FACTORY().modifiable_settings([](auto& s) {
s.auth_manager->auth_allowed_apps = "app1, testapp, app2";
s.auth_manager->tf_token_url = "http://127.0.0.1";
s.auth_manager->leeway = 0;
s.auth_manager->expiry_leeway_secs = 0;
s.auth_manager->issuer = "trustfabric";
s.trf_client->grant_path = grant_path;
s.trf_client->server = fmt::format("{}:{}/token", trf_token_server_ip, trf_token_server_port);
Expand Down Expand Up @@ -401,34 +401,34 @@ TEST(GenericServiceDeathTest, basic_test) {
auto g_grpc_server = GrpcServer::make("0.0.0.0:56789", nullptr, 1, "", "");
// register rpc before generic service is registered
#ifndef NDEBUG
ASSERT_DEATH(g_grpc_server->register_generic_rpc(
"method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }),
"Assertion .* failed");
ASSERT_DEATH(
g_grpc_server->register_generic_rpc("method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }),
"Assertion .* failed");
#else
EXPECT_FALSE(g_grpc_server->register_generic_rpc(
"method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
EXPECT_FALSE(
g_grpc_server->register_generic_rpc("method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
#endif

ASSERT_TRUE(g_grpc_server->register_async_generic_service());
// duplicate register
EXPECT_FALSE(g_grpc_server->register_async_generic_service());
// register rpc before server is run
#ifndef NDEBUG
ASSERT_DEATH(g_grpc_server->register_generic_rpc(
"method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }),
"Assertion .* failed");
ASSERT_DEATH(
g_grpc_server->register_generic_rpc("method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }),
"Assertion .* failed");
#else
EXPECT_FALSE(g_grpc_server->register_generic_rpc(
"method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
EXPECT_FALSE(
g_grpc_server->register_generic_rpc("method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
#endif
g_grpc_server->run();
EXPECT_TRUE(g_grpc_server->register_generic_rpc(
"method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
EXPECT_TRUE(g_grpc_server->register_generic_rpc(
"method2", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
EXPECT_TRUE(
g_grpc_server->register_generic_rpc("method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
EXPECT_TRUE(
g_grpc_server->register_generic_rpc("method2", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
// re-register method 1
EXPECT_FALSE(g_grpc_server->register_generic_rpc(
"method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));
EXPECT_FALSE(
g_grpc_server->register_generic_rpc("method1", [](boost::intrusive_ptr< GenericRpcData >&) { return true; }));

auto client = std::make_unique< GrpcAsyncClient >("0.0.0.0:56789", "", "");
client->init();
Expand All @@ -437,15 +437,11 @@ TEST(GenericServiceDeathTest, basic_test) {
::grpc::ByteBuffer cli_buf;
generic_stub->call_unary(
cli_buf, "method1",
[method = "method1"](::grpc::ByteBuffer&, ::grpc::Status& status) {
validate_generic_reply(method, status);
},
[method = "method1"](::grpc::ByteBuffer&, ::grpc::Status& status) { validate_generic_reply(method, status); },
1);
generic_stub->call_unary(
cli_buf, "method2",
[method = "method2"](::grpc::ByteBuffer&, ::grpc::Status& status) {
validate_generic_reply(method, status);
},
[method = "method2"](::grpc::ByteBuffer&, ::grpc::Status& status) { validate_generic_reply(method, status); },
1);
generic_stub->call_unary(
cli_buf, "method_unknown",
Expand Down

0 comments on commit ccabd78

Please sign in to comment.