Skip to content

Commit

Permalink
Add back 250 year max duration cap (#7337)
Browse files Browse the repository at this point in the history
## What changed?
<!-- Describe what has changed in this PR -->
Adding back a cap for durations at 250 years.

## Why?
<!-- Tell your future self why have you made these changes -->
To prevent serialization errors for very large proto durations.

## How did you test it?
<!-- How have you verified this change? Tested locally? Added a unit
test? Checked in staging env? -->
New unit test
  • Loading branch information
pdoerner authored and rodrigozhou committed Feb 13, 2025
1 parent e493e55 commit 9889683
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 32 deletions.
20 changes: 16 additions & 4 deletions common/primitives/timestamp/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
)

// 250 years. Maximum representable golang time.Duration is approximately 290 years (INT64_MAX * time.Nanosecond)
const maxAllowedDuration = 250 * 365 * 24 * time.Hour

var (
errNegativeDuration = fmt.Errorf("negative duration")
errMismatchedSigns = fmt.Errorf("duration has seconds and nanos with different signs")

maxSeconds = maxAllowedDuration.Nanoseconds() / 1e9
)

func DurationValue(d *durationpb.Duration) time.Duration {
Expand Down Expand Up @@ -72,15 +77,17 @@ func durationMultipleOf(amt int64, mult time.Duration) *durationpb.Duration {
return DurationPtr(time.Duration(amt) * mult)
}

// ValidateProtoDuration checks protobuf durations for two conditions:
// ValidateAndCapProtoDuration validates protobuf durations for two conditions:
// 1. the seconds and nanos fields have the same sign (to avoid serialization issues)
// 2. the golang representation of the duration is not negative
//
// Durations are capped to 250 years to prevent overflow and serialization errors.
// NB: to cap durations, the proto Seconds and Nanos fields are modified!
//
// nil durations are considered valid because they will be treated as the zero value.
// durationpb.CheckValid cannot be used directly because it will return an error for
// very large durations but we are okay with truncating these. durationpb.AsDuration()
// caps the upper bound for timers at 10,000 years to prevent overflow.
func ValidateProtoDuration(d *durationpb.Duration) error {
// very large durations, but we are okay with truncating these.
func ValidateAndCapProtoDuration(d *durationpb.Duration) error {
if d == nil {
// nil durations are converted to 0 value
return nil
Expand All @@ -95,5 +102,10 @@ func ValidateProtoDuration(d *durationpb.Duration) error {
return errNegativeDuration
}

if d.AsDuration() > maxAllowedDuration {
d.Seconds = maxSeconds
d.Nanos = 0 // A year is always a round number of seconds.
}

return nil
}
10 changes: 8 additions & 2 deletions common/primitives/timestamp/duration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
)

func TestValidateProtoDuration(t *testing.T) {
func TestValidateAndCapProtoDuration(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -69,13 +69,19 @@ func TestValidateProtoDuration(t *testing.T) {
expectedErr: errMismatchedSigns,
expectedTimerDuration: nil,
},
{
name: "large duration",
timerDuration: durationpb.New(280 * 365 * 24 * time.Hour),
expectedErr: nil,
expectedTimerDuration: durationpb.New(maxAllowedDuration),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

actualErr := ValidateProtoDuration(tc.timerDuration)
actualErr := ValidateAndCapProtoDuration(tc.timerDuration)

assert.Equal(t, tc.expectedErr, actualErr)
if tc.expectedErr == nil {
Expand Down
4 changes: 2 additions & 2 deletions common/retrypolicy/retry_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ func Validate(policy *commonpb.RetryPolicy) error {
// rest of the arguments is pointless
return nil
}
if err := timestamp.ValidateProtoDuration(policy.GetInitialInterval()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(policy.GetInitialInterval()); err != nil {
return serviceerror.NewInvalidArgument(fmt.Sprintf("invalid InitialInterval set on retry policy: %v", err))
}
if policy.GetBackoffCoefficient() < 1 {
return serviceerror.NewInvalidArgument("BackoffCoefficient cannot be less than 1 on retry policy.")
}
if err := timestamp.ValidateProtoDuration(policy.GetMaximumInterval()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(policy.GetMaximumInterval()); err != nil {
return serviceerror.NewInvalidArgument(fmt.Sprintf("invalid MaximumInterval set on retry policy: %v", err))
}
if timestamp.DurationValue(policy.GetMaximumInterval()) > 0 && timestamp.DurationValue(policy.GetMaximumInterval()) < timestamp.DurationValue(policy.GetInitialInterval()) {
Expand Down
2 changes: 1 addition & 1 deletion components/nexusoperations/workflow/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (ch *commandHandler) HandleScheduleCommand(
}
}

if err := timestamp.ValidateProtoDuration(attrs.ScheduleToCloseTimeout); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attrs.ScheduleToCloseTimeout); err != nil {
return workflow.FailWorkflowTaskError{
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_SCHEDULE_NEXUS_OPERATION_ATTRIBUTES,
Message: fmt.Sprintf(
Expand Down
2 changes: 1 addition & 1 deletion service/frontend/namespace_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ func (d *namespaceHandler) maybeUpdateFailoverHistory(

// validateRetentionDuration ensures that retention duration can't be set below a sane minimum.
func validateRetentionDuration(retention *durationpb.Duration, isGlobalNamespace bool) error {
if err := timestamp.ValidateProtoDuration(retention); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(retention); err != nil {
return errInvalidRetentionPeriod
}

Expand Down
14 changes: 7 additions & 7 deletions service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5479,15 +5479,15 @@ func validateRequestId(requestID *string, lenLimit int) error {
func (wh *WorkflowHandler) validateStartWorkflowTimeouts(
request *workflowservice.StartWorkflowExecutionRequest,
) error {
if err := timestamp.ValidateProtoDuration(request.GetWorkflowExecutionTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowExecutionTimeout()); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowExecutionTimeoutSeconds, err)
}

if err := timestamp.ValidateProtoDuration(request.GetWorkflowRunTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowRunTimeout()); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowRunTimeoutSeconds, err)
}

if err := timestamp.ValidateProtoDuration(request.GetWorkflowTaskTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowTaskTimeout()); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowTaskTimeoutSeconds, err)
}

Expand All @@ -5497,15 +5497,15 @@ func (wh *WorkflowHandler) validateStartWorkflowTimeouts(
func (wh *WorkflowHandler) validateSignalWithStartWorkflowTimeouts(
request *workflowservice.SignalWithStartWorkflowExecutionRequest,
) error {
if err := timestamp.ValidateProtoDuration(request.GetWorkflowExecutionTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowExecutionTimeout()); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowExecutionTimeoutSeconds, err)
}

if err := timestamp.ValidateProtoDuration(request.GetWorkflowRunTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowRunTimeout()); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowRunTimeoutSeconds, err)
}

if err := timestamp.ValidateProtoDuration(request.GetWorkflowTaskTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowTaskTimeout()); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowTaskTimeoutSeconds, err)
}

Expand All @@ -5520,7 +5520,7 @@ func (wh *WorkflowHandler) validateWorkflowStartDelay(
return errCronAndStartDelaySet
}

if err := timestamp.ValidateProtoDuration(startDelay); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(startDelay); err != nil {
return fmt.Errorf("%w cause: %v", errInvalidWorkflowStartDelaySeconds, err)
}

Expand Down
22 changes: 11 additions & 11 deletions service/history/api/command_attr_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,16 @@ func (v *CommandAttrValidator) ValidateActivityScheduleAttributes(
}

// Only attempt to deduce and fill in unspecified timeouts only when all timeouts are non-negative.
if err := timestamp.ValidateProtoDuration(attributes.GetScheduleToCloseTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetScheduleToCloseTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid ScheduleToCloseTimeout for ScheduleActivityTaskCommand: %v. ActivityId=%s ActivityType=%s", err, activityID, activityType))
}
if err := timestamp.ValidateProtoDuration(attributes.GetScheduleToStartTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetScheduleToStartTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid ScheduleToStartTimeout for ScheduleActivityTaskCommand: %v. ActivityId=%s ActivityType=%s", err, activityID, activityType))
}
if err := timestamp.ValidateProtoDuration(attributes.GetStartToCloseTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetStartToCloseTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid StartToCloseTimeout for ScheduleActivityTaskCommand: %v. ActivityId=%s ActivityType=%s", err, activityID, activityType))
}
if err := timestamp.ValidateProtoDuration(attributes.GetHeartbeatTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetHeartbeatTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid HeartbeatTimeout for ScheduleActivityTaskCommand: %v. ActivityId=%s ActivityType=%s", err, activityID, activityType))
}

Expand Down Expand Up @@ -222,7 +222,7 @@ func (v *CommandAttrValidator) ValidateTimerScheduleAttributes(
if err := common.ValidateUTF8String("TimerId", timerID); err != nil {
return failedCause, err
}
if err := timestamp.ValidateProtoDuration(attributes.GetStartToFireTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetStartToFireTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("An invalid StartToFireTimeout is set on StartTimerCommand: %v. TimerId=%s", err, timerID))
}
return enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNSPECIFIED, nil
Expand Down Expand Up @@ -483,15 +483,15 @@ func (v *CommandAttrValidator) ValidateContinueAsNewWorkflowExecutionAttributes(
return failedCause, fmt.Errorf("error validating ContinueAsNewWorkflowExecutionCommand TaskQueue: %w. WorkflowType=%s TaskQueue=%s", err, wfType, attributes.TaskQueue)
}

if err := timestamp.ValidateProtoDuration(attributes.GetWorkflowRunTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetWorkflowRunTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowRunTimeout on ContinueAsNewWorkflowExecutionCommand: %v. WorkflowType=%s TaskQueue=%s", err, wfType, attributes.TaskQueue))
}

if err := timestamp.ValidateProtoDuration(attributes.GetWorkflowTaskTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetWorkflowTaskTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowTaskTimeout on ContinueAsNewWorkflowExecutionCommand: %v. WorkflowType=%s TaskQueue=%s", err, wfType, attributes.TaskQueue))
}

if err := timestamp.ValidateProtoDuration(attributes.GetBackoffStartInterval()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetBackoffStartInterval()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid BackoffStartInterval on ContinueAsNewWorkflowExecutionCommand: %v. WorkflowType=%s TaskQueue=%s", err, wfType, attributes.TaskQueue))
}

Expand Down Expand Up @@ -574,15 +574,15 @@ func (v *CommandAttrValidator) ValidateStartChildExecutionAttributes(
return failedCause, err
}

if err := timestamp.ValidateProtoDuration(attributes.GetWorkflowExecutionTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetWorkflowExecutionTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowExecutionTimeout on StartChildWorkflowExecutionCommand: %v. WorkflowId=%s WorkflowType=%s Namespace=%s", err, wfID, wfType, ns))
}

if err := timestamp.ValidateProtoDuration(attributes.GetWorkflowRunTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetWorkflowRunTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowRunTimeout on StartChildWorkflowExecutionCommand: %v. WorkflowId=%s WorkflowType=%s Namespace=%s", err, wfID, wfType, ns))
}

if err := timestamp.ValidateProtoDuration(attributes.GetWorkflowTaskTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(attributes.GetWorkflowTaskTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowTaskTimeout on StartChildWorkflowExecutionCommand: %v. WorkflowId=%s WorkflowType=%s Namespace=%s", err, wfID, wfType, ns))
}

Expand Down
6 changes: 3 additions & 3 deletions service/history/api/create_workflow_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,13 @@ func ValidateStartWorkflowExecutionRequest(
if len(request.GetRequestId()) == 0 {
return serviceerror.NewInvalidArgument("Missing request ID.")
}
if err := timestamp.ValidateProtoDuration(request.GetWorkflowExecutionTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowExecutionTimeout()); err != nil {
return serviceerror.NewInvalidArgument(fmt.Sprintf("invalid WorkflowExecutionTimeoutSeconds: %s", err.Error()))
}
if err := timestamp.ValidateProtoDuration(request.GetWorkflowRunTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowRunTimeout()); err != nil {
return serviceerror.NewInvalidArgument(fmt.Sprintf("invalid WorkflowRunTimeoutSeconds: %s", err.Error()))
}
if err := timestamp.ValidateProtoDuration(request.GetWorkflowTaskTimeout()); err != nil {
if err := timestamp.ValidateAndCapProtoDuration(request.GetWorkflowTaskTimeout()); err != nil {
return serviceerror.NewInvalidArgument(fmt.Sprintf("invalid WorkflowTaskTimeoutSeconds: %s", err.Error()))
}
if request.TaskQueue == nil || request.TaskQueue.GetName() == "" {
Expand Down
2 changes: 1 addition & 1 deletion service/worker/scheduler/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func validateInterval(i *schedulepb.IntervalSpec) error {
if i == nil {
return errors.New("interval is nil")
}
// TODO: use timestamp.ValidateProtoDuration after switching to state machine based implementation.
// TODO: use timestamp.ValidateAndCapProtoDuration after switching to state machine based implementation.
// Not adding it to workflow based implementation to avoid potential non-determinism errors.
iv, phase := timestamp.DurationValue(i.Interval), timestamp.DurationValue(i.Phase)
if iv < time.Second {
Expand Down

0 comments on commit 9889683

Please sign in to comment.