diff --git a/config/samples/unifi_v1beta1_firewallgroup.yaml b/config/samples/unifi_v1beta1_firewallgroup.yaml index 7a99241..d673c5e 100644 --- a/config/samples/unifi_v1beta1_firewallgroup.yaml +++ b/config/samples/unifi_v1beta1_firewallgroup.yaml @@ -9,9 +9,4 @@ spec: name: Test manualAddresses: - 192.168.1.153 - - 192.168.1.154 - - 192.168.1.155 - - 2a01::3 - - 2a01:0::5 - - 2a01:2a01::/32 # TODO(user): Add fields here diff --git a/internal/controller/firewallgroup_controller.go b/internal/controller/firewallgroup_controller.go index b194c06..d16069f 100644 --- a/internal/controller/firewallgroup_controller.go +++ b/internal/controller/firewallgroup_controller.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "reflect" + "strings" "k8s.io/apimachinery/pkg/runtime" ctrl "sigs.k8s.io/controller-runtime" @@ -90,6 +91,10 @@ func (r *FirewallGroupReconciler) Reconcile(ctx context.Context, req ctrl.Reques } } } + err := r.UnifiClient.Reauthenticate() + if err != nil { + return ctrl.Result{}, err + } firewall_groups, err := r.UnifiClient.Client.ListFirewallGroup(context.Background(), r.UnifiClient.SiteID) if err != nil { log.Error(err, "Could not list network objects") @@ -105,8 +110,21 @@ func (r *FirewallGroupReconciler) Reconcile(ctx context.Context, req ctrl.Reques log.Info(fmt.Sprintf("Delete %s", ipv4_name)) err := r.UnifiClient.Client.DeleteFirewallGroup(context.Background(), r.UnifiClient.SiteID, firewall_group.ID) if err != nil { - log.Error(err, "Could not delete firewall group") - return ctrl.Result{}, err + msg := strings.ToLower(err.Error()) + log.Info(msg) + if strings.Contains(msg, "api.err.objectreferredby") { + log.Info("Firewall group is in use. Invoking workaround...!") + firewall_group.GroupMembers = []string{"127.0.0.1"} + firewall_group.Name = firewall_group.Name + "-deleted" + _, updateerr := r.UnifiClient.Client.UpdateFirewallGroup(context.Background(), r.UnifiClient.SiteID, &firewall_group) + if updateerr != nil { + log.Error(updateerr, "Could neither delete or rename firewall group") + return ctrl.Result{}, updateerr + } + } else { + log.Error(err, "Could not delete firewall group") + return ctrl.Result{}, err + } } ipv4_done = true } else { @@ -127,8 +145,21 @@ func (r *FirewallGroupReconciler) Reconcile(ctx context.Context, req ctrl.Reques log.Info(fmt.Sprintf("Delete %s", ipv6_name)) err := r.UnifiClient.Client.DeleteFirewallGroup(context.Background(), r.UnifiClient.SiteID, firewall_group.ID) if err != nil { - log.Error(err, "Could not delete firewall group") - return ctrl.Result{}, err + msg := strings.ToLower(err.Error()) + log.Info(msg) + if strings.Contains(msg, "api.err.objectreferredby") { + log.Info("Firewall group is in use. Invoking workaround...!") + firewall_group.GroupMembers = []string{"::1"} + firewall_group.Name = firewall_group.Name + "-deleted" + _, updateerr := r.UnifiClient.Client.UpdateFirewallGroup(context.Background(), r.UnifiClient.SiteID, &firewall_group) + if updateerr != nil { + log.Error(updateerr, "Could neither delete or rename firewall group") + return ctrl.Result{}, updateerr + } + } else { + log.Error(err, "Could not delete firewall group") + return ctrl.Result{}, err + } } ipv6_done = true } else { @@ -144,6 +175,28 @@ func (r *FirewallGroupReconciler) Reconcile(ctx context.Context, req ctrl.Reques ipv6_done = true } } + if firewall_group.Name == ipv4_name+"-deleted" && len(ipv4) > 0 { + firewall_group.Name = ipv4_name + firewall_group.GroupMembers = ipv4 + log.Info(fmt.Sprintf("Creating %s (from previously deleted)", ipv4_name)) + _, err := r.UnifiClient.Client.UpdateFirewallGroup(context.Background(), r.UnifiClient.SiteID, &firewall_group) + if err != nil { + log.Error(err, "Could not update firewall group") + return ctrl.Result{}, err + } + ipv4_done = true + } + if firewall_group.Name == ipv6_name+"-deleted" && len(ipv6) > 0 { + firewall_group.Name = ipv6_name + firewall_group.GroupMembers = ipv6 + log.Info(fmt.Sprintf("Creating %s (from previously deleted)", ipv6_name)) + _, err := r.UnifiClient.Client.UpdateFirewallGroup(context.Background(), r.UnifiClient.SiteID, &firewall_group) + if err != nil { + log.Error(err, "Could not update firewall group") + return ctrl.Result{}, err + } + ipv6_done = true + } } if len(ipv4) > 0 && !ipv4_done { log.Info(fmt.Sprintf("Creating %s", ipv4_name)) diff --git a/internal/unifi/unifi.go b/internal/unifi/unifi.go index 819dbfa..7aa812e 100644 --- a/internal/unifi/unifi.go +++ b/internal/unifi/unifi.go @@ -10,13 +10,19 @@ import ( "net/http" "net/http/cookiejar" "os" + "strings" + "sync" "github.com/vegardengen/go-unifi/unifi" ) type UnifiClient struct { - Client *unifi.Client - SiteID string + Client *unifi.Client + SiteID string + mutex sync.Mutex + controller string + username string + password string } func CreateUnifiClient() (*UnifiClient, error) { @@ -64,9 +70,57 @@ func CreateUnifiClient() (*UnifiClient, error) { } unifiClient := &UnifiClient{ - Client: client, - SiteID: siteID, + Client: client, + SiteID: siteID, + controller: unifiURL, + username: username, + password: password, } return unifiClient, nil } + +func (s *UnifiClient) WithSession(action func(c *unifi.Client) error) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + err := action(s.Client) + if err == nil { + return nil + } + + if IsSessionExpired(err) { + if loginErr := s.Client.Login(context.Background(), s.username, s.password); loginErr != nil { + return fmt.Errorf("re-login to Unifi failed: %w", loginErr) + } + + return action(s.Client) + } + return err +} + +func (uClient *UnifiClient) Reauthenticate() error { + _, err := uClient.Client.ListSites(context.Background()) + if err == nil { + return nil + } + + if IsSessionExpired(err) { + if loginErr := uClient.Client.Login(context.Background(), uClient.username, uClient.password); loginErr != nil { + return fmt.Errorf("re-login to Unifi failed: %w", loginErr) + } + } + return nil +} + +func IsSessionExpired(err error) bool { + if err == nil { + return false + } + + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "unauthorized") || + strings.Contains(msg, "authentication") || + strings.Contains(msg, "login required") || + strings.Contains(msg, "token") +}