diff --git a/README.md b/README.md index 9fcf2556..63b2c3a2 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,8 @@ Configuration file is supported in two different format: json and yaml. Use the `$ arduino-cloud-cli config init --config-format json` +It is also possible to specify credentials directly in `ARDUINO_CLOUD_CLIENT` and `ARDUINO_CLOUD_SECRET` environment variables. Credentials specified in environment variables have higher priority than the ones specified in config files. + ## Device provisioning When provisioning a device, you can optionally specify the port to which the device is connected to and its fqbn. If they are not given, then the first device found will be provisioned. diff --git a/cli/config/init.go b/cli/config/init.go index 3002a7ee..82e51046 100644 --- a/cli/config/init.go +++ b/cli/config/init.go @@ -34,11 +34,6 @@ import ( "github.com/spf13/viper" ) -const ( - clientIDLen = 32 - clientSecretLen = 64 -) - var initFlags struct { destDir string overwrite bool @@ -122,7 +117,7 @@ func paramsPrompt() (id, key string, err error) { prompt := promptui.Prompt{ Label: "Please enter the Client ID", Validate: func(s string) error { - if len(s) != clientIDLen { + if len(s) != config.ClientIDLen { return errors.New("client-id not valid") } return nil @@ -137,7 +132,7 @@ func paramsPrompt() (id, key string, err error) { Mask: '*', Label: "Please enter the Client Secret", Validate: func(s string) error { - if len(s) != clientSecretLen { + if len(s) != config.ClientSecretLen { return errors.New("client secret not valid") } return nil diff --git a/internal/config/config.go b/internal/config/config.go index 1d8f212f..57536820 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,17 @@ import ( "github.com/spf13/viper" ) +const ( + // ClientIDLen specifies the length of Arduino IoT Cloud client ids. + ClientIDLen = 32 + // ClientSecretLen specifies the length of Arduino IoT Cloud client secrets. + ClientSecretLen = 64 + + // EnvPrefix is the prefix environment variables should have to be + // fetched as config parameters during the config retrieval. + EnvPrefix = "ARDUINO_CLOUD" +) + // Config contains all the configuration parameters // known by arduino-cloud-cli. type Config struct { @@ -32,53 +43,164 @@ type Config struct { Secret string `map-structure:"secret"` // Secret ID of the user, unique for each Client ID } -// Retrieve returns the actual parameters contained in the -// configuration file, if any. Returns error if no config file is found. +// Validate the config. +// If config is not valid, it returns an error explaining the reason. +func (c *Config) Validate() error { + if len(c.Client) != ClientIDLen { + return fmt.Errorf( + "client id not valid, expected len %d but got %d", + ClientIDLen, + len(c.Client), + ) + } + if len(c.Secret) != ClientSecretLen { + return fmt.Errorf( + "client secret not valid, expected len %d but got %d", + ClientSecretLen, + len(c.Secret), + ) + } + return nil +} + +// IsEmpty checks if config has no params set. +func (c *Config) IsEmpty() bool { + return len(c.Client) == 0 && len(c.Secret) == 0 +} + +// Retrieve looks for configuration parameters in +// environment variables or in configuration file. +// Returns error if no config is found. func Retrieve() (*Config, error) { + // Config extracted from environment has highest priority + c, err := fromEnv() + if err != nil { + return nil, fmt.Errorf("reading config from environment variables: %w", err) + } + // Return the config only if it has been found + if c != nil { + return c, nil + } + + c, err = fromFile() + if err != nil { + return nil, fmt.Errorf("reading config from file: %w", err) + } + if c != nil { + return c, nil + } + + return nil, fmt.Errorf( + "config has not been found neither in environment variables " + + "nor in the current directory, its parents or in arduino15", + ) +} + +// fromFile looks for a configuration file. +// If a config file is not found, it returns a nil config without raising errors. +// If invalid config file is found, it returns an error. +func fromFile() (*Config, error) { + // Looks for a configuration file configDir, err := searchConfigDir() if err != nil { return nil, fmt.Errorf("can't get config directory: %w", err) } + // Return nil config if no config file is found + if configDir == nil { + return nil, nil + } v := viper.New() v.SetConfigName(Filename) - v.AddConfigPath(configDir) + v.AddConfigPath(*configDir) err = v.ReadInConfig() if err != nil { - err = fmt.Errorf("%s: %w", "retrieving config file", err) + err = fmt.Errorf( + "config file found at %s but cannot read its content: %w", + *configDir, + err, + ) return nil, err } conf := &Config{} - v.Unmarshal(conf) + err = v.Unmarshal(conf) + if err != nil { + return nil, fmt.Errorf( + "config file found at %s but cannot unmarshal it: %w", + *configDir, + err, + ) + } + if err = conf.Validate(); err != nil { + return nil, fmt.Errorf( + "config file found at %s but is not valid: %w", + *configDir, + err, + ) + } return conf, nil } -func searchConfigDir() (string, error) { +// fromEnv looks for configuration credentials in environment variables. +// If credentials are not found, it returns a nil config without raising errors. +// If invalid credentials are found, it returns an error. +func fromEnv() (*Config, error) { + v := viper.New() + SetDefaults(v) + v.SetEnvPrefix(EnvPrefix) + v.AutomaticEnv() + + conf := &Config{} + err := v.Unmarshal(conf) + if err != nil { + return nil, fmt.Errorf("cannot unmarshal config from environment variables: %w", err) + } + + if conf.IsEmpty() { + return nil, nil + } + + if err = conf.Validate(); err != nil { + return nil, fmt.Errorf( + "config retrieved from environment variables with prefix '%s' are not valid: %w", + EnvPrefix, + err, + ) + } + return conf, nil +} + +// searchConfigDir configuration file in different directories in the following order: +// current working directory, parents of the current working directory, arduino15 default directory. +// Returns a nil string if no config file has been found, without raising errors. +// Returns an error if any problem is encountered during the file research which prevents +// to understand whether a config file exists or not. +func searchConfigDir() (*string, error) { // Search in current directory and its parents. cwd, err := paths.Getwd() if err != nil { - return "", err + return nil, err } // Don't let bad naming mislead you, cwd.Parents()[0] is cwd itself so // we look in the current directory first and then on its parents. for _, path := range cwd.Parents() { if path.Join(Filename+".yaml").Exist() || path.Join(Filename+".json").Exist() { - return path.String(), nil + p := path.String() + return &p, nil } } // Search in arduino's default data directory. arduino15, err := arduino.DataDir() if err != nil { - return "", err + return nil, err } if arduino15.Join(Filename+".yaml").Exist() || arduino15.Join(Filename+".json").Exist() { - return arduino15.String(), nil + p := arduino15.String() + return &p, nil } - return "", fmt.Errorf( - "didn't find config file in the current directory, its parents or in %s", - arduino15.String(), - ) + // Didn't find config file in the current directory, its parents or in arduino15" + return nil, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..2fb7b47f --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,247 @@ +// This file is part of arduino-cloud-cli. +// +// Copyright (C) 2021 ARDUINO SA (http://www.arduino.cc/) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package config + +import ( + "os" + "testing" + + "encoding/json" + + "github.com/google/go-cmp/cmp" +) + +func TestRetrieve(t *testing.T) { + var ( + validSecret = "qaRZGEbnQNNvmaeTLqy8Bxs22wLZ6H7obIiNSveTLPdoQuylANnuy6WBOw16XoqH" + validClient = "CQ4iZ5sebOfhGRwUn3IV0r1YFMNrMTIx" + validConfig = &Config{validClient, validSecret} + invalidConfig = &Config{"", validSecret} + clientEnv = EnvPrefix + "_CLIENT" + secretEnv = EnvPrefix + "_SECRET" + ) + + tests := []struct { + name string + pre func() + post func() + wantedConfig *Config + wantedErr bool + }{ + { + name: "valid config written in env", + pre: func() { + os.Setenv(clientEnv, validConfig.Client) + os.Setenv(secretEnv, validConfig.Secret) + }, + post: func() { + os.Unsetenv(clientEnv) + os.Unsetenv(secretEnv) + }, + wantedConfig: validConfig, + wantedErr: false, + }, + + { + name: "invalid config written in env", + pre: func() { + os.Setenv(clientEnv, validConfig.Client) + os.Setenv(secretEnv, "") + }, + post: func() { + os.Unsetenv(clientEnv) + os.Unsetenv(secretEnv) + }, + wantedConfig: nil, + wantedErr: true, + }, + + { + name: "valid config written in parent of cwd", + pre: func() { + parent := "test-parent" + cwd := "test-parent/test-cwd" + os.MkdirAll(cwd, os.FileMode(0777)) + // Write valid config in parent dir + os.Chdir(parent) + b, _ := json.Marshal(validConfig) + os.WriteFile(Filename+".json", b, os.FileMode(0777)) + // Cwd has no config file + os.Chdir("test-cwd") + }, + post: func() { + os.Chdir("../..") + os.RemoveAll("test-parent") + }, + wantedConfig: validConfig, + wantedErr: false, + }, + + { + name: "invalid config written in cwd, ignore config of parent dir", + pre: func() { + parent := "test-parent" + cwd := "test-parent/test-cwd" + os.MkdirAll(cwd, os.FileMode(0777)) + // Write valid config in parent dir + os.Chdir(parent) + b, _ := json.Marshal(validConfig) + os.WriteFile(Filename+".json", b, os.FileMode(0777)) + // Write invalid config in cwd + os.Chdir("test-cwd") + b, _ = json.Marshal(invalidConfig) + os.WriteFile(Filename+".json", b, os.FileMode(0777)) + }, + post: func() { + os.Chdir("../..") + os.RemoveAll("test-parent") + os.Unsetenv(clientEnv) + os.Unsetenv(secretEnv) + }, + wantedConfig: nil, + wantedErr: true, + }, + + { + name: "invalid config written in env, ignore valid config of cwd", + pre: func() { + cwd := "test-cwd" + os.MkdirAll(cwd, os.FileMode(0777)) + // Write valid config in cwd + os.Chdir(cwd) + b, _ := json.Marshal(validConfig) + os.WriteFile(Filename+".json", b, os.FileMode(0777)) + // Write invalid config in env + os.Setenv(clientEnv, validConfig.Client) + os.Setenv(secretEnv, "") + }, + post: func() { + os.Chdir("..") + os.RemoveAll("test-cwd") + }, + wantedConfig: nil, + wantedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.pre() + got, err := Retrieve() + tt.post() + + if tt.wantedErr && err == nil { + t.Errorf("Expected an error, but got nil") + } + if !tt.wantedErr && err != nil { + t.Errorf("Expected nil error, but got: %v", err) + } + + if !cmp.Equal(got, tt.wantedConfig) { + t.Errorf("Wrong config received, diff:\n%s", cmp.Diff(tt.wantedConfig, got)) + } + }) + } +} + +func TestValidate(t *testing.T) { + var ( + validSecret = "qaRZGEbnQNNvmaeTLqy8Bxs22wLZ6H7obIiNSveTLPdoQuylANnuy6WBOw16XoqH" + validClient = "CQ4iZ5sebOfhGRwUn3IV0r1YFMNrMTIx" + ) + tests := []struct { + name string + config *Config + valid bool + }{ + { + name: "valid config", + config: &Config{ + Client: validClient, + Secret: validSecret, + }, + valid: true, + }, + { + name: "invalid client id", + config: &Config{Client: "", Secret: validSecret}, + valid: false, + }, + { + name: "invalid client secret", + config: &Config{Client: validClient, Secret: ""}, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.valid && err != nil { + t.Errorf( + "Wrong validation, the config was correct but an error was received: \nconfig: %v\nerr: %v", + tt.config, + err, + ) + } + if !tt.valid && err == nil { + t.Errorf( + "Wrong validation, the config was invalid but no error was received: \nconfig: %v", + tt.config, + ) + } + }) + } +} + +func TestIsEmpty(t *testing.T) { + var ( + validSecret = "qaRZGEbnQNNvmaeTLqy8Bxs22wLZ6H7obIiNSveTLPdoQuylANnuy6WBOw16XoqH" + validClient = "CQ4iZ5sebOfhGRwUn3IV0r1YFMNrMTIx" + ) + tests := []struct { + name string + config *Config + want bool + }{ + { + name: "empty config", + config: &Config{Client: "", Secret: ""}, + want: true, + }, + { + name: "config without id", + config: &Config{Client: "", Secret: validSecret}, + want: false, + }, + { + name: "config without secret", + config: &Config{Client: validClient, Secret: ""}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.config.IsEmpty() + if got != tt.want { + t.Errorf("Expected %v but got %v, with config: %v", tt.want, got, tt.config) + } + }) + } +} diff --git a/internal/config/default.go b/internal/config/default.go index ee0698a3..d593d317 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -26,7 +26,7 @@ var ( // SetDefaults sets the default values for configuration keys. func SetDefaults(settings *viper.Viper) { // Client ID - settings.SetDefault("client", "xxxxxxxxxxxxxx") + settings.SetDefault("client", "") // Secret - settings.SetDefault("secret", "xxxxxxxxxxxxxx") + settings.SetDefault("secret", "") } diff --git a/internal/iot/token.go b/internal/iot/token.go index 5361e48a..883bc11d 100644 --- a/internal/iot/token.go +++ b/internal/iot/token.go @@ -19,7 +19,9 @@ package iot import ( "context" + "fmt" "net/url" + "strings" "golang.org/x/oauth2" cc "golang.org/x/oauth2/clientcredentials" @@ -37,5 +39,9 @@ func token(client, secret string) (*oauth2.Token, error) { EndpointParams: additionalValues, } // Get the access token in exchange of client_id and client_secret - return config.Token(context.Background()) + t, err := config.Token(context.Background()) + if err != nil && strings.Contains(err.Error(), "401") { + return nil, fmt.Errorf("wrong credentials") + } + return t, err }