diff --git a/launcher/agent/agent.go b/launcher/agent/agent.go index 6b92a638..f8432d89 100644 --- a/launcher/agent/agent.go +++ b/launcher/agent/agent.go @@ -49,6 +49,8 @@ type AttestationAgent interface { Attest(context.Context, AttestAgentOpts) ([]byte, error) Refresh(context.Context) error Close() error + AddClient(client verifier.Client, verifier VerifierType) error + HasClient(verifier VerifierType) bool } type attestRoot interface { @@ -69,9 +71,17 @@ type AttestAgentOpts struct { TokenType string } +type VerifierType int + +const ( + GCA VerifierType = iota + 1 + ITA +) + // Clients contains clients for supported verifier services. type Clients struct { GCA verifier.Client + ITA verifier.Client } type agent struct { @@ -278,6 +288,33 @@ func convertOCIToContainerSignature(ociSig oci.Signature) (*verifier.ContainerSi }, nil } +// AddClient adds the given client for the provided verifier service. Returns error if a client +// already exists for the service. +func (a *agent) AddClient(client verifier.Client, verifier VerifierType) error { + if a.HasClient(verifier) { + return fmt.Errorf("client for verifier service %v already exists", verifier) + } + + switch verifier { + case GCA: + a.clients.GCA = client + case ITA: + a.clients.ITA = client + } + return nil +} + +// HasClient returns whether a client has been added for the given verifier. +func (a *agent) HasClient(verifier VerifierType) bool { + switch verifier { + case GCA: + return a.clients.GCA != nil + case ITA: + return a.clients.ITA != nil + } + return false +} + type tpmAttestRoot struct { tpmMu sync.Mutex fetchedAK *client.Key diff --git a/launcher/container_runner.go b/launcher/container_runner.go index 702fc05a..50d186ce 100644 --- a/launcher/container_runner.go +++ b/launcher/container_runner.go @@ -35,6 +35,7 @@ import ( "github.com/google/go-tpm-tools/launcher/registryauth" "github.com/google/go-tpm-tools/launcher/spec" "github.com/google/go-tpm-tools/launcher/teeserver" + "github.com/google/go-tpm-tools/verifier/ita" "github.com/google/go-tpm-tools/verifier/util" v1 "github.com/opencontainers/image-spec/specs-go/v1" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -216,12 +217,22 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To asAddr := launchSpec.AttestationServiceAddr - gcaClient, err := util.NewRESTClient(ctx, asAddr, launchSpec.ProjectID, launchSpec.Region) - if err != nil { - return nil, fmt.Errorf("failed to create REST verifier client: %v", err) - } + clients := &agent.Clients{} + if launchSpec.ITARegionalKey != "" { + itaClient, err := ita.NewClient(launchSpec.ITARegionalKey) + if err != nil { + return nil, fmt.Errorf("failed to create ITA verifier client: %v", err) + } + + clients.ITA = itaClient + } else { + gcaClient, err := util.NewRESTClient(ctx, asAddr, launchSpec.ProjectID, launchSpec.Region) + if err != nil { + return nil, fmt.Errorf("failed to create REST verifier client: %v", err) + } - clients := &agent.Clients{GCA: gcaClient} + clients.GCA = gcaClient + } // Create a new signaturediscovery client to fetch signatures. sdClient := getSignatureDiscoveryClient(cdClient, mdsClient, image.Target()) @@ -563,13 +574,16 @@ func (r *ContainerRunner) Run(ctx context.Context) error { return fmt.Errorf("failed to measure CEL events: %v", err) } - if err := r.fetchAndWriteToken(ctx); err != nil { - return fmt.Errorf("failed to fetch and write OIDC token: %v", err) + // Only create default token if GCA client was added. + if r.attestAgent.HasClient(agent.GCA) { + if err := r.fetchAndWriteToken(ctx); err != nil { + return fmt.Errorf("failed to fetch and write OIDC token: %v", err) + } } // create and start the TEE server r.logger.Info("EnableOnDemandAttestation is enabled: initializing TEE server.") - teeServer, err := teeserver.New(ctx, path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger) + teeServer, err := teeserver.New(ctx, path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger, r.launchSpec) if err != nil { return fmt.Errorf("failed to create the TEE server: %v", err) } diff --git a/launcher/spec/launch_spec.go b/launcher/spec/launch_spec.go index ed9fd3f1..96b27006 100644 --- a/launcher/spec/launch_spec.go +++ b/launcher/spec/launch_spec.go @@ -86,6 +86,7 @@ const ( monitoringEnable = "tee-monitoring-enable" devShmSizeKey = "tee-dev-shm-size-kb" mountKey = "tee-mount" + itaRegionAndKey = "ita-regional-key" ) const ( @@ -117,6 +118,7 @@ type LaunchSpec struct { MonitoringEnabled MonitoringType LogRedirect LogRedirectLocation Mounts []launchermount.Mount + ITARegionalKey string // DevShmSize is specified in kiB. DevShmSize int64 Experiments experiments.Experiments @@ -240,6 +242,10 @@ func (s *LaunchSpec) UnmarshalJSON(b []byte) error { } } + if val, ok := unmarshaledMap[itaRegionAndKey]; ok && val != "" { + s.ITARegionalKey = val + } + return nil } diff --git a/launcher/spec/launch_spec_test.go b/launcher/spec/launch_spec_test.go index ae8acb2f..5fb0ba81 100644 --- a/launcher/spec/launch_spec_test.go +++ b/launcher/spec/launch_spec_test.go @@ -26,7 +26,8 @@ func TestLaunchSpecUnmarshalJSONHappyCases(t *testing.T) { "tee-container-log-redirect":"true", "tee-monitoring-memory-enable":"true", "tee-dev-shm-size-kb":"234234", - "tee-mount":"type=tmpfs,source=tmpfs,destination=/tmpmount;type=tmpfs,source=tmpfs,destination=/sized,size=222" + "tee-mount":"type=tmpfs,source=tmpfs,destination=/tmpmount;type=tmpfs,source=tmpfs,destination=/sized,size=222", + "ita-regional-key":"US:test-api-key" }`, }, { @@ -43,7 +44,8 @@ func TestLaunchSpecUnmarshalJSONHappyCases(t *testing.T) { "tee-container-log-redirect":"true", "tee-monitoring-memory-enable":"TRUE", "tee-dev-shm-size-kb":"234234", - "tee-mount":"type=tmpfs,source=tmpfs,destination=/tmpmount;type=tmpfs,source=tmpfs,destination=/sized,size=222" + "tee-mount":"type=tmpfs,source=tmpfs,destination=/tmpmount;type=tmpfs,source=tmpfs,destination=/sized,size=222", + "ita-regional-key":"US:test-api-key" }`, }, } @@ -63,6 +65,7 @@ func TestLaunchSpecUnmarshalJSONHappyCases(t *testing.T) { Experiments: experiments.Experiments{ EnableTempFSMount: true, }, + ITARegionalKey: "US:test-api-key", } for _, testcase := range testCases { diff --git a/launcher/teeserver/tee_server.go b/launcher/teeserver/tee_server.go index aa323ff6..e334ab80 100644 --- a/launcher/teeserver/tee_server.go +++ b/launcher/teeserver/tee_server.go @@ -11,13 +11,16 @@ import ( "github.com/google/go-tpm-tools/launcher/agent" "github.com/google/go-tpm-tools/launcher/internal/logging" + "github.com/google/go-tpm-tools/launcher/spec" + "github.com/google/go-tpm-tools/verifier/util" ) type attestHandler struct { ctx context.Context attestAgent agent.AttestationAgent // defaultTokenFile string - logger logging.Logger + logger logging.Logger + launchSpec spec.LaunchSpec } type customTokenRequest struct { @@ -34,7 +37,7 @@ type TeeServer struct { } // New takes in a socket and start to listen to it, and create a server -func New(ctx context.Context, unixSock string, a agent.AttestationAgent, logger logging.Logger) (*TeeServer, error) { +func New(ctx context.Context, unixSock string, a agent.AttestationAgent, logger logging.Logger, launchSpec spec.LaunchSpec) (*TeeServer, error) { var err error nl, err := net.Listen("unix", unixSock) if err != nil { @@ -48,7 +51,8 @@ func New(ctx context.Context, unixSock string, a agent.AttestationAgent, logger ctx: ctx, attestAgent: a, // defaultTokenFile: filepath.Join(launcherfile.HostTmpPath, launcherfile.AttestationVerifierTokenFilename), - logger: logger, + logger: logger, + launchSpec: launchSpec, }).Handler(), }, } @@ -73,6 +77,20 @@ func (a *attestHandler) Handler() http.Handler { func (a *attestHandler) getToken(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") + // If the agent does not have a GCA client, create one. + if !a.attestAgent.HasClient(agent.GCA) { + gcaClient, err := util.NewRESTClient(a.ctx, a.launchSpec.AttestationServiceAddr, a.launchSpec.ProjectID, a.launchSpec.Region) + if err != nil { + errStr := fmt.Sprintf("failed to create REST verifier client: %v", err) + a.logger.Error(errStr) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(errStr)) + return + } + + a.attestAgent.AddClient(gcaClient, agent.GCA) + } + switch r.Method { case "GET": if err := a.attestAgent.Refresh(a.ctx); err != nil { diff --git a/launcher/teeserver/tee_server_test.go b/launcher/teeserver/tee_server_test.go index 56ff530d..400a21bc 100644 --- a/launcher/teeserver/tee_server_test.go +++ b/launcher/teeserver/tee_server_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-tpm-tools/cel" "github.com/google/go-tpm-tools/launcher/agent" "github.com/google/go-tpm-tools/launcher/internal/logging" + "github.com/google/go-tpm-tools/verifier" ) type fakeAttestationAgent struct { @@ -34,6 +35,13 @@ func (f fakeAttestationAgent) Refresh(_ context.Context) error { func (f fakeAttestationAgent) Close() error { return nil } +func (f fakeAttestationAgent) AddClient(client verifier.Client, verifier agent.VerifierType) error { + return nil +} + +func (f fakeAttestationAgent) HasClient(_ agent.VerifierType) bool { + return false +} func TestGetDefaultToken(t *testing.T) { testTokenContent := "test token"