diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 0f04b0c..b0ea4bf 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net" "sync" "time" @@ -20,6 +21,7 @@ type Discovery struct { discovered map[string]*discoveredPeer mu sync.RWMutex stopCh chan struct{} + stopOnce sync.Once waitGroup sync.WaitGroup onAddUpstream func(url string, priority int) onRemoveUpstream func(url string) @@ -71,6 +73,7 @@ func (d *Discovery) Start(ctx context.Context) error { if err := d.resolver.Browse(ctx, d.cfg.ServiceName, d.cfg.Domain, entries); err != nil { close(entries) + d.stopOnce.Do(func() { close(d.stopCh) }) d.waitGroup.Wait() return fmt.Errorf("browse services: %w", err) } @@ -85,7 +88,7 @@ func (d *Discovery) Start(ctx context.Context) error { // Stops the discovery process. func (d *Discovery) Stop() { - close(d.stopCh) + d.stopOnce.Do(func() { close(d.stopCh) }) d.waitGroup.Wait() } @@ -122,8 +125,7 @@ func (d *Discovery) handleEntry(_ context.Context, entry *zeroconf.ServiceEntry) addr = entry.AddrIPv6[0].String() } - port := entry.Port - url := fmt.Sprintf("http://%s:%d", addr, port) + url := "http://" + net.JoinHostPort(addr, fmt.Sprintf("%d", entry.Port)) key := fmt.Sprintf("%s@%s", entry.Instance, entry.HostName) d.mu.Lock() @@ -184,6 +186,9 @@ func (d *Discovery) cleanupPeers() { // TTL is the discovery response time; peers should re-announce periodically. // Use 3x TTL as the expiration window. expiration := d.cfg.DiscoveryTime.Duration * 3 + if expiration == 0 { + expiration = time.Second + } for key, peer := range d.discovered { if now.Sub(peer.lastSeen) > expiration {