Skip to content

Commit

Permalink
refactor: extracts certificate validation into a separate class
Browse files Browse the repository at this point in the history
  • Loading branch information
stalniy committed Jan 31, 2025
1 parent 9ab53ef commit f591481
Show file tree
Hide file tree
Showing 11 changed files with 409 additions and 186 deletions.
1 change: 1 addition & 0 deletions apps/provider-proxy/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
},
"dependencies": {
"@akashnetwork/net": "*",
"async-sema": "^3.1.1",
"bech32": "^2.0.0",
"cors": "^2.8.5",
"express": "^4.18.2",
Expand Down
11 changes: 9 additions & 2 deletions apps/provider-proxy/src/container.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { netConfig, SupportedChainNetworks } from "@akashnetwork/net";

import { CertificateValidator, createCertificateValidatorInstrumentation } from "./services/CertificateValidator";
import { ProviderProxy } from "./services/ProviderProxy";
import { ProviderService } from "./services/ProviderService";
import { WebsocketStats } from "./services/WebsocketStats";
Expand All @@ -13,9 +14,15 @@ const providerService = new ProviderService((network: SupportedChainNetworks) =>
// @see https://github.com/mswjs/msw/discussions/2416
return process.env.TEST_CHAIN_NETWORK_URL || netConfig.getBaseAPIUrl(network);
}, fetch);
const providerProxy = new ProviderProxy(Date.now, providerService);
const certificateValidator = new CertificateValidator(
Date.now,
providerService,
process.env.NODE_ENV === "test" ? undefined : createCertificateValidatorInstrumentation(console)
);
const providerProxy = new ProviderProxy(certificateValidator);

export const container = {
wsStats,
providerProxy
providerProxy,
certificateValidator
};
152 changes: 152 additions & 0 deletions apps/provider-proxy/src/services/CertificateValidator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { Sema } from "async-sema";
import { bech32 } from "bech32";
import { X509Certificate } from "crypto";
import { LRUCache } from "lru-cache";

import { ProviderService } from "./ProviderService";

export class CertificateValidator {
private readonly knownCertificatesCache = new LRUCache<string, X509Certificate | null>({
max: 100_000,
ttl: 30 * 60 * 1000
});
private readonly locks: Record<string, Sema> = {};

constructor(
private readonly now: () => number,
private readonly providerService: ProviderService,
private readonly instrumentation?: CertificateValidatorIntrumentation
) {}

async validate(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string): Promise<CertValidationResult> {
const now = this.now();
const validationResult = validateCertificateAttrs(certificate, now);

if (validationResult.ok === false) {
this.instrumentation?.onInvalidAttrs?.(certificate, network, providerAddress, now, validationResult);
return validationResult;
}

const providerCertificate = await this.getProviderCertificate(certificate, network, providerAddress);
if (!providerCertificate) {
this.instrumentation?.onUnknownCert?.(certificate, network, providerAddress);
return { ok: false, code: "unknownCertificate" };
}

if (providerCertificate.fingerprint256 !== certificate.fingerprint256) {
this.instrumentation?.onInvalidFingerprint?.(certificate, network, providerAddress, providerCertificate);
return { ok: false, code: "fingerprintMismatch" };
}

this.instrumentation?.onValidationSuccess?.(certificate, network, providerAddress, now);

return { ok: true };
}

private async getProviderCertificate(cert: X509Certificate, network: SupportedChainNetworks, providerAddress: string): Promise<X509Certificate | null> {
const key = `${network}.${providerAddress}.${cert.serialNumber}`;

this.locks[key] ??= new Sema(1);

try {
await this.locks[key].acquire();
if (!this.knownCertificatesCache.has(key)) {
const certificate = await this.providerService.getCertificate(network, providerAddress, cert.serialNumber);
this.knownCertificatesCache.set(key, certificate);
return certificate;
}

return this.knownCertificatesCache.get(key);
} finally {
this.locks[key].release();
delete this.locks[key];
}
}
}

export type CertValidationResult = { ok: true } | CertValidationResultError;
export type CertValidationResultError = {
ok: false;
code: "validInFuture" | "expired" | "invalidSerialNumber" | "notSelfSigned" | "CommonNameIsNotBech32" | "unknownCertificate" | "fingerprintMismatch";
};
export interface CertificateValidatorIntrumentation {
onValidationSuccess?(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string, now: number): void;
onInvalidAttrs?(
certificate: X509Certificate,
network: SupportedChainNetworks,
providerAddress: string,
now: number,
validationResult: CertValidationResultError
): void;
onUnknownCert?(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string): void;
onInvalidFingerprint?(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string, providerCertificate: X509Certificate): void;
}

function validateCertificateAttrs(cert: X509Certificate, now: number): CertValidationResult {
if (new Date(cert.validFrom).getTime() > now) {
return {
ok: false,
code: "validInFuture"
};
}

if (new Date(cert.validTo).getTime() < now) {
return {
ok: false,
code: "expired"
};
}

if (!cert.serialNumber?.trim()) {
return {
ok: false,
code: "invalidSerialNumber"
};
}

if (cert.issuer !== cert.subject) {
return {
ok: false,
code: "notSelfSigned"
};
}

const commonName = parseCertSubject(cert.subject, "CN");
if (!commonName || !bech32.decodeUnsafe(commonName)) {
return {
ok: false,
code: "CommonNameIsNotBech32"
};
}

return { ok: true };
}

function parseCertSubject(subject: string, attr: string): string | null {
const attrPrefix = `${attr}=`;
const index = subject.indexOf(attrPrefix);
if (index === -1) return null;

const endIndex = subject.indexOf("\n", index);
if (endIndex === -1) return subject.slice(index);

return subject.slice(index + attrPrefix.length, endIndex);
}

export const createCertificateValidatorInstrumentation = (logger: typeof console): CertificateValidatorIntrumentation => ({
onValidationSuccess(certificate, network, providerAddress, now) {
logger.log(`Successfully validated ${certificate.serialNumber} in ${network} for "${providerAddress}" at ${now}`);
},
onInvalidAttrs(certificate, network, providerAddress, now, result) {
logger.log(`Certificate ${certificate.serialNumber} is invalid in ${network} for "${providerAddress}" because ${result.code} at ${now}`);
},
onInvalidFingerprint(certificate, network, providerAddress, providerCertificate) {
logger.log(
`Certificate ${certificate.serialNumber} (${certificate.fingerprint256}) fingerprint does not match fingerprint in ${network} for ${providerAddress}: ${providerCertificate.fingerprint256}`
);
},
onUnknownCert(certificate, network, providerAddress) {
logger.log(`Certificate ${certificate.serialNumber} does not have corresponding certificate in ${network} for ${providerAddress}`);
}
});
78 changes: 34 additions & 44 deletions apps/provider-proxy/src/services/ProviderProxy.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { X509Certificate } from "crypto";
import { IncomingMessage } from "http";
import https, { RequestOptions } from "https";
import { LRUCache } from "lru-cache";
import { TLSSocket } from "tls";

import { CertValidationResultError, validateCertificate } from "../utils/validateCertificate";
import { ProviderService } from "./ProviderService";
import { CertificateValidator, CertValidationResultError } from "./CertificateValidator";

export class ProviderProxy {
private readonly knownCertificatesCache = new LRUCache<string, boolean>({
max: 100_000,
ttl: 30 * 60 * 1000
});
/**
* Cache agents in order to control TLS session resumption
*/
private readonly agentsCache = new LRUCache<string, https.Agent>({
max: 100_000
max: 1_000_000
});

constructor(
private readonly now: () => number,
private readonly providerService: ProviderService
) {}
constructor(private readonly certificateValidator: CertificateValidator) {}

connect(url: string, options: ProxyConnectOptions): Promise<ProxyConnectionResult> {
const agent = this.getHttpsAgent(options.network, options.providerAddress, {
const agentOptions: TLSChainAgentOptions = {
timeout: options.timeout,
rejectUnauthorized: false,
cert: options.cert,
key: options.key,
rejectUnauthorized: false
});
chainNetwork: options.network,
providerAddress: options.providerAddress
};
const agent = this.getHttpsAgent(agentOptions);
return new Promise<ProxyConnectionResult>((resolve, reject) => {
const req = https.request(
url,
Expand Down Expand Up @@ -59,18 +56,15 @@ export class ProviderProxy {
const didHandshake = !!serverCert;

if (didHandshake && options.network && options.providerAddress) {
const validationResult = validateCertificate(serverCert, this.now());
const validationResult = await this.certificateValidator.validate(serverCert, options.network, options.providerAddress);
if (validationResult.ok === false) {
// remove agent from cache to destroy TLS session to force TLS handshake on the next call
this.agentsCache.delete(genAgentsCacheKey(agentOptions));
resolve({ ok: false, code: "invalidCertificate", reason: validationResult.code });
req.off("error", reject);
req.destroy();
this.agentsCache.delete(`${options.network}.${options.providerAddress}`);
return resolve({ ok: false, code: "invalidCertificate", reason: validationResult.code });
}

const isKnown = await this.isKnownCertificate(serverCert, options.network, options.providerAddress);
if (!isKnown) {
req.destroy();
this.agentsCache.delete(`${options.network}.${options.providerAddress}`);
return resolve({ ok: false, code: "invalidCertificate", reason: "unknownCertificate" });
agent.destroy();
return;
}
}

Expand All @@ -82,9 +76,7 @@ export class ProviderProxy {
);

if (!req.reusedSocket) {
req.on("error", (error: (Error & { code: string }) | undefined) => {
reject(error);
});
req.on("error", reject);
req.on("timeout", () => {
// here we are just notified that response take more than specified in request options timeout
// then we manually destroy request and it drops connection and
Expand All @@ -98,23 +90,12 @@ export class ProviderProxy {
});
}

private async isKnownCertificate(cert: X509Certificate, network: SupportedChainNetworks, providerAddress: string): Promise<boolean> {
const key = `${network}.${providerAddress}.${cert.serialNumber}`;

if (!this.knownCertificatesCache.has(key)) {
const hasCertificate = await this.providerService.hasCertificate(network, providerAddress, cert.serialNumber);
this.knownCertificatesCache.set(key, hasCertificate);
return hasCertificate;
}

return this.knownCertificatesCache.get(key);
}

private getHttpsAgent(network: SupportedChainNetworks, providerAddress: string, options: https.AgentOptions): https.Agent {
const key = `${network}.${providerAddress}`;
private getHttpsAgent(options: TLSChainAgentOptions): https.Agent {
const key = genAgentsCacheKey(options);

if (!this.agentsCache.has(key)) {
const agent = new https.Agent(options);
const { chainNetwork, providerAddress, ...agentOptions } = options;
const agent = new https.Agent(agentOptions);
this.agentsCache.set(key, agent);
return agent;
}
Expand All @@ -123,6 +104,10 @@ export class ProviderProxy {
}
}

function genAgentsCacheKey(options: TLSChainAgentOptions): string {
return `${options.chainNetwork}:${options.providerAddress}:${options.cert}:${options.key}`;
}

export interface ProxyConnectOptions extends Pick<RequestOptions, "cert" | "key" | "method"> {
body?: BodyInit;
headers?: Record<string, string>;
Expand All @@ -140,5 +125,10 @@ interface ProxyConnectionResultSuccess {
}

type ProxyConnectionResultError =
| { ok: false; code: "invalidCertificate"; reason: CertValidationResultError["code"] | "unknownCertificate" }
| { ok: false; code: "invalidCertificate"; reason: CertValidationResultError["code"] }
| { ok: false; code: "insecureConnection" };

interface TLSChainAgentOptions extends https.AgentOptions {
chainNetwork: SupportedChainNetworks;
providerAddress: string;
}
13 changes: 9 additions & 4 deletions apps/provider-proxy/src/services/ProviderService.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { X509Certificate } from "crypto";

import { httpRetry } from "../utils/retry";

Expand All @@ -8,7 +9,7 @@ export class ProviderService {
private readonly fetch: typeof global.fetch
) {}

async hasCertificate(network: SupportedChainNetworks, providerAddress: string, serialNumber: string): Promise<boolean> {
async getCertificate(network: SupportedChainNetworks, providerAddress: string, serialNumber: string): Promise<X509Certificate | null> {
const queryParams = new URLSearchParams({
"filter.state": "valid",
"filter.owner": providerAddress,
Expand All @@ -22,13 +23,17 @@ export class ProviderService {

if (response.status >= 200 && response.status < 300) {
const body = (await response.json()) as KnownCertificatesResponseBody;
return body.certificates.length === 1;
return body.certificates.length === 1 ? new X509Certificate(atob(body.certificates[0].certificate.cert)) : null;
}

return false;
return null;
}
}

interface KnownCertificatesResponseBody {
certificates: unknown[];
certificates: Array<{
certificate: {
cert: string;
};
}>;
}
2 changes: 1 addition & 1 deletion apps/provider-proxy/src/utils/retry.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { setTimeout } from "timers/promises";

export const httpRetry = <T>(callback: () => Promise<T>, options: HttpRetryOptions<T>): Promise<T> => {
return retryWithBackoff(callback, options.retryIf, options.maxRetries || 5, 0);
return retryWithBackoff(callback, options.retryIf, options.maxRetries || 3, 0);
};

export interface HttpRetryOptions<T> {
Expand Down
Loading

0 comments on commit f591481

Please sign in to comment.