// Package server implements the API server. package server import ( "cmp" "context" "fmt" "log" "strings" "gitlab.com/lyda/external-dns-openwrt/internal/api" "gitlab.com/lyda/external-dns-openwrt/internal/openwrt" "golang.org/x/exp/slices" ) // Clients is a map of Client objects. type Clients struct { clients openwrt.DNSClients clientMap map[string]*openwrt.DNSClient domains []string debug bool } // NewClients creates a new Clients instance. func NewClients(debug bool) *Clients { c := &Clients{debug: debug} c.clientMap = make(map[string]*openwrt.DNSClient, 1) return c } // AddClient adds a server and domains. func (ow *Clients) AddClient(server string, domains []string) { c := openwrt.New(server, domains, ow.debug) ow.clients = append(ow.clients, c) for _, domain := range domains { ow.clientMap[domain] = c ow.domains = append(ow.domains, domain) } slices.SortFunc(ow.domains, func(a, b string) int { return cmp.Compare(len(b), len(a)) }) } // PickClient picks a client. func (ow *Clients) PickClient(host string) (*openwrt.DNSClient, error) { for _, d := range ow.domains { if strings.HasSuffix(host, d) { return ow.clientMap[d], nil } } return nil, fmt.Errorf("unable to find server for %s", host) } // GetClients gets a list of clients. func (ow *Clients) GetClients() openwrt.DNSClients { return ow.clients } // Finalise reloads the dnsmasq servers. func (ow *Clients) Finalise() { for _, c := range ow.clients { _ = c.Finalise() } } // API implements the api server. type API struct { clients *Clients } // NewAPI creates the API server. func NewAPI(clients *Clients) *API { return &API{clients: clients} } func (a *API) endpointsAddOrDelete(endpoints *api.Endpoints, add bool) { for _, endpoint := range *endpoints { if endpoint.DnsName == nil { continue } if endpoint.RecordType != nil { targets := endpoint.Targets if targets == nil { targets = &api.Targets{} } for _, target := range *targets { c, err := a.clients.PickClient(*endpoint.DnsName) if err != nil { log.Printf("no client for '%s'(%s)\n", *endpoint.DnsName, target) continue } if a.clients.debug { log.Printf("trying to add '%s' '%s'(%s)\n", *endpoint.RecordType, *endpoint.DnsName, target) } if add { err = c.Add(*endpoint.RecordType, *endpoint.DnsName, target) } else { err = c.Delete(*endpoint.RecordType, *endpoint.DnsName, target) } if err != nil { log.Printf("error adding '%s' '%s'(%s) [%s]\n", *endpoint.RecordType, *endpoint.DnsName, target, err) continue } } } } } // AdjustRecords implements the POST:/adjustrecords handler. func (a *API) AdjustRecords(_ context.Context, req api.AdjustRecordsRequestObject) (api.AdjustRecordsResponseObject, error) { endpoints := (*api.Endpoints)(req.Body) if endpoints == nil { endpoints = &api.Endpoints{} } a.endpointsAddOrDelete(endpoints, true) a.clients.Finalise() return api.AdjustRecords200ApplicationExternalDNSWebhookPlusJSONVersion1Response(*endpoints), nil } // GetRecords implements the GET:records handler. func (a *API) GetRecords(_ context.Context, _ api.GetRecordsRequestObject) (api.GetRecordsResponseObject, error) { endpoints := api.Endpoints{} for _, client := range a.clients.GetClients() { entries, err := client.List() if err != nil { return api.GetRecords500Response{}, nil } lastDomain := "" lastRR := "" for _, entry := range entries { domain := entry.Domain rr := entry.RR if domain == lastDomain && rr == lastRR { *endpoints[len(endpoints)-1].Targets = append(*endpoints[len(endpoints)-1].Targets, entry.Target) } else { targets := &[]string{entry.Target} endpoint := api.Endpoint{ RecordType: &rr, DnsName: &domain, RecordTTL: new(int64), Targets: targets, } *endpoint.RecordTTL = int64(60) endpoints = append(endpoints, endpoint) } lastDomain = domain lastRR = rr } } return api.GetRecords200ApplicationExternalDNSWebhookPlusJSONVersion1Response(endpoints), nil } // Healthz implements the GET:/healthz handler. func (a *API) Healthz(_ context.Context, _ api.HealthzRequestObject) (api.HealthzResponseObject, error) { return api.Healthz200TextResponse("ok"), nil } // Negotiate implements the GET:/ handler. func (a *API) Negotiate(_ context.Context, _ api.NegotiateRequestObject) (api.NegotiateResponseObject, error) { filters := a.clients.domains return api.Negotiate200ApplicationExternalDNSWebhookPlusJSONVersion1Response{ Filters: &filters, }, nil } // SetRecords implements the POST:/records handler. func (a *API) SetRecords(_ context.Context, req api.SetRecordsRequestObject) (api.SetRecordsResponseObject, error) { changes := (*api.Changes)(req.Body) if changes == nil { log.Print("SetRecords: empty change sent") return api.SetRecords500Response{}, nil } if a.clients.debug { log.Printf("SetRecords changes: %+v\n", changes) } toCreate, toDelete := simplifyChanges(changes) a.endpointsAddOrDelete(toCreate, true) a.endpointsAddOrDelete(toDelete, false) return api.SetRecords204Response{}, nil }
// Package server implements the server. This part is originally from // the echo framework. Specifically from bind.go. // // Won't be needed if https://github.com/labstack/echo/pull/2572 is // accepted. package server import ( "encoding" "errors" "mime" "net/http" "reflect" "strconv" "strings" "github.com/labstack/echo/v4" ) // WebhookBinder Is the echo binder for the external-dns webhook content type. type WebhookBinder struct{} // BindPathParams binds path params to bindable object func (b *WebhookBinder) BindPathParams(c echo.Context, i interface{}) error { names := c.ParamNames() values := c.ParamValues() params := map[string][]string{} for i, name := range names { params[name] = []string{values[i]} } if err := b.bindData(i, params, "param"); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil } // BindQueryParams binds query params to bindable object func (b *WebhookBinder) BindQueryParams(c echo.Context, i interface{}) error { if err := b.bindData(i, c.QueryParams(), "query"); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil } func contentTypeIsJSON(mediaType string) bool { parsed, _, err := mime.ParseMediaType(mediaType) if err != nil { return false } return parsed == "application/json" || strings.HasSuffix(parsed, "+json") } // BindBody binds request body contents to bindable object // NB: then binding forms take note that this implementation uses standard library form parsing // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm func (b *WebhookBinder) BindBody(c echo.Context, i interface{}) (err error) { req := c.Request() if req.ContentLength == 0 { return } ctype := req.Header.Get(echo.HeaderContentType) if contentTypeIsJSON(ctype) { if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { switch err.(type) { case *echo.HTTPError: return err default: return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } } } else { return echo.ErrUnsupportedMediaType } return nil } // BindHeaders binds HTTP headers to a bindable object func (b *WebhookBinder) BindHeaders(c echo.Context, i interface{}) error { if err := b.bindData(i, c.Request().Header, "header"); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous // step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. func (b *WebhookBinder) Bind(i interface{}, c echo.Context) (err error) { if err := b.BindPathParams(c, i); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { if err = b.BindQueryParams(c, i); err != nil { return err } } return b.BindBody(c, i) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag func (b *WebhookBinder) bindData(destination interface{}, data map[string][]string, tag string) error { if destination == nil || len(data) == 0 { return nil } typ := reflect.TypeOf(destination).Elem() val := reflect.ValueOf(destination).Elem() // Support binding to limited Map destinations: // - map[string][]string, // - map[string]string <-- (binds first value from data slice) // - map[string]interface{} // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are: // params,query,header,form as these sources produce string values, most of the time slice of strings, actually. if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String { k := typ.Elem().Kind() isElemInterface := k == reflect.Interface isElemString := k == reflect.String isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String if !(isElemSliceOfStrings || isElemString || isElemInterface) { return nil } for k, v := range data { if isElemString { val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) } else { val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v)) } } return nil } // !struct if typ.Kind() != reflect.Struct { if tag == "param" || tag == "query" || tag == "header" { // incompatible type, data is probably to be found in the body return nil } return errors.New("binding element must be a struct") } for i := 0; i < typ.NumField(); i++ { typeField := typ.Field(i) structField := val.Field(i) if typeField.Anonymous { if structField.Kind() == reflect.Ptr { structField = structField.Elem() } } if !structField.CanSet() { continue } structFieldKind := structField.Kind() inputFieldName := typeField.Tag.Get(tag) if typeField.Anonymous && structField.Kind() == reflect.Struct && inputFieldName != "" { // if anonymous struct with query/param/form tags, report an error return errors.New("query/param/form tags are not allowed with anonymous struct field") } if inputFieldName == "" { // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). // structs that implement BindUnmarshaler are binded only when they have explicit tag if _, ok := structField.Addr().Interface().(echo.BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { return err } } // does not have explicit tag and is not an ordinary struct - so move to next field continue } inputValue, exists := data[inputFieldName] if !exists { // Go json.Unmarshal supports case insensitive binding. However the // url params are bound case sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. for k, v := range data { if strings.EqualFold(k, inputFieldName) { inputValue = v exists = true break } } } if !exists { continue } // Call this first, in case we're dealing with an alias to an array type if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { if err != nil { return err } continue } numElems := len(inputValue) if structFieldKind == reflect.Slice && numElems > 0 { sliceOf := structField.Type().Elem().Kind() slice := reflect.MakeSlice(structField.Type(), numElems, numElems) for j := 0; j < numElems; j++ { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { return err } } val.Field(i).Set(slice) } else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { return err } } return nil } func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { // But also call it here, in case we're dealing with an array of BindUnmarshalers if ok, err := unmarshalField(valueKind, val, structField); ok { return err } switch valueKind { case reflect.Ptr: return setWithProperType(structField.Elem().Kind(), val, structField.Elem()) case reflect.Int: return setIntField(val, 0, structField) case reflect.Int8: return setIntField(val, 8, structField) case reflect.Int16: return setIntField(val, 16, structField) case reflect.Int32: return setIntField(val, 32, structField) case reflect.Int64: return setIntField(val, 64, structField) case reflect.Uint: return setUintField(val, 0, structField) case reflect.Uint8: return setUintField(val, 8, structField) case reflect.Uint16: return setUintField(val, 16, structField) case reflect.Uint32: return setUintField(val, 32, structField) case reflect.Uint64: return setUintField(val, 64, structField) case reflect.Bool: return setBoolField(val, structField) case reflect.Float32: return setFloatField(val, 32, structField) case reflect.Float64: return setFloatField(val, 64, structField) case reflect.String: structField.SetString(val) default: return errors.New("unknown type") } return nil } func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { switch valueKind { case reflect.Ptr: return unmarshalFieldPtr(val, field) default: return unmarshalFieldNonPtr(val, field) } } func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { fieldIValue := field.Addr().Interface() if unmarshaler, ok := fieldIValue.(echo.BindUnmarshaler); ok { return true, unmarshaler.UnmarshalParam(value) } if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok { return true, unmarshaler.UnmarshalText([]byte(value)) } return false, nil } func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { if field.IsNil() { // Initialize the pointer to a nil value field.Set(reflect.New(field.Type().Elem())) } return unmarshalFieldNonPtr(value, field.Elem()) } func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0" } intVal, err := strconv.ParseInt(value, 10, bitSize) if err == nil { field.SetInt(intVal) } return err } func setUintField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0" } uintVal, err := strconv.ParseUint(value, 10, bitSize) if err == nil { field.SetUint(uintVal) } return err } func setBoolField(value string, field reflect.Value) error { if value == "" { value = "false" } boolVal, err := strconv.ParseBool(value) if err == nil { field.SetBool(boolVal) } return err } func setFloatField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0.0" } floatVal, err := strconv.ParseFloat(value, bitSize) if err == nil { field.SetFloat(floatVal) } return err }
// Package server implements the API server. package server import ( "log" "gitlab.com/lyda/external-dns-openwrt/internal/api" "golang.org/x/exp/slices" ) func simplifyChanges(changes *api.Changes) (*api.Endpoints, *api.Endpoints) { var toCreate, toDelete api.Endpoints if changes.Create != nil { toCreate = make(api.Endpoints, len(*changes.Create)) copy(toCreate, *changes.Create) } if changes.Delete != nil { toDelete = make(api.Endpoints, len(*changes.Delete)) copy(toDelete, *changes.Delete) } if changes.UpdateOld != nil && changes.UpdateNew != nil && len(*changes.UpdateOld) == len(*changes.UpdateNew) { for i, updateOldEndpoint := range *changes.UpdateOld { updateNewEndpoint := (*changes.UpdateNew)[i] if !compare(updateOldEndpoint, updateNewEndpoint) { toDelete = append(toDelete, updateOldEndpoint) toCreate = append(toCreate, updateNewEndpoint) } } } else if changes.UpdateOld != nil || changes.UpdateNew != nil { if changes.UpdateOld != nil && changes.UpdateNew != nil { log.Printf("changes have mismatched updates for old(%d) and new(%d)", len(*changes.UpdateOld), len(*changes.UpdateNew)) } else if changes.UpdateOld != nil { log.Print("changes only have old updates") } else if changes.UpdateNew != nil { log.Print("changes only have new updates") } return nil, nil } return &toCreate, &toDelete } func compare(a api.Endpoint, b api.Endpoint) bool { return a.DnsName == b.DnsName && a.RecordType == b.RecordType && slices.Compare(*a.Targets, *b.Targets) == 0 }
// Package server implements the API server. package server import ( "fmt" "mime" "strings" "github.com/labstack/echo-contrib/echoprometheus" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" "gitlab.com/lyda/external-dns-openwrt/internal/api" ) // Run creates and runs the server. func Run(cmd *cobra.Command, _ []string) { // Get config values. host, err := cmd.Flags().GetString("listen") cobra.CheckErr(err) debug, err := cmd.Flags().GetBool("debug") cobra.CheckErr(err) // Initialise all components. clients := NewClients(debug) clientArgs, err := cmd.Flags().GetStringArray("client") cobra.CheckErr(err) for _, client := range clientArgs { tmp := strings.Split(client, ",") clients.AddClient(tmp[0], tmp[1:]) } apiHandler := api.NewStrictHandler(NewAPI(clients), nil) err = mime.AddExtensionType(".js", "application/javascript") cobra.CheckErr(err) err = mime.AddExtensionType(".json", "application/json") cobra.CheckErr(err) e := echo.New() logFormat := `${remote_ip} - - [${time_custom}] "${method} ${path} ${protocol}" ${status} ${bytes_out} "${referer}" "${user_agent}"` + "\n" customTimeFormat := "2/Jan/2006:15:04:05 -0700" logMiddleware := middleware.LoggerWithConfig(middleware.LoggerConfig{ Format: logFormat, CustomTimeFormat: customTimeFormat, }) // Register components with echo server. api.RegisterHandlers(e, apiHandler) e.Use(logMiddleware) if debug { e.Use(middleware.BodyDump(func(c echo.Context, reqBody []byte, resBody []byte) { fmt.Printf("Request: %s\n", string(reqBody)) fmt.Printf("Response: %s\n", string(resBody)) })) } e.Binder = &WebhookBinder{} e.Use(echoprometheus.NewMiddleware("openwrt")) e.GET("/metrics", echo.WrapHandler(promhttp.Handler())) e.Debug = debug e.Logger.Fatal(e.Start(host)) }