Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extracts certificate validation into a separate class #764

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/provider-proxy/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
},
"dependencies": {
"@akashnetwork/net": "*",
"async-sema": "^3.1.1",
"bech32": "^2.0.0",
"cors": "^2.8.5",
"express": "^4.18.2",
Expand Down
11 changes: 9 additions & 2 deletions apps/provider-proxy/src/container.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { netConfig, SupportedChainNetworks } from "@akashnetwork/net";

import { CertificateValidator, createCertificateValidatorInstrumentation } from "./services/CertificateValidator";
import { ProviderProxy } from "./services/ProviderProxy";
import { ProviderService } from "./services/ProviderService";
import { WebsocketStats } from "./services/WebsocketStats";
Expand All @@ -13,9 +14,15 @@ const providerService = new ProviderService((network: SupportedChainNetworks) =>
// @see https://github.com/mswjs/msw/discussions/2416
return process.env.TEST_CHAIN_NETWORK_URL || netConfig.getBaseAPIUrl(network);
}, fetch);
const providerProxy = new ProviderProxy(Date.now, providerService);
const certificateValidator = new CertificateValidator(
stalniy marked this conversation as resolved.
Show resolved Hide resolved
Date.now,
providerService,
process.env.NODE_ENV === "test" ? undefined : createCertificateValidatorInstrumentation(console)
);
const providerProxy = new ProviderProxy(certificateValidator);

export const container = {
wsStats,
providerProxy
providerProxy,
certificateValidator
};
152 changes: 152 additions & 0 deletions apps/provider-proxy/src/services/CertificateValidator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { Sema } from "async-sema";
import { bech32 } from "bech32";
import { X509Certificate } from "crypto";
import { LRUCache } from "lru-cache";

import { ProviderService } from "./ProviderService";

export class CertificateValidator {
private readonly knownCertificatesCache = new LRUCache<string, X509Certificate | null>({
max: 100_000,
ttl: 30 * 60 * 1000
});
private readonly locks: Record<string, Sema> = {};

constructor(
private readonly now: () => number,
private readonly providerService: ProviderService,
private readonly instrumentation?: CertificateValidatorIntrumentation
) {}

async validate(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string): Promise<CertValidationResult> {
const now = this.now();
const validationResult = validateCertificateAttrs(certificate, now);

if (validationResult.ok === false) {
this.instrumentation?.onInvalidAttrs?.(certificate, network, providerAddress, now, validationResult);
return validationResult;
}

const providerCertificate = await this.getProviderCertificate(certificate, network, providerAddress);
if (!providerCertificate) {
this.instrumentation?.onUnknownCert?.(certificate, network, providerAddress);
return { ok: false, code: "unknownCertificate" };
}

if (providerCertificate.fingerprint256 !== certificate.fingerprint256) {
this.instrumentation?.onInvalidFingerprint?.(certificate, network, providerAddress, providerCertificate);
return { ok: false, code: "fingerprintMismatch" };
}

this.instrumentation?.onValidationSuccess?.(certificate, network, providerAddress, now);

return { ok: true };
}

private async getProviderCertificate(cert: X509Certificate, network: SupportedChainNetworks, providerAddress: string): Promise<X509Certificate | null> {
const key = `${network}.${providerAddress}.${cert.serialNumber}`;

this.locks[key] ??= new Sema(1);

try {
await this.locks[key].acquire();
if (!this.knownCertificatesCache.has(key)) {
const certificate = await this.providerService.getCertificate(network, providerAddress, cert.serialNumber);
this.knownCertificatesCache.set(key, certificate);
return certificate;
}

return this.knownCertificatesCache.get(key);
} finally {
this.locks[key].release();
delete this.locks[key];
}
}
}

export type CertValidationResult = { ok: true } | CertValidationResultError;
export type CertValidationResultError = {
ok: false;
code: "validInFuture" | "expired" | "invalidSerialNumber" | "notSelfSigned" | "CommonNameIsNotBech32" | "unknownCertificate" | "fingerprintMismatch";
};
export interface CertificateValidatorIntrumentation {
onValidationSuccess?(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string, now: number): void;
onInvalidAttrs?(
certificate: X509Certificate,
network: SupportedChainNetworks,
providerAddress: string,
now: number,
validationResult: CertValidationResultError
): void;
onUnknownCert?(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string): void;
onInvalidFingerprint?(certificate: X509Certificate, network: SupportedChainNetworks, providerAddress: string, providerCertificate: X509Certificate): void;
}

function validateCertificateAttrs(cert: X509Certificate, now: number): CertValidationResult {
if (new Date(cert.validFrom).getTime() > now) {
return {
ok: false,
code: "validInFuture"
};
}

if (new Date(cert.validTo).getTime() < now) {
return {
ok: false,
code: "expired"
};
}

if (!cert.serialNumber?.trim()) {
return {
ok: false,
code: "invalidSerialNumber"
};
}

if (cert.issuer !== cert.subject) {
return {
ok: false,
code: "notSelfSigned"
};
}

const commonName = parseCertSubject(cert.subject, "CN");
if (!commonName || !bech32.decodeUnsafe(commonName)) {
return {
ok: false,
code: "CommonNameIsNotBech32"
};
}

return { ok: true };
}

function parseCertSubject(subject: string, attr: string): string | null {
const attrPrefix = `${attr}=`;
const index = subject.indexOf(attrPrefix);
if (index === -1) return null;

const endIndex = subject.indexOf("\n", index);
if (endIndex === -1) return subject.slice(index);

return subject.slice(index + attrPrefix.length, endIndex);
}

export const createCertificateValidatorInstrumentation = (logger: typeof console): CertificateValidatorIntrumentation => ({
onValidationSuccess(certificate, network, providerAddress, now) {
logger.log(`Successfully validated ${certificate.serialNumber} in ${network} for "${providerAddress}" at ${now}`);
},
onInvalidAttrs(certificate, network, providerAddress, now, result) {
logger.log(`Certificate ${certificate.serialNumber} is invalid in ${network} for "${providerAddress}" because ${result.code} at ${now}`);
},
onInvalidFingerprint(certificate, network, providerAddress, providerCertificate) {
logger.log(
`Certificate ${certificate.serialNumber} (${certificate.fingerprint256}) fingerprint does not match fingerprint in ${network} for ${providerAddress}: ${providerCertificate.fingerprint256}`
);
},
onUnknownCert(certificate, network, providerAddress) {
logger.log(`Certificate ${certificate.serialNumber} does not have corresponding certificate in ${network} for ${providerAddress}`);
}
});
78 changes: 34 additions & 44 deletions apps/provider-proxy/src/services/ProviderProxy.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { X509Certificate } from "crypto";
import { IncomingMessage } from "http";
import https, { RequestOptions } from "https";
import { LRUCache } from "lru-cache";
import { TLSSocket } from "tls";

import { CertValidationResultError, validateCertificate } from "../utils/validateCertificate";
import { ProviderService } from "./ProviderService";
import { CertificateValidator, CertValidationResultError } from "./CertificateValidator";

export class ProviderProxy {
private readonly knownCertificatesCache = new LRUCache<string, boolean>({
max: 100_000,
ttl: 30 * 60 * 1000
});
/**
* Cache agents in order to control TLS session resumption
*/
private readonly agentsCache = new LRUCache<string, https.Agent>({
max: 100_000
max: 1_000_000
});

constructor(
private readonly now: () => number,
private readonly providerService: ProviderService
) {}
constructor(private readonly certificateValidator: CertificateValidator) {}

connect(url: string, options: ProxyConnectOptions): Promise<ProxyConnectionResult> {
const agent = this.getHttpsAgent(options.network, options.providerAddress, {
const agentOptions: TLSChainAgentOptions = {
timeout: options.timeout,
rejectUnauthorized: false,
cert: options.cert,
key: options.key,
rejectUnauthorized: false
});
chainNetwork: options.network,
providerAddress: options.providerAddress
};
const agent = this.getHttpsAgent(agentOptions);
return new Promise<ProxyConnectionResult>((resolve, reject) => {
const req = https.request(
url,
Expand Down Expand Up @@ -59,18 +56,15 @@ export class ProviderProxy {
const didHandshake = !!serverCert;

if (didHandshake && options.network && options.providerAddress) {
const validationResult = validateCertificate(serverCert, this.now());
const validationResult = await this.certificateValidator.validate(serverCert, options.network, options.providerAddress);
if (validationResult.ok === false) {
// remove agent from cache to destroy TLS session to force TLS handshake on the next call
this.agentsCache.delete(genAgentsCacheKey(agentOptions));
resolve({ ok: false, code: "invalidCertificate", reason: validationResult.code });
req.off("error", reject);
req.destroy();
this.agentsCache.delete(`${options.network}.${options.providerAddress}`);
return resolve({ ok: false, code: "invalidCertificate", reason: validationResult.code });
}

const isKnown = await this.isKnownCertificate(serverCert, options.network, options.providerAddress);
if (!isKnown) {
req.destroy();
this.agentsCache.delete(`${options.network}.${options.providerAddress}`);
return resolve({ ok: false, code: "invalidCertificate", reason: "unknownCertificate" });
agent.destroy();
return;
}
}

Expand All @@ -82,9 +76,7 @@ export class ProviderProxy {
);

if (!req.reusedSocket) {
req.on("error", (error: (Error & { code: string }) | undefined) => {
reject(error);
});
req.on("error", reject);
req.on("timeout", () => {
// here we are just notified that response take more than specified in request options timeout
// then we manually destroy request and it drops connection and
Expand All @@ -98,23 +90,12 @@ export class ProviderProxy {
});
}

private async isKnownCertificate(cert: X509Certificate, network: SupportedChainNetworks, providerAddress: string): Promise<boolean> {
const key = `${network}.${providerAddress}.${cert.serialNumber}`;

if (!this.knownCertificatesCache.has(key)) {
const hasCertificate = await this.providerService.hasCertificate(network, providerAddress, cert.serialNumber);
this.knownCertificatesCache.set(key, hasCertificate);
return hasCertificate;
}

return this.knownCertificatesCache.get(key);
}

private getHttpsAgent(network: SupportedChainNetworks, providerAddress: string, options: https.AgentOptions): https.Agent {
const key = `${network}.${providerAddress}`;
private getHttpsAgent(options: TLSChainAgentOptions): https.Agent {
const key = genAgentsCacheKey(options);

if (!this.agentsCache.has(key)) {
const agent = new https.Agent(options);
const { chainNetwork, providerAddress, ...agentOptions } = options;
const agent = new https.Agent(agentOptions);
this.agentsCache.set(key, agent);
return agent;
}
Expand All @@ -123,6 +104,10 @@ export class ProviderProxy {
}
}

function genAgentsCacheKey(options: TLSChainAgentOptions): string {
return `${options.chainNetwork}:${options.providerAddress}:${options.cert}:${options.key}`;
}

export interface ProxyConnectOptions extends Pick<RequestOptions, "cert" | "key" | "method"> {
body?: BodyInit;
headers?: Record<string, string>;
Expand All @@ -140,5 +125,10 @@ interface ProxyConnectionResultSuccess {
}

type ProxyConnectionResultError =
| { ok: false; code: "invalidCertificate"; reason: CertValidationResultError["code"] | "unknownCertificate" }
| { ok: false; code: "invalidCertificate"; reason: CertValidationResultError["code"] }
| { ok: false; code: "insecureConnection" };

interface TLSChainAgentOptions extends https.AgentOptions {
chainNetwork: SupportedChainNetworks;
providerAddress: string;
}
13 changes: 9 additions & 4 deletions apps/provider-proxy/src/services/ProviderService.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { X509Certificate } from "crypto";

import { httpRetry } from "../utils/retry";

Expand All @@ -8,7 +9,7 @@ export class ProviderService {
private readonly fetch: typeof global.fetch
) {}

async hasCertificate(network: SupportedChainNetworks, providerAddress: string, serialNumber: string): Promise<boolean> {
async getCertificate(network: SupportedChainNetworks, providerAddress: string, serialNumber: string): Promise<X509Certificate | null> {
const queryParams = new URLSearchParams({
"filter.state": "valid",
"filter.owner": providerAddress,
Expand All @@ -22,13 +23,17 @@ export class ProviderService {

if (response.status >= 200 && response.status < 300) {
const body = (await response.json()) as KnownCertificatesResponseBody;
return body.certificates.length === 1;
return body.certificates.length === 1 ? new X509Certificate(atob(body.certificates[0].certificate.cert)) : null;
}

return false;
return null;
}
}

interface KnownCertificatesResponseBody {
certificates: unknown[];
certificates: Array<{
certificate: {
cert: string;
};
}>;
}
2 changes: 1 addition & 1 deletion apps/provider-proxy/src/utils/retry.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { setTimeout } from "timers/promises";

export const httpRetry = <T>(callback: () => Promise<T>, options: HttpRetryOptions<T>): Promise<T> => {
return retryWithBackoff(callback, options.retryIf, options.maxRetries || 5, 0);
return retryWithBackoff(callback, options.retryIf, options.maxRetries || 3, 0);
};

export interface HttpRetryOptions<T> {
Expand Down
Loading
Loading