Skip to content

Commit

Permalink
Apply retry logics in launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
yawangwang committed Nov 11, 2024
1 parent c891518 commit 1adb6c2
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 51 deletions.
4 changes: 2 additions & 2 deletions client/pcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ var extends = map[tpm2.Algorithm][]struct {
{bytes.Repeat([]byte{0x02}, sha512.Size384)}},
}

func pcrExtend(alg tpm2.Algorithm, old, new []byte) ([]byte, error) {
func pcrExtend(alg tpm2.Algorithm, old, newBytes []byte) ([]byte, error) {
hCon, err := alg.Hash()
if err != nil {
return nil, fmt.Errorf("not a valid hash type: %v", alg)
}
h := hCon.New()
h.Write(old)
h.Write(new)
h.Write(newBytes)
return h.Sum(nil), nil
}

Expand Down
187 changes: 187 additions & 0 deletions go.work.sum

Large diffs are not rendered by default.

44 changes: 42 additions & 2 deletions launcher/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"context"
"crypto"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -24,11 +25,16 @@ import (
pb "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/verifier"
"github.com/google/go-tpm-tools/verifier/oci"
"github.com/google/go-tpm-tools/verifier/rest"
"github.com/google/go-tpm-tools/verifier/util"
"go.uber.org/multierr"
)

var defaultCELHashAlgo = []crypto.Hash{crypto.SHA256, crypto.SHA1}

// attestFunc is used for doAttest indirectly so that unit tests can stub it.
var attestFunc = doAttest

type principalIDTokenFetcher func(audience string) ([][]byte, error)

// AttestationAgent is an agent that interacts with GCE's Attestation Service
Expand Down Expand Up @@ -101,10 +107,44 @@ func (a *agent) MeasureEvent(event cel.Content) error {
return a.cosCel.AppendEvent(a.tpm, cel.CosEventPCR, defaultCELHashAlgo, event)
}

// Attest fetches the nonce and connection ID from the Attestation Service,
// Attest is a thin wrapper of AttestWithRetries with defaultRetryPolicy.
func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error) {
return a.AttestWithRetries(ctx, opts, defaultRetryPolicy)
}

// Attest executes doAttest with retries when 500 errors originate from VerifyAttestation API.
func (a *agent) AttestWithRetries(ctx context.Context, opts AttestAgentOpts, retry func() backoff.BackOff) ([]byte, error) {
var token []byte
var err error

retryErr := backoff.Retry(
func() error {
var doErr error
token, doErr = attestFunc(ctx, a, opts)
var verifyErr *rest.VerifyAttestationError
// Retry for VerifyAttestation 500 errors.
if errors.As(doErr, &verifyErr) && verifyErr.StatusCode() == http.StatusInternalServerError {
return verifyErr
}

// Otherwise, save the error and exit the retry.
err = doErr
return nil
},
retry(),
)

if retryErr != nil || err != nil {
return nil, multierr.Append(retryErr, err)
}

return token, nil
}

// doAttest fetches the nonce and connection ID from the Attestation Service,
// creates an attestation message, and returns the resultant
// principalIDTokens and Metadata Server-generated ID tokens for the instance.
func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error) {
func doAttest(ctx context.Context, a *agent, opts AttestAgentOpts) ([]byte, error) {
challenge, err := a.client.CreateChallenge(ctx)
if err != nil {
return nil, err
Expand Down
75 changes: 75 additions & 0 deletions launcher/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/base64"
"fmt"
"math"
"net/http"
"runtime"
"sync"
"testing"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/google/go-tpm-tools/verifier/oci/cosign"
"github.com/google/go-tpm-tools/verifier/rest"
"golang.org/x/oauth2/google"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
)
Expand Down Expand Up @@ -78,6 +80,79 @@ func TestAttestRacing(t *testing.T) {
agent.Close()
}

func TestAttestWithRetries(t *testing.T) {
testCases := []struct {
name string
fn func(int) ([]byte, error)
wantPass bool
wantAttempts int
}{
{
name: "success",
fn: func(int) ([]byte, error) {
return []byte("test token"), nil
},
wantPass: true,
wantAttempts: 1,
},
{
name: "failed with 500, then success",
fn: func(attempts int) ([]byte, error) {
if attempts == 1 {
return nil, rest.NewVerifyAttestationError(nil, &googleapi.Error{Code: http.StatusInternalServerError})
}
return []byte("test token"), nil
},
wantPass: true,
wantAttempts: 2,
},
{
name: "failed with 500 after attempts exceed",
fn: func(int) ([]byte, error) {
return nil, rest.NewVerifyAttestationError(nil, &googleapi.Error{Code: http.StatusInternalServerError})
},
wantPass: false,
wantAttempts: 4,
},
{
name: "failed with non-500 error",
fn: func(int) ([]byte, error) {
return nil, fmt.Errorf("other error")
},
wantPass: false,
wantAttempts: 1,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Reset stub after test case is done.
af := attestFunc
t.Cleanup(func() { attestFunc = af })

attempts := 0
// Stub attestFunc.
attestFunc = func(context.Context, *agent, AttestAgentOpts) ([]byte, error) {
attempts++
return tc.fn(attempts)
}

a := &agent{}
testRetryPolicy := func() backoff.BackOff {
return backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Millisecond), 3)
}
_, err := a.AttestWithRetries(context.Background(), AttestAgentOpts{}, testRetryPolicy)
if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("AttestWithRetries failed, gotPass %v, but wantPass %v", gotPass, tc.wantPass)
}

if gotAttempts := attempts; gotAttempts != tc.wantAttempts {
t.Errorf("AttestWithRetries failed, gotAttempts %v, but wantAttempts %v", gotAttempts, tc.wantAttempts)
}
})
}
}

func TestAttest(t *testing.T) {
ctx := context.Background()
testCases := []struct {
Expand Down
40 changes: 36 additions & 4 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,12 @@ func getSignatureDiscoveryClient(cdClient *containerd.Client, mdsClient *metadat
return registryauth.RefreshResolver(ctx, mdsClient)
}
imageFetcher := func(ctx context.Context, imageRef string, opts ...containerd.RemoteOpt) (containerd.Image, error) {
image, err := cdClient.Pull(ctx, imageRef, opts...)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, imageRef, opts...)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull signature objects from the signature image [%s]: %w", imageRef, err)
}
Expand Down Expand Up @@ -529,6 +534,11 @@ func defaultRetryPolicy() *backoff.ExponentialBackOff {
return expBack
}

func pullImageBackoffPolicy() backoff.BackOff {
b := backoff.NewConstantBackOff(time.Millisecond * 500)
return backoff.WithMaxRetries(b, 3)
}

// Run the container
// Container output will always be redirected to logger writer for now
func (r *ContainerRunner) Run(ctx context.Context) error {
Expand Down Expand Up @@ -621,17 +631,39 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
return nil
}

func pullImageWithRetries(f func() (containerd.Image, error), retry func() backoff.BackOff) (containerd.Image, error) {
var err error
var image containerd.Image
err = backoff.Retry(func() error {
image, err = f()
return err
}, retry())
if err != nil {
return nil, fmt.Errorf("failed to pull image with retries, the last error is: %w", err)
}
return image, nil
}

func initImage(ctx context.Context, cdClient *containerd.Client, launchSpec spec.LaunchSpec, token oauth2.Token) (containerd.Image, error) {
if token.Valid() {
remoteOpt := containerd.WithResolver(registryauth.Resolver(token.AccessToken))

image, err := cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack, remoteOpt)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack, remoteOpt)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull the image: %w", err)
}
return image, nil
}
image, err := cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack)
image, err := pullImageWithRetries(
func() (containerd.Image, error) {
return cdClient.Pull(ctx, launchSpec.ImageRef, containerd.WithPullUnpack)
},
pullImageBackoffPolicy,
)
if err != nil {
return nil, fmt.Errorf("cannot pull the image (no token, only works for a public image): %w", err)
}
Expand Down
51 changes: 51 additions & 0 deletions launcher/container_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,57 @@ func TestMeasureCELEvents(t *testing.T) {
}
}

func TestPullImageWithRetries(t *testing.T) {
testCases := []struct {
name string
imagePuller func(int) (containerd.Image, error)
wantPass bool
}{
{
name: "success with single attempt",
imagePuller: func(int) (containerd.Image, error) { return &fakeImage{}, nil },
wantPass: true,
},
{
name: "failure then success",
imagePuller: func(attempts int) (containerd.Image, error) {
if attempts%2 == 1 {
return nil, errors.New("fake error")
}
return &fakeImage{}, nil
},
wantPass: true,
},
{
name: "failure with attempts exceeded",
imagePuller: func(int) (containerd.Image, error) {
return nil, errors.New("fake error")
},
wantPass: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
retryPolicy := func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
return backoff.WithMaxRetries(b, 2)
}

attempts := 0
_, err := pullImageWithRetries(
func() (containerd.Image, error) {
attempts++
return tc.imagePuller(attempts)
},
retryPolicy)
if gotPass := (err == nil); gotPass != tc.wantPass {
t.Errorf("pullImageWithRetries failed, got %v, but want %v", gotPass, tc.wantPass)
}
})
}
}

// This ensures fakeContainer implements containerd.Container interface.
var _ containerd.Container = &fakeContainer{}

Expand Down
Loading

0 comments on commit 1adb6c2

Please sign in to comment.