Skip to content

Commit

Permalink
Refactor verifier for issue #418 (#419)
Browse files Browse the repository at this point in the history
* refactor verifier

* move rest_network_test to agent_test

* resolve token command comments in #375 

* refactor token cmd without depending on launcher agent

* decouple launcher agent from cloud logger

* extract agent common functions to package util

* extract getRestClient function

* extract getRegion function

* fix fake_oauth2_server

* use constants in the fake

* move util to internal

* replace fakeOauth2Credential with os.CreateTemp

* refactor principalFetcher

* add PrincipleFetcher unit test
  • Loading branch information
Ruide authored Mar 15, 2024
1 parent de26f21 commit cf9fd6d
Show file tree
Hide file tree
Showing 28 changed files with 453 additions and 314 deletions.
9 changes: 5 additions & 4 deletions cmd/attest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
tgtestclient "github.com/google/go-tdx-guest/testing/client"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/internal/test"
"github.com/google/go-tpm-tools/internal/util"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
)
Expand Down Expand Up @@ -205,8 +206,8 @@ func TestFormatFlagFail(t *testing.T) {
}

func TestMetadataPass(t *testing.T) {
var dummyInstance = Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mock, err := NewMetadataServer(dummyInstance)
var dummyInstance = util.Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mock, err := util.NewMetadataServer(dummyInstance)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -270,8 +271,8 @@ func TestAttestWithGCEAK(t *testing.T) {
}
defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getIndex[op.keyAlgo]))

var dummyInstance = Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mock, err := NewMetadataServer(dummyInstance)
var dummyInstance = util.Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mock, err := util.NewMetadataServer(dummyInstance)
if err != nil {
t.Error(err)
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/fake_cloudlogging_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

logpb "cloud.google.com/go/logging/apiv2/loggingpb"
tspb "github.com/golang/protobuf/ptypes/timestamp"
"github.com/google/go-tpm-tools/internal/util"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -125,7 +126,7 @@ func (h *loggingHandler) WriteLogEntries(_ context.Context, req *logpb.WriteLogE
var logEntryPayload []map[string]interface{}
logEntryPayload = append(logEntryPayload, map[string]interface{}{"aud": "test", "iat": float64(1709752525), "exp": float64(1919752525)})
logEntryPayload = append(logEntryPayload, map[string]interface{}{"token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0IiwiaWF0IjoxNzA5NzUyNTI1LCJleHAiOjE5MTk3NTI1MjV9.EBLA2zX3c-Fu0l--J9Gey6LIXMO1TFRCoe3bzuPGc1k"})
logEntryPayload = append(logEntryPayload, map[string]interface{}{"Name": "projects/test-project/locations/us-central-1/challenges/" + fakeChallengeUUID, "Nonce": fakeTpmNonce, "ConnID": ""})
logEntryPayload = append(logEntryPayload, map[string]interface{}{"Name": "projects/test-project/locations/us-central-1/challenges/" + util.FakeChallengeUUID, "Nonce": util.FakeTpmNonce, "ConnID": ""})
attestationMapFields := []string{"TeeAttestation", "ak_pub", "quotes", "event_log", "ak_cert"}
for _, entry := range h.logs["projects/"+TestProjectID+"/logs/"+toolName] {
payload := entry.GetJsonPayload().AsMap()
Expand Down
32 changes: 0 additions & 32 deletions cmd/fake_oauth2_server.go

This file was deleted.

6 changes: 0 additions & 6 deletions cmd/testdata/credentials

This file was deleted.

108 changes: 39 additions & 69 deletions cmd/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,22 @@ import (
"errors"
"fmt"
"log"
"net/url"
"strings"
"time"

"cloud.google.com/go/compute/metadata"
"cloud.google.com/go/logging"
"github.com/containerd/containerd/namespaces"
"github.com/golang-jwt/jwt/v4"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/launcher/agent"
"github.com/google/go-tpm-tools/launcher/spec"
"github.com/google/go-tpm-tools/launcher/verifier"
"github.com/google/go-tpm-tools/launcher/verifier/rest"
"github.com/google/go-tpm-tools/internal/util"
"github.com/google/go-tpm-tools/verifier"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/spf13/cobra"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

var mdsClient *metadata.Client
var mockCloudLoggingServerAddress string

const toolName = "gotpm"
Expand All @@ -37,7 +31,7 @@ var tokenCmd = &cobra.Command{
Use: "token",
Short: "Attest and fetch an OIDC token from Google Attestation Verification Service.",
Long: `Gather attestation report and send it to Google Attestation Verification Service for an OIDC token.
The OIDC token includes claims regarding the GCE VM, which is verified by Attestation Verification Service. Note that Confidential Computing API needs to be enabled for your account to access Google Attestation Verification Service https://pantheon.corp.google.com/apis/api/confidentialcomputing.googleapis.com.
The OIDC token includes claims regarding the GCE VM, which is verified by Attestation Verification Service. Note that Confidential Computing API needs to be enabled for your account to access Google Attestation Verification Service https://console.cloud.google.com/apis/api/confidentialcomputing.googleapis.com.
--algo flag overrides the public key algorithm for the GCE TPM attestation key. If not provided then by default rsa is used.
`,
Args: cobra.NoArgs,
Expand All @@ -49,32 +43,13 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
defer rwc.Close()

// Metadata Server (MDS). A GCP specific client.
mdsClient = metadata.NewClient(nil)
mdsClient := metadata.NewClient(nil)

ctx := namespaces.WithNamespace(context.Background(), namespaces.Default)
// TODO: principalFetcher is copied from go-tpm-tools/launcher/container_runner.go, to be refactored
// Fetch GCP specific ID token with specific audience.
// See https://cloud.google.com/functions/docs/securing/authenticating#functions-bearer-token-example-go.
principalFetcher := func(audience string) ([][]byte, error) {
u := url.URL{
Path: "instance/service-accounts/default/identity",
RawQuery: url.Values{
"audience": {audience},
"format": {"full"},
}.Encode(),
}
idToken, err := mdsClient.Get(u.String())
if err != nil {
return nil, fmt.Errorf("failed to get principal tokens: %w", err)
}
fmt.Fprintf(debugOutput(), "GCP ID token fetched is: %s\n", idToken)
tokens := [][]byte{[]byte(idToken)}
return tokens, nil
}

fmt.Fprintf(debugOutput(), "Attestation Address is set to %s\n", asAddress)

region, err := getRegion(mdsClient)
region, err := util.GetRegion(mdsClient)
if err != nil {
return fmt.Errorf("failed to fetch Region from MDS, the tool is probably not running in a GCE VM: %v", err)
}
Expand All @@ -84,7 +59,7 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
return fmt.Errorf("failed to retrieve ProjectID from MDS: %v", err)
}

verifierClient, err := getRESTClient(ctx, asAddress, projectID, region)
verifierClient, err := util.NewRESTClient(ctx, asAddress, projectID, region)
if err != nil {
return fmt.Errorf("failed to create REST verifier client: %v", err)
}
Expand All @@ -104,7 +79,7 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
return err
}
if gceAK.Cert() == nil {
return errors.New("failed to find gceAKCert on this VM: try creating a new VM or verifying the VM has an EK cert using get-shielded-identity gcloud command. The used key algorithm is: " + usedKeyAlgo)
return errors.New("failed to find GCE AK Certificate on this VM: try creating a new VM or verifying the VM has an EK cert using get-shielded-identity gcloud command. The used key algorithm is: " + usedKeyAlgo)
}
gceAK.Close()

Expand Down Expand Up @@ -135,14 +110,41 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
}

key = "gceAK"
attestAgent := agent.CreateAttestationAgent(rwc, attestationKeys[key][keyAlgo], verifierClient, principalFetcher, nil, spec.LaunchSpec{}, nil, cloudLogger)

fmt.Fprintf(debugOutput(), "Fetching attestation verifier OIDC token\n")
token, err := attestAgent.Attest(ctx, agent.AttestAgentOpts{Aud: audience, TokenType: "OIDC"})

challenge, err := verifierClient.CreateChallenge(ctx)
if err != nil {
return fmt.Errorf("failed to retrieve attestation service token: %v", err)
return err
}

principalTokens, err := util.PrincipalFetcher(challenge.Name, mdsClient)
if err != nil {
return fmt.Errorf("failed to get principal tokens: %w", err)
}

attestation, err := util.FetchAttestation(rwc, attestationKeys[key][keyAlgo], challenge.Nonce)
if err != nil {
return err
}

req := verifier.VerifyAttestationRequest{
Challenge: challenge,
GcpCredentials: principalTokens,
Attestation: attestation,
TokenOptions: verifier.TokenOptions{CustomAudience: audience, TokenType: "OIDC"},
}

resp, err := verifierClient.VerifyAttestation(ctx, req)
if err != nil {
return err
}
if len(resp.PartialErrs) > 0 {
fmt.Fprintf(debugOutput(), "partial errors from VerifyAttestation: %v", resp.PartialErrs)
}

token := resp.ClaimsToken

// Get token expiration.
claims := &jwt.RegisteredClaims{}
_, _, err = jwt.NewParser().ParseUnverified(string(token), claims)
Expand Down Expand Up @@ -176,6 +178,8 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
}

if cloudLog {
cloudLogger.Log(logging.Entry{Payload: challenge})
cloudLogger.Log(logging.Entry{Payload: attestation})
cloudLogger.Log(logging.Entry{Payload: map[string]string{"token": string(token)}})
cloudLogger.Log(logging.Entry{Payload: mapClaims})
cloudLogClient.Close()
Expand All @@ -190,40 +194,6 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
},
}

// TODO: getRESTClient is copied from go-tpm-tools/launcher/container_runner.go, to be refactored.
// getRESTClient returns a REST verifier.Client that points to the given address.
// It defaults to the Attestation Verifier instance at
// https://confidentialcomputing.googleapis.com.
func getRESTClient(ctx context.Context, asAddr string, ProjectID string, Region string) (verifier.Client, error) {
httpClient, err := google.DefaultClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %v", err)
}

opts := []option.ClientOption{option.WithHTTPClient(httpClient)}
if asAddr != "" {
opts = append(opts, option.WithEndpoint(asAddr))
}

restClient, err := rest.NewClient(ctx, ProjectID, Region, opts...)
if err != nil {
return nil, err
}
return restClient, nil
}

func getRegion(client *metadata.Client) (string, error) {
zone, err := client.Zone()
if err != nil {
return "", fmt.Errorf("failed to retrieve zone from MDS: %v", err)
}
lastDash := strings.LastIndex(zone, "-")
if lastDash == -1 {
return "", fmt.Errorf("got malformed zone from MDS: %v", zone)
}
return zone[:lastDash], nil
}

func init() {
RootCmd.AddCommand(tokenCmd)
addOutputFlag(tokenCmd)
Expand Down
18 changes: 11 additions & 7 deletions cmd/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/internal/test"
"github.com/google/go-tpm-tools/internal/util"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -49,24 +50,27 @@ func TestTokenWithGCEAK(t *testing.T) {
defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getIndex[op.algo]))
defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[op.algo]))

var dummyMetaInstance = Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mockMdsServer, err := NewMetadataServer(dummyMetaInstance)
var dummyMetaInstance = util.Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mockMdsServer, err := util.NewMetadataServer(dummyMetaInstance)
if err != nil {
t.Error(err)
}
defer mockMdsServer.Stop()

mockOauth2Server := newMockOauth2Server()
mockOauth2Server, err := util.NewMockOauth2Server()
if err != nil {
t.Error(err)
}
defer mockOauth2Server.Stop()

// Endpoint is Google's OAuth 2.0 default endpoint. Change to mock server.
google.Endpoint = oauth2.Endpoint{
AuthURL: mockOauth2Server.server.URL + "/o/oauth2/auth",
TokenURL: mockOauth2Server.server.URL + "/token",
AuthURL: mockOauth2Server.Server.URL + "/o/oauth2/auth",
TokenURL: mockOauth2Server.Server.URL + "/token",
AuthStyle: oauth2.AuthStyleInParams,
}

mockAttestationServer, err := newMockAttestationServer()
mockAttestationServer, err := util.NewMockAttestationServer()
if err != nil {
t.Error(err)
}
Expand All @@ -77,7 +81,7 @@ func TestTokenWithGCEAK(t *testing.T) {
t.Error(err)
}

RootCmd.SetArgs([]string{"token", "--algo", op.algo, "--output", secretFile1, "--verifier-endpoint", mockAttestationServer.server.URL, "--cloud-log", "--audience", "https://api.test.com"})
RootCmd.SetArgs([]string{"token", "--algo", op.algo, "--output", secretFile1, "--verifier-endpoint", mockAttestationServer.Server.URL, "--cloud-log", "--audience", "https://api.test.com"})
if err := RootCmd.Execute(); err != nil {
t.Error(err)
}
Expand Down
5 changes: 3 additions & 2 deletions cmd/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
tgtestdata "github.com/google/go-tdx-guest/testing/testdata"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/internal/test"
"github.com/google/go-tpm-tools/internal/util"
pb "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
Expand Down Expand Up @@ -93,8 +94,8 @@ func TestVerifyWithGCEAK(t *testing.T) {
}
defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getIndex[op.keyAlgo]))

var dummyInstance = Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mock, err := NewMetadataServer(dummyInstance)
var dummyInstance = util.Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mock, err := util.NewMetadataServer(dummyInstance)
if err != nil {
t.Error(err)
}
Expand Down
Loading

0 comments on commit cf9fd6d

Please sign in to comment.