From 241a712b185756b370e8fbde23acf3d5c6cc6673 Mon Sep 17 00:00:00 2001 From: VaibhavMalik4187 Date: Thu, 8 Feb 2024 13:52:45 +0530 Subject: [PATCH] feat: support many:1 auth:provider mapping Added the ability to add multiple configurations for the same backend provider. One to one auth:provider mapping had many bugs and users were requesting many to one auth:provider mapping functionality for more flexibility with backend provider configurations. This commit also includes the addition of an optional `--config-name` flag to allow users to specify the config-name while configuring backend AI providers. Changes exclusive to individual subcommands are: 1. default - `--config-name` flag has been added to allow users to update the default configuration for a backend AI provider. - Setting the `--config-name` flag will only update the default configuration for the specified backend, it will not update the `configAI.DefaultProvider`. - To update the `configAI.DefaultProvider`, user must leave the `--config-name` flag unset. 2. list - The list subcommand will now respect the value of `userInput` set when the user is asked to show password or not. - The command's output has been modified to display the backend providers list along with their configuration names. - Default configs for each backend AI provider are marked with the `(Default Config)` string in bright yellow color. 3. remove - Users can specify either single or multiple backends when removing the configurations. - If the `--config-name` flag is not set, the "default" configuration for the specified backend(s) will be removed(if it is present). - Setting the `--config-name` flag will specify the configuration name to delete for each of the specified backends. 4. Update - Now the user can only specify a single backend provider at a time. This has been done because updating multiple backend configs with same values of parameters doesn't make any sense. - Removal of the `Args` function because update subcommand can have multiple args to set different parameters for a backend config. - Users can also update the configuration name for an already existing config. Addresses: - https://github.com/k8sgpt-ai/k8sgpt/issues/936 - https://github.com/k8sgpt-ai/k8sgpt/issues/911 - https://github.com/k8sgpt-ai/k8sgpt/issues/905 - https://github.com/k8sgpt-ai/k8sgpt/issues/900 - https://github.com/k8sgpt-ai/k8sgpt/issues/843 Signed-off-by: VaibhavMalik4187 --- cmd/auth/add.go | 214 +++++++++++++++++++++++---------------- cmd/auth/auth.go | 1 + cmd/auth/default.go | 114 ++++++++++++++------- cmd/auth/list.go | 129 ++++++++++++----------- cmd/auth/remove.go | 109 +++++++++++++------- cmd/auth/update.go | 127 ++++++++++++++--------- cmd/serve/serve.go | 26 +++-- pkg/ai/iai.go | 24 +++-- pkg/analysis/analysis.go | 8 +- 9 files changed, 462 insertions(+), 290 deletions(-) diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 2e195d02b9..2dab767908 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -28,120 +28,164 @@ import ( const ( defaultBackend = "openai" + defaultConfig = "default" defaultModel = "gpt-3.5-turbo" ) -var addCmd = &cobra.Command{ - Use: "add", - Short: "Add new provider", - Long: "The add command allows to configure a new backend AI provider", - PreRun: func(cmd *cobra.Command, args []string) { - backend, _ := cmd.Flags().GetString("backend") - if strings.ToLower(backend) == "azureopenai" { - _ = cmd.MarkFlagRequired("engine") - _ = cmd.MarkFlagRequired("baseurl") - } - if strings.ToLower(backend) == "amazonsagemaker" { - _ = cmd.MarkFlagRequired("endpointname") - _ = cmd.MarkFlagRequired("providerRegion") +func runAddCommand(cmd *cobra.Command, args []string) { + // 1. Get ai configuration + err := viper.UnmarshalKey("ai", &configAI) + if err != nil { + color.Red("Error: %v", err) + os.Exit(1) + } + + // 2. Validate input values upfront and set default values if the inputs are empty + // check if backend is not empty and a valid value + validBackend := func(validBackends []string, backend string) bool { + for _, b := range validBackends { + if b == backend { + return true + } } - }, - Run: func(cmd *cobra.Command, args []string) { + return false + } - // get ai configuration - err := viper.UnmarshalKey("ai", &configAI) - if err != nil { - color.Red("Error: %v", err) + if backend == "" { + // Set the default value of the backend provider + color.Yellow(fmt.Sprintf("Warning: backend input is empty, will use the default value: %s", defaultBackend)) + backend = defaultBackend + } else { + // Check if the given provider is valid or not. + if !validBackend(ai.Backends, backend) { + color.Red("Error: Backend AI accepted values are '%v'", strings.Join(ai.Backends, ", ")) os.Exit(1) } + } - // search for provider with same name - providerIndex := -1 - for i, provider := range configAI.Providers { - if backend == provider.Name { - providerIndex = i - break - } - } + // Set the value of config-name if it is not provided by the user. + if configName == "" { + color.Yellow(fmt.Sprintf("Warning: config-name input is empty, will use the default value: %s", defaultConfig)) + configName = defaultConfig + } + + // 3. Find existing provider index + // search for provider with same backend + providerIndex := -1 + configIndex := -1 + for i, provider := range configAI.Providers { + if backend == provider.Backend { + providerIndex = i - validBackend := func(validBackends []string, backend string) bool { - for _, b := range validBackends { - if b == backend { - return true + // Iterate over all the configs of this provider + // and check if a config with the same name already exists + for index, config := range provider.Configs { + if configName == config.Name { + configIndex = index + break } } - return false - } - // check if backend is not empty and a valid value - if backend == "" { - color.Yellow(fmt.Sprintf("Warning: backend input is empty, will use the default value: %s", defaultBackend)) - backend = defaultBackend - } else { - if !validBackend(ai.Backends, backend) { - color.Red("Error: Backend AI accepted values are '%v'", strings.Join(ai.Backends, ", ")) - os.Exit(1) + if configIndex != -1 { + break } } + } - // check if model is not empty - if model == "" { - model = defaultModel - color.Yellow(fmt.Sprintf("Warning: model input is empty, will use the default value: %s", defaultModel)) - } - if temperature > 1.0 || temperature < 0.0 { - color.Red("Error: temperature ranges from 0 to 1.") - os.Exit(1) - } - if topP > 1.0 || topP < 0.0 { - color.Red("Error: topP ranges from 0 to 1.") + // Quit if the config already exists + if configIndex != -1 { + color.Red("Provider with same config already exists.") + os.Exit(1) + } + + // Handle input sanitization for config. + if model == "" { + model = defaultModel + color.Yellow(fmt.Sprintf("Warning: model input is empty, will use the default value: %s", defaultModel)) + } + if temperature > 1.0 || temperature < 0.0 { + color.Red("Error: temperature ranges from 0 to 1.") + os.Exit(1) + } + if topP > 1.0 || topP < 0.0 { + color.Red("Error: topP ranges from 0 to 1.") + os.Exit(1) + } + + if ai.NeedPassword(backend) && password == "" { + fmt.Printf("Enter %s Key: ", backend) + bytePassword, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + color.Red("Error reading %s Key from stdin: %s", backend, + err.Error()) os.Exit(1) } + password = strings.TrimSpace(string(bytePassword)) + } - if ai.NeedPassword(backend) && password == "" { - fmt.Printf("Enter %s Key: ", backend) - bytePassword, err := term.ReadPassword(int(syscall.Stdin)) - if err != nil { - color.Red("Error reading %s Key from stdin: %s", backend, - err.Error()) - os.Exit(1) - } - password = strings.TrimSpace(string(bytePassword)) - } + // Create a new provider config + config := ai.AIProviderConfig{ + Name: configName, + Model: model, + Password: password, + BaseURL: baseURL, + EndpointName: endpointName, + Engine: engine, + Temperature: temperature, + ProviderRegion: providerRegion, + TopP: topP, + MaxTokens: maxTokens, + } - // create new provider object + // Create a new provider if the providerIndex is -1 + if providerIndex == -1 { + // Instantiate a new provider if it is not already present. newProvider := ai.AIProvider{ - Name: backend, - Model: model, - Password: password, - BaseURL: baseURL, - EndpointName: endpointName, - Engine: engine, - Temperature: temperature, - ProviderRegion: providerRegion, - TopP: topP, - MaxTokens: maxTokens, + Backend: backend, + Configs: []ai.AIProviderConfig{ + config, + }, + DefaultConfig: 0, } - if providerIndex == -1 { - // provider with same name does not exist, add new provider to list - configAI.Providers = append(configAI.Providers, newProvider) - viper.Set("ai", configAI) - if err := viper.WriteConfig(); err != nil { - color.Red("Error writing config file: %s", err.Error()) - os.Exit(1) - } - color.Green("%s added to the AI backend provider list", backend) - } else { - // provider with same name exists, update provider info - color.Yellow("Provider with same name already exists.") + // provider with this backend name does not exist, add new provider to list + configAI.Providers = append(configAI.Providers, newProvider) + } else { + // Append this config in the configs of the ai provider + configAI.Providers[providerIndex].Configs = append(configAI.Providers[providerIndex].Configs, config) + } + + viper.Set("ai", configAI) + if err := viper.WriteConfig(); err != nil { + color.Red("Error writing config file: %s", err.Error()) + os.Exit(1) + } + color.Green("%s added to the AI backend provider list", backend) +} + +var addCmd = &cobra.Command{ + Use: "add", + Short: "Add new provider", + Long: "The add command allows to configure a new backend AI provider", + PreRun: func(cmd *cobra.Command, args []string) { + backend, _ := cmd.Flags().GetString("backend") + if strings.ToLower(backend) == "azureopenai" { + _ = cmd.MarkFlagRequired("engine") + _ = cmd.MarkFlagRequired("baseurl") + } + if strings.ToLower(backend) == "amazonsagemaker" { + _ = cmd.MarkFlagRequired("endpointname") + _ = cmd.MarkFlagRequired("providerRegion") } }, + Run: runAddCommand, } func init() { // add flag for backend addCmd.Flags().StringVarP(&backend, "backend", "b", defaultBackend, "Backend AI provider") + // add flag for config-name + addCmd.Flags().StringVarP(&configName, "config-name", "", defaultConfig, "Backend AI provider") // add flag for model addCmd.Flags().StringVarP(&model, "model", "m", defaultModel, "Backend AI model") // add flag for password diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index f252f8a8d3..d02d8ec1c4 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -20,6 +20,7 @@ import ( var ( backend string + configName string password string baseURL string endpointName string diff --git a/cmd/auth/default.go b/cmd/auth/default.go index 8b58293066..879ae5b1c5 100644 --- a/cmd/auth/default.go +++ b/cmd/auth/default.go @@ -26,54 +26,94 @@ var ( providerName string ) -var defaultCmd = &cobra.Command{ - Use: "default", - Short: "Set your default AI backend provider", - Long: "The command to set your new default AI backend provider (default is openai)", - Run: func(cmd *cobra.Command, args []string) { - err := viper.UnmarshalKey("ai", &configAI) - if err != nil { - color.Red("Error: %v", err) - os.Exit(1) +func runDefaultCommand(cmd *cobra.Command, args []string) { + // 1. Get the ai configurations + err := viper.UnmarshalKey("ai", &configAI) + if err != nil { + color.Red("Error: %v", err) + os.Exit(1) + } + + // 2. Validate the input values and set defaults if necessary + if providerName == "" { + if configAI.DefaultProvider != "" { + color.Yellow("Your default provider is \"%s\"", configAI.DefaultProvider) + } else { + color.Yellow("Your default provider is openai") } - if providerName == "" { - if configAI.DefaultProvider != "" { - color.Yellow("Your default provider is %s", configAI.DefaultProvider) - } else { - color.Yellow("Your default provider is openai") + os.Exit(0) + } + + // lowercase the provider name + providerName = strings.ToLower(providerName) + + // Check if the provider is in the provider list + providerIndex := -1 + configIndex := -1 + for i, provider := range configAI.Providers { + if providerName == provider.Backend { + providerIndex = i + + // Iterate over all the configs of this provider + // and check if a config with the same name exists + if configName != "" { + for index, config := range provider.Configs { + if configName == config.Name { + configIndex = index + break + } + } } - os.Exit(0) - } - // lowercase the provider name - providerName = strings.ToLower(providerName) - - // Check if the provider is in the provider list - providerExists := false - for _, provider := range configAI.Providers { - if provider.Name == providerName { - providerExists = true + + if configIndex != -1 { + break } } - if !providerExists { - color.Red("Error: Provider %s does not exist", providerName) - os.Exit(1) - } + } + + if providerIndex == -1 { + color.Red("Error: Provider \"%s\" does not exist", providerName) + os.Exit(1) + } + + if configIndex == -1 && configName != "" { + color.Red("Error: The backend provider \"%s\" does not have a configuration with the name \"%s\"", backend, configName) + os.Exit(1) + } + + if configName != "" { + // Set the default config + configAI.Providers[providerIndex].DefaultConfig = configIndex + } else { // Set the default provider configAI.DefaultProvider = providerName + } - viper.Set("ai", configAI) - // Viper write config - err = viper.WriteConfig() - if err != nil { - color.Red("Error: %v", err) - os.Exit(1) - } - // Print acknowledgement + viper.Set("ai", configAI) + // Viper write config + err = viper.WriteConfig() + if err != nil { + color.Red("Error: %v", err) + os.Exit(1) + } + + // Print acknowledgement + if configName != "" { + color.Green("Default config for %s set to %s", providerName, configName) + } else { color.Green("Default provider set to %s", providerName) - }, + } +} + +var defaultCmd = &cobra.Command{ + Use: "default", + Short: "Set your default AI backend provider and provider config", + Long: "The command to set your new default AI backend provider (default is openai)", + Run: runDefaultCommand, } func init() { // provider name flag defaultCmd.Flags().StringVarP(&providerName, "provider", "p", "", "The name of the provider to set as default") + defaultCmd.Flags().StringVarP(&configName, "config-name", "", "", "The name of the config to set as default for a provider") } diff --git a/cmd/auth/list.go b/cmd/auth/list.go index af9472ed45..1eba35a859 100644 --- a/cmd/auth/list.go +++ b/cmd/auth/list.go @@ -16,6 +16,7 @@ package auth import ( "fmt" "os" + "strings" "github.com/fatih/color" "github.com/k8sgpt-ai/k8sgpt/pkg/ai" @@ -26,79 +27,87 @@ import ( var details bool var userInput string -var listCmd = &cobra.Command{ - Use: "list", - Short: "List configured providers", - Long: "The list command displays a list of configured providers", - Run: func(cmd *cobra.Command, args []string) { +func runListCommand(cmd *cobra.Command, args []string) { + // Get the ai configurations + err := viper.UnmarshalKey("ai", &configAI) + if err != nil { + color.Red("Error: %v", err) + os.Exit(1) + } - // get ai configuration - err := viper.UnmarshalKey("ai", &configAI) - if err != nil { - color.Red("Error: %v", err) - os.Exit(1) - } + if details { + fmt.Println("Show password ? (y/n)") + fmt.Scan(&userInput) + } - if details { - fmt.Println("Show password ? (y/n)") - fmt.Scan(&userInput) - } + // Print the default if it is set + fmt.Print(color.YellowString("Default: \n")) + if configAI.DefaultProvider != "" { + fmt.Printf("> %s\n", color.BlueString(configAI.DefaultProvider)) + } else { + fmt.Printf("> %s\n", color.BlueString("openai")) + } - // Print the default if it is set - fmt.Print(color.YellowString("Default: \n")) - if configAI.DefaultProvider != "" { - fmt.Printf("> %s\n", color.BlueString(configAI.DefaultProvider)) - } else { - fmt.Printf("> %s\n", color.BlueString("openai")) - } + // Get list of all AI Backends and only print them if they are not in the provider list + fmt.Print(color.YellowString("Active: \n")) + for _, provider := range configAI.Providers { + fmt.Printf("> %s\n", color.GreenString(provider.Backend)) + fmt.Println(" > " + color.HiCyanString("Configs")) - // Get list of all AI Backends and only print them if they are not in the provider list - fmt.Print(color.YellowString("Active: \n")) - for _, aiBackend := range ai.Backends { - providerExists := false - for _, provider := range configAI.Providers { - if provider.Name == aiBackend { - providerExists = true - } + for index, config := range provider.Configs { + if index == provider.DefaultConfig { + fmt.Printf(" %d. %s "+color.HiYellowString("(Default Config)\n"), index+1, config.Name) + } else { + fmt.Printf(" %d. %s\n", index+1, config.Name) } - if providerExists { - fmt.Printf("> %s\n", color.GreenString(aiBackend)) - if details { - for _, provider := range configAI.Providers { - if provider.Name == aiBackend { - printDetails(provider, userInput) - } - } - } + if details { + printDetails(provider, userInput, index) } } - fmt.Print(color.YellowString("Unused: \n")) - for _, aiBackend := range ai.Backends { - providerExists := false - for _, provider := range configAI.Providers { - if provider.Name == aiBackend { - providerExists = true - } - } - if !providerExists { - fmt.Printf("> %s\n", color.RedString(aiBackend)) + } + + fmt.Print(color.YellowString("Unused: \n")) + for _, aiBackend := range ai.Backends { + providerExists := false + for _, provider := range configAI.Providers { + if provider.Backend == aiBackend { + providerExists = true } } - }, + if !providerExists { + fmt.Printf("> %s\n", color.RedString(aiBackend)) + } + } } -func init() { - listCmd.Flags().BoolVar(&details, "details", false, "Print active provider configuration details") -} +func printDetails(provider ai.AIProvider, userInput string, index int) { + if provider.Configs[index].Model != "" { + fmt.Printf(" - Model: %s\n", provider.Configs[index].Model) + } -func printDetails(provider ai.AIProvider, userInput string) { - if provider.Model != "" { - fmt.Printf(" - Model: %s\n", provider.Model) + userInput = strings.ToLower(userInput) + if userInput == "y" || userInput == "yes" { + if provider.Configs[index].Password != "" { + fmt.Printf(" - Password: %s\n", provider.Configs[index].Password) + } } - if provider.Engine != "" { - fmt.Printf(" - Engine: %s\n", provider.Engine) + + if provider.Configs[index].Engine != "" { + fmt.Printf(" - Engine: %s\n", provider.Configs[index].Engine) } - if provider.BaseURL != "" { - fmt.Printf(" - BaseURL: %s\n", provider.BaseURL) + + if provider.Configs[index].BaseURL != "" { + fmt.Printf(" - BaseURL: %s\n", provider.Configs[index].BaseURL) } } + +var listCmd = &cobra.Command{ + Use: "list", + Short: "List configured providers", + Long: "The list command displays a list of configured providers", + Run: runListCommand, +} + +func init() { + listCmd.Flags().BoolVar(&details, "details", false, "Print active provider configuration details") +} diff --git a/cmd/auth/remove.go b/cmd/auth/remove.go index c18066545c..badf54d066 100644 --- a/cmd/auth/remove.go +++ b/cmd/auth/remove.go @@ -22,56 +22,91 @@ import ( "github.com/spf13/viper" ) -var removeCmd = &cobra.Command{ - Use: "remove", - Short: "Remove provider(s)", - Long: "The command to remove AI backend provider(s)", - PreRun: func(cmd *cobra.Command, args []string) { - _ = cmd.MarkFlagRequired("backends") - }, - Run: func(cmd *cobra.Command, args []string) { - if backend == "" { - color.Red("Error: backends must be set.") - _ = cmd.Help() - return - } - inputBackends := strings.Split(backend, ",") +func runRemoveCommand(cmd *cobra.Command, args []string) { + // Get the ai configurations + err := viper.UnmarshalKey("ai", &configAI) + if err != nil { + color.Red("Error: %v", err) + os.Exit(1) + } - err := viper.UnmarshalKey("ai", &configAI) - if err != nil { - color.Red("Error: %v", err) - os.Exit(1) - } + // Check if the backend flag is set. + if backend == "" { + color.Red("Error: backends must be set.") + _ = cmd.Help() + return + } - for _, b := range inputBackends { - foundBackend := false - for i, provider := range configAI.Providers { - if b == provider.Name { - foundBackend = true - configAI.Providers = append(configAI.Providers[:i], configAI.Providers[i+1:]...) - if configAI.DefaultProvider == b { - configAI.DefaultProvider = "openai" + inputBackends := strings.Split(backend, ",") + + if configName == "" { + color.Yellow("Warning: No config is specified therefore the config named \"default\" config will be removed") + configName = defaultConfig + } + + // Now, iterate over each backend + for _, backendName := range inputBackends { + foundBackend := false + for i, provider := range configAI.Providers { + // Check if the input backend is present in the list of providers stored + // in the config file. + if backendName == provider.Backend { + foundBackend = true + + // Now, start iterating over the configs stored in the backend + deletedConfigIndex := -1 + for index, config := range provider.Configs { + if configName == config.Name { + deletedConfigIndex = index + // Remove the config if it is found. + configAI.Providers[i].Configs = append(configAI.Providers[i].Configs[:index], configAI.Providers[i].Configs[index+1:]...) + color.Green("Config: \"%s\" deleted for the AI backend provider: \"%s\"", configName, backendName) } - color.Green("%s deleted from the AI backend provider list", b) - break } - } - if !foundBackend { - color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", b) - os.Exit(1) + + if deletedConfigIndex == -1 { + color.Red("Error: Backend provider \"%s\" didn't have any config with name \"%s\". Aborting!", backendName, configName) + os.Exit(1) + } + + // Now, check if there are any configs left for this backend provider. + if len(configAI.Providers[i].Configs) == 0 { + // Delete this backend provider. + configAI.Providers = append(configAI.Providers[:i], configAI.Providers[i+1:]...) + } else if deletedConfigIndex == configAI.Providers[i].DefaultConfig { + // Update the default config for this backend provider to config at 0th index. + configAI.Providers[i].DefaultConfig = 0 + color.Yellow("After deleting the config \"%s\", the default config for backend provider \"%s\" has changed to \"%s\"", configName, backendName, configAI.Providers[i].Configs[0].Name) + } + + break } } - - viper.Set("ai", configAI) - if err := viper.WriteConfig(); err != nil { - color.Red("Error writing config file: %s", err.Error()) + if !foundBackend { + color.Red("Error: \"%s\" does not exist in configuration file. Please use k8sgpt auth new.", backendName) os.Exit(1) } + } + viper.Set("ai", configAI) + if err := viper.WriteConfig(); err != nil { + color.Red("Error writing config file: %s", err.Error()) + os.Exit(1) + } +} + +var removeCmd = &cobra.Command{ + Use: "remove", + Short: "Remove provider(s)", + Long: "The command to remove AI backend provider(s)", + PreRun: func(cmd *cobra.Command, args []string) { + _ = cmd.MarkFlagRequired("backends") }, + Run: runRemoveCommand, } func init() { // add flag for backends removeCmd.Flags().StringVarP(&backend, "backends", "b", "", "Backend AI providers to remove (separated by a comma)") + removeCmd.Flags().StringVarP(&configName, "config-name", "", "", "Name of the config to remove") } diff --git a/cmd/auth/update.go b/cmd/auth/update.go index eb9a0e79ef..28f3242436 100644 --- a/cmd/auth/update.go +++ b/cmd/auth/update.go @@ -22,84 +22,115 @@ import ( "github.com/spf13/viper" ) -var updateCmd = &cobra.Command{ - Use: "update", - Short: "Update a backend provider", - Long: "The command to update an AI backend provider", - Args: cobra.ExactArgs(1), - PreRun: func(cmd *cobra.Command, args []string) { - backend, _ := cmd.Flags().GetString("backend") - if strings.ToLower(backend) == "azureopenai" { - _ = cmd.MarkFlagRequired("engine") - _ = cmd.MarkFlagRequired("baseurl") - } - }, - Run: func(cmd *cobra.Command, args []string) { +var ( + newConfigName string +) - // get ai configuration - err := viper.UnmarshalKey("ai", &configAI) - if err != nil { - color.Red("Error: %v", err) - os.Exit(1) - } +func runUpdateCommand(cmd *cobra.Command, args []string) { + // Get the ai configurations + err := viper.UnmarshalKey("ai", &configAI) + if err != nil { + color.Red("Error: %v", err) + os.Exit(1) + } - inputBackends := strings.Split(args[0], ",") + // Check if the backend flag is set. + if backend == "" { + color.Red("Error: backend must be set.") + _ = cmd.Help() + return + } - if len(inputBackends) == 0 { - color.Red("Error: backend must be set.") - os.Exit(1) - } - if temperature > 1.0 || temperature < 0.0 { - color.Red("Error: temperature ranges from 0 to 1.") - os.Exit(1) - } + // Validate the temperature range. + if temperature > 1.0 || temperature < 0.0 { + color.Red("Error: temperature ranges from 0 to 1.") + os.Exit(1) + } - for _, b := range inputBackends { - foundBackend := false - for i, provider := range configAI.Providers { - if b == provider.Name { - foundBackend = true - if backend != "" { - configAI.Providers[i].Name = backend - color.Blue("Backend name updated successfully") + // Iterate over all the providers present in the config file. + for i, provider := range configAI.Providers { + if backend == provider.Backend { + if configName == "" { + // Modify the default config if the config name is not specified. + configName = defaultConfig + color.Yellow("Since no config name was specified, changes will be made to the \"default\" config") + } + + // Iterate over all the configs present in that backend provider. + configFound := false + for index, config := range provider.Configs { + // Check if the config to be updated exists or not. + if configName == config.Name { + configFound = true + + // Config exists, now update the parameters + if newConfigName != "" { + configAI.Providers[i].Configs[index].Name = newConfigName + color.Blue("Config name updated successfully") } if model != "" { - configAI.Providers[i].Model = model + configAI.Providers[i].Configs[index].Model = model color.Blue("Model updated successfully") } if password != "" { - configAI.Providers[i].Password = password + configAI.Providers[i].Configs[index].Password = password color.Blue("Password updated successfully") } if baseURL != "" { - configAI.Providers[i].BaseURL = baseURL + configAI.Providers[i].Configs[index].BaseURL = baseURL color.Blue("Base URL updated successfully") } if engine != "" { - configAI.Providers[i].Engine = engine + configAI.Providers[i].Configs[index].Engine = engine + color.Blue("Engine updated successfully") } - configAI.Providers[i].Temperature = temperature - color.Green("%s updated in the AI backend provider list", b) + configAI.Providers[i].Configs[index].Temperature = temperature } } - if !foundBackend { - color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0]) + + if !configFound { + color.Red("Error: The backend provider \"%s\" does not have a configuration with the name \"%s\"", backend, configName) os.Exit(1) + } else { + color.Green("Config \"%s\" for the backend provider \"%s\" has been successfully updated", configName, backend) } + // Break out of the loop if the desired backend provider has been updated. + break } + } - viper.Set("ai", configAI) - if err := viper.WriteConfig(); err != nil { - color.Red("Error writing config file: %s", err.Error()) - os.Exit(1) + // Write the configuration to the config file. + viper.Set("ai", configAI) + if err := viper.WriteConfig(); err != nil { + color.Red("Error writing config file: %s", err.Error()) + os.Exit(1) + } +} + +var updateCmd = &cobra.Command{ + Use: "update", + Short: "Update a backend provider", + Long: "The command to update an AI backend provider", + // TODO: Why was this present in the first place? + // Args: cobra.ExactArgs(1), + PreRun: func(cmd *cobra.Command, args []string) { + backend, _ := cmd.Flags().GetString("backend") + if strings.ToLower(backend) == "azureopenai" { + _ = cmd.MarkFlagRequired("engine") + _ = cmd.MarkFlagRequired("baseurl") } }, + Run: runUpdateCommand, } func init() { // update flag for backend updateCmd.Flags().StringVarP(&backend, "backend", "b", "", "Update backend AI provider") + // update flag for config-name + updateCmd.Flags().StringVarP(&configName, "config-name", "", "", "Name of the configuration to update") + // update flag for config-name + updateCmd.Flags().StringVarP(&newConfigName, "name", "n", "", "New name for the configuration to update") // update flag for model updateCmd.Flags().StringVarP(&model, "model", "m", "", "Update backend AI model") // update flag for password diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index 3c3e58fc1a..c766519715 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -78,12 +78,18 @@ var ServeCmd = &cobra.Command{ envIsSet := backend != "" || password != "" || model != "" if envIsSet { aiProvider = &ai.AIProvider{ - Name: backend, - Password: password, - Model: model, - BaseURL: baseURL, - Engine: engine, - Temperature: temperature(), + Backend: backend, + Configs: []ai.AIProviderConfig{ + { + Name: backend, + Password: password, + Model: model, + BaseURL: baseURL, + Engine: engine, + Temperature: temperature(), + }, + }, + DefaultConfig: 0, } configAI.Providers = append(configAI.Providers, *aiProvider) @@ -100,7 +106,7 @@ var ServeCmd = &cobra.Command{ } if aiProvider == nil { for _, provider := range configAI.Providers { - if backend == provider.Name { + if backend == provider.Backend { // the pointer to the range variable is not really an issue here, as there // is a break right after, but to prevent potential future issues, a temp // variable is assigned @@ -111,7 +117,7 @@ var ServeCmd = &cobra.Command{ } } - if aiProvider.Name == "" { + if aiProvider.Backend == "" { color.Red("Error: AI provider %s not specified in configuration. Please run k8sgpt auth", backend) os.Exit(1) } @@ -129,11 +135,11 @@ var ServeCmd = &cobra.Command{ }() server := k8sgptserver.Config{ - Backend: aiProvider.Name, + Backend: aiProvider.Backend, Port: port, MetricsPort: metricsPort, EnableHttp: enableHttp, - Token: aiProvider.Password, + Token: aiProvider.Configs[aiProvider.DefaultConfig].Password, Logger: logger, } go func() { diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 99de8e3a40..69f3ad548d 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -88,6 +88,12 @@ type AIConfiguration struct { } type AIProvider struct { + Backend string + Configs []AIProviderConfig + DefaultConfig int +} + +type AIProviderConfig struct { Name string `mapstructure:"name"` Model string `mapstructure:"model"` Password string `mapstructure:"password" yaml:"password,omitempty"` @@ -101,38 +107,38 @@ type AIProvider struct { } func (p *AIProvider) GetBaseURL() string { - return p.BaseURL + return p.Configs[p.DefaultConfig].BaseURL } func (p *AIProvider) GetEndpointName() string { - return p.EndpointName + return p.Configs[p.DefaultConfig].EndpointName } func (p *AIProvider) GetTopP() float32 { - return p.TopP + return p.Configs[p.DefaultConfig].TopP } func (p *AIProvider) GetMaxTokens() int { - return p.MaxTokens + return p.Configs[p.DefaultConfig].MaxTokens } func (p *AIProvider) GetPassword() string { - return p.Password + return p.Configs[p.DefaultConfig].Password } func (p *AIProvider) GetModel() string { - return p.Model + return p.Configs[p.DefaultConfig].Model } func (p *AIProvider) GetEngine() string { - return p.Engine + return p.Configs[p.DefaultConfig].Engine } func (p *AIProvider) GetTemperature() float32 { - return p.Temperature + return p.Configs[p.DefaultConfig].Temperature } func (p *AIProvider) GetProviderRegion() string { - return p.ProviderRegion + return p.Configs[p.DefaultConfig].ProviderRegion } var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"} diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index 139d0e7505..8848e2ce05 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -131,22 +131,22 @@ func NewAnalysis( var aiProvider ai.AIProvider for _, provider := range configAI.Providers { - if backend == provider.Name { + if backend == provider.Backend { aiProvider = provider break } } - if aiProvider.Name == "" { + if aiProvider.Backend == "" { return nil, fmt.Errorf("AI provider %s not specified in configuration. Please run k8sgpt auth", backend) } - aiClient := ai.NewClient(aiProvider.Name) + aiClient := ai.NewClient(aiProvider.Backend) if err := aiClient.Configure(&aiProvider); err != nil { return nil, err } a.AIClient = aiClient - a.AnalysisAIProvider = aiProvider.Name + a.AnalysisAIProvider = aiProvider.Backend return a, nil }