// Package server implements the API server.
package server
import (
"cmp"
"context"
"fmt"
"log"
"strings"
"gitlab.com/lyda/external-dns-openwrt/internal/api"
"gitlab.com/lyda/external-dns-openwrt/internal/openwrt"
"golang.org/x/exp/slices"
)
// Clients is a map of Client objects.
type Clients struct {
clients openwrt.DNSClients
clientMap map[string]*openwrt.DNSClient
domains []string
debug bool
}
// NewClients creates a new Clients instance.
func NewClients(debug bool) *Clients {
c := &Clients{debug: debug}
c.clientMap = make(map[string]*openwrt.DNSClient, 1)
return c
}
// AddClient adds a server and domains.
func (ow *Clients) AddClient(server string, domains []string) {
c := openwrt.New(server, domains, ow.debug)
ow.clients = append(ow.clients, c)
for _, domain := range domains {
ow.clientMap[domain] = c
ow.domains = append(ow.domains, domain)
}
slices.SortFunc(ow.domains, func(a, b string) int {
return cmp.Compare(len(b), len(a))
})
}
// PickClient picks a client.
func (ow *Clients) PickClient(host string) (*openwrt.DNSClient, error) {
for _, d := range ow.domains {
if strings.HasSuffix(host, d) {
return ow.clientMap[d], nil
}
}
return nil, fmt.Errorf("unable to find server for %s", host)
}
// GetClients gets a list of clients.
func (ow *Clients) GetClients() openwrt.DNSClients {
return ow.clients
}
// Finalise reloads the dnsmasq servers.
func (ow *Clients) Finalise() {
for _, c := range ow.clients {
_ = c.Finalise()
}
}
// API implements the api server.
type API struct {
clients *Clients
}
// NewAPI creates the API server.
func NewAPI(clients *Clients) *API {
return &API{clients: clients}
}
func (a *API) endpointsAddOrDelete(endpoints *api.Endpoints, add bool) {
for _, endpoint := range *endpoints {
if endpoint.DnsName == nil {
continue
}
if endpoint.RecordType != nil {
targets := endpoint.Targets
if targets == nil {
targets = &api.Targets{}
}
for _, target := range *targets {
c, err := a.clients.PickClient(*endpoint.DnsName)
if err != nil {
log.Printf("no client for '%s'(%s)\n", *endpoint.DnsName, target)
continue
}
if a.clients.debug {
log.Printf("trying to add '%s' '%s'(%s)\n",
*endpoint.RecordType, *endpoint.DnsName, target)
}
if add {
err = c.Add(*endpoint.RecordType, *endpoint.DnsName, target)
} else {
err = c.Delete(*endpoint.RecordType, *endpoint.DnsName, target)
}
if err != nil {
log.Printf("error adding '%s' '%s'(%s) [%s]\n",
*endpoint.RecordType, *endpoint.DnsName, target, err)
continue
}
}
}
}
}
// AdjustRecords implements the POST:/adjustrecords handler.
func (a *API) AdjustRecords(_ context.Context, req api.AdjustRecordsRequestObject) (api.AdjustRecordsResponseObject, error) {
endpoints := (*api.Endpoints)(req.Body)
if endpoints == nil {
endpoints = &api.Endpoints{}
}
a.endpointsAddOrDelete(endpoints, true)
a.clients.Finalise()
return api.AdjustRecords200ApplicationExternalDNSWebhookPlusJSONVersion1Response(*endpoints),
nil
}
// GetRecords implements the GET:records handler.
func (a *API) GetRecords(_ context.Context, _ api.GetRecordsRequestObject) (api.GetRecordsResponseObject, error) {
endpoints := api.Endpoints{}
for _, client := range a.clients.GetClients() {
entries, err := client.List()
if err != nil {
return api.GetRecords500Response{}, nil
}
lastDomain := ""
lastRR := ""
for _, entry := range entries {
domain := entry.Domain
rr := entry.RR
if domain == lastDomain && rr == lastRR {
*endpoints[len(endpoints)-1].Targets = append(*endpoints[len(endpoints)-1].Targets, entry.Target)
} else {
targets := &[]string{entry.Target}
endpoint := api.Endpoint{
RecordType: &rr,
DnsName: &domain,
RecordTTL: new(int64),
Targets: targets,
}
*endpoint.RecordTTL = int64(60)
endpoints = append(endpoints, endpoint)
}
lastDomain = domain
lastRR = rr
}
}
return api.GetRecords200ApplicationExternalDNSWebhookPlusJSONVersion1Response(endpoints),
nil
}
// Healthz implements the GET:/healthz handler.
func (a *API) Healthz(_ context.Context, _ api.HealthzRequestObject) (api.HealthzResponseObject, error) {
return api.Healthz200TextResponse("ok"), nil
}
// Negotiate implements the GET:/ handler.
func (a *API) Negotiate(_ context.Context, _ api.NegotiateRequestObject) (api.NegotiateResponseObject, error) {
filters := a.clients.domains
return api.Negotiate200ApplicationExternalDNSWebhookPlusJSONVersion1Response{
Filters: &filters,
}, nil
}
// SetRecords implements the POST:/records handler.
func (a *API) SetRecords(_ context.Context, req api.SetRecordsRequestObject) (api.SetRecordsResponseObject, error) {
changes := (*api.Changes)(req.Body)
if changes == nil {
log.Print("SetRecords: empty change sent")
return api.SetRecords500Response{}, nil
}
if a.clients.debug {
log.Printf("SetRecords changes: %+v\n", changes)
}
toCreate, toDelete := simplifyChanges(changes)
a.endpointsAddOrDelete(toCreate, true)
a.endpointsAddOrDelete(toDelete, false)
return api.SetRecords204Response{}, nil
}
// Package server implements the server. This part is originally from
// the echo framework. Specifically from bind.go.
//
// Won't be needed if https://github.com/labstack/echo/pull/2572 is
// accepted.
package server
import (
"encoding"
"errors"
"mime"
"net/http"
"reflect"
"strconv"
"strings"
"github.com/labstack/echo/v4"
)
// WebhookBinder Is the echo binder for the external-dns webhook content type.
type WebhookBinder struct{}
// BindPathParams binds path params to bindable object
func (b *WebhookBinder) BindPathParams(c echo.Context, i interface{}) error {
names := c.ParamNames()
values := c.ParamValues()
params := map[string][]string{}
for i, name := range names {
params[name] = []string{values[i]}
}
if err := b.bindData(i, params, "param"); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
return nil
}
// BindQueryParams binds query params to bindable object
func (b *WebhookBinder) BindQueryParams(c echo.Context, i interface{}) error {
if err := b.bindData(i, c.QueryParams(), "query"); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
return nil
}
func contentTypeIsJSON(mediaType string) bool {
parsed, _, err := mime.ParseMediaType(mediaType)
if err != nil {
return false
}
return parsed == "application/json" || strings.HasSuffix(parsed, "+json")
}
// BindBody binds request body contents to bindable object
// NB: then binding forms take note that this implementation uses standard library form parsing
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
func (b *WebhookBinder) BindBody(c echo.Context, i interface{}) (err error) {
req := c.Request()
if req.ContentLength == 0 {
return
}
ctype := req.Header.Get(echo.HeaderContentType)
if contentTypeIsJSON(ctype) {
if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil {
switch err.(type) {
case *echo.HTTPError:
return err
default:
return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
}
} else {
return echo.ErrUnsupportedMediaType
}
return nil
}
// BindHeaders binds HTTP headers to a bindable object
func (b *WebhookBinder) BindHeaders(c echo.Context, i interface{}) error {
if err := b.bindData(i, c.Request().Header, "header"); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
return nil
}
// Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *WebhookBinder) Bind(i interface{}, c echo.Context) (err error) {
if err := b.BindPathParams(c, i); err != nil {
return err
}
// Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
// For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues.
// The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670)
method := c.Request().Method
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
if err = b.BindQueryParams(c, i); err != nil {
return err
}
}
return b.BindBody(c, i)
}
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
func (b *WebhookBinder) bindData(destination interface{}, data map[string][]string, tag string) error {
if destination == nil || len(data) == 0 {
return nil
}
typ := reflect.TypeOf(destination).Elem()
val := reflect.ValueOf(destination).Elem()
// Support binding to limited Map destinations:
// - map[string][]string,
// - map[string]string <-- (binds first value from data slice)
// - map[string]interface{}
// You are better off binding to struct but there are user who want this map feature. Source of data for these cases are:
// params,query,header,form as these sources produce string values, most of the time slice of strings, actually.
if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String {
k := typ.Elem().Kind()
isElemInterface := k == reflect.Interface
isElemString := k == reflect.String
isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String
if !(isElemSliceOfStrings || isElemString || isElemInterface) {
return nil
}
for k, v := range data {
if isElemString {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
} else {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v))
}
}
return nil
}
// !struct
if typ.Kind() != reflect.Struct {
if tag == "param" || tag == "query" || tag == "header" {
// incompatible type, data is probably to be found in the body
return nil
}
return errors.New("binding element must be a struct")
}
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
structField := val.Field(i)
if typeField.Anonymous {
if structField.Kind() == reflect.Ptr {
structField = structField.Elem()
}
}
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := typeField.Tag.Get(tag)
if typeField.Anonymous && structField.Kind() == reflect.Struct && inputFieldName != "" {
// if anonymous struct with query/param/form tags, report an error
return errors.New("query/param/form tags are not allowed with anonymous struct field")
}
if inputFieldName == "" {
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag
if _, ok := structField.Addr().Interface().(echo.BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
return err
}
}
// does not have explicit tag and is not an ordinary struct - so move to next field
continue
}
inputValue, exists := data[inputFieldName]
if !exists {
// Go json.Unmarshal supports case insensitive binding. However the
// url params are bound case sensitive which is inconsistent. To
// fix this we must check all of the map values in a
// case-insensitive search.
for k, v := range data {
if strings.EqualFold(k, inputFieldName) {
inputValue = v
exists = true
break
}
}
}
if !exists {
continue
}
// Call this first, in case we're dealing with an alias to an array type
if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
if err != nil {
return err
}
continue
}
numElems := len(inputValue)
if structFieldKind == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind()
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for j := 0; j < numElems; j++ {
if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
return err
}
}
val.Field(i).Set(slice)
} else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
return err
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
// But also call it here, in case we're dealing with an array of BindUnmarshalers
if ok, err := unmarshalField(valueKind, val, structField); ok {
return err
}
switch valueKind {
case reflect.Ptr:
return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
case reflect.Int:
return setIntField(val, 0, structField)
case reflect.Int8:
return setIntField(val, 8, structField)
case reflect.Int16:
return setIntField(val, 16, structField)
case reflect.Int32:
return setIntField(val, 32, structField)
case reflect.Int64:
return setIntField(val, 64, structField)
case reflect.Uint:
return setUintField(val, 0, structField)
case reflect.Uint8:
return setUintField(val, 8, structField)
case reflect.Uint16:
return setUintField(val, 16, structField)
case reflect.Uint32:
return setUintField(val, 32, structField)
case reflect.Uint64:
return setUintField(val, 64, structField)
case reflect.Bool:
return setBoolField(val, structField)
case reflect.Float32:
return setFloatField(val, 32, structField)
case reflect.Float64:
return setFloatField(val, 64, structField)
case reflect.String:
structField.SetString(val)
default:
return errors.New("unknown type")
}
return nil
}
func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
switch valueKind {
case reflect.Ptr:
return unmarshalFieldPtr(val, field)
default:
return unmarshalFieldNonPtr(val, field)
}
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
fieldIValue := field.Addr().Interface()
if unmarshaler, ok := fieldIValue.(echo.BindUnmarshaler); ok {
return true, unmarshaler.UnmarshalParam(value)
}
if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok {
return true, unmarshaler.UnmarshalText([]byte(value))
}
return false, nil
}
func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
if field.IsNil() {
// Initialize the pointer to a nil value
field.Set(reflect.New(field.Type().Elem()))
}
return unmarshalFieldNonPtr(value, field.Elem())
}
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return err
}
func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return err
}
func setBoolField(value string, field reflect.Value) error {
if value == "" {
value = "false"
}
boolVal, err := strconv.ParseBool(value)
if err == nil {
field.SetBool(boolVal)
}
return err
}
func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0.0"
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return err
}
// Package server implements the API server.
package server
import (
"log"
"gitlab.com/lyda/external-dns-openwrt/internal/api"
"golang.org/x/exp/slices"
)
func simplifyChanges(changes *api.Changes) (*api.Endpoints, *api.Endpoints) {
var toCreate, toDelete api.Endpoints
if changes.Create != nil {
toCreate = make(api.Endpoints, len(*changes.Create))
copy(toCreate, *changes.Create)
}
if changes.Delete != nil {
toDelete = make(api.Endpoints, len(*changes.Delete))
copy(toDelete, *changes.Delete)
}
if changes.UpdateOld != nil &&
changes.UpdateNew != nil &&
len(*changes.UpdateOld) == len(*changes.UpdateNew) {
for i, updateOldEndpoint := range *changes.UpdateOld {
updateNewEndpoint := (*changes.UpdateNew)[i]
if !compare(updateOldEndpoint, updateNewEndpoint) {
toDelete = append(toDelete, updateOldEndpoint)
toCreate = append(toCreate, updateNewEndpoint)
}
}
} else if changes.UpdateOld != nil || changes.UpdateNew != nil {
if changes.UpdateOld != nil && changes.UpdateNew != nil {
log.Printf("changes have mismatched updates for old(%d) and new(%d)",
len(*changes.UpdateOld), len(*changes.UpdateNew))
} else if changes.UpdateOld != nil {
log.Print("changes only have old updates")
} else if changes.UpdateNew != nil {
log.Print("changes only have new updates")
}
return nil, nil
}
return &toCreate, &toDelete
}
func compare(a api.Endpoint, b api.Endpoint) bool {
return a.DnsName == b.DnsName &&
a.RecordType == b.RecordType &&
slices.Compare(*a.Targets, *b.Targets) == 0
}
// Package server implements the API server.
package server
import (
"fmt"
"mime"
"strings"
"github.com/labstack/echo-contrib/echoprometheus"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/cobra"
"gitlab.com/lyda/external-dns-openwrt/internal/api"
)
// Run creates and runs the server.
func Run(cmd *cobra.Command, _ []string) {
// Get config values.
host, err := cmd.Flags().GetString("listen")
cobra.CheckErr(err)
debug, err := cmd.Flags().GetBool("debug")
cobra.CheckErr(err)
// Initialise all components.
clients := NewClients(debug)
clientArgs, err := cmd.Flags().GetStringArray("client")
cobra.CheckErr(err)
for _, client := range clientArgs {
tmp := strings.Split(client, ",")
clients.AddClient(tmp[0], tmp[1:])
}
apiHandler := api.NewStrictHandler(NewAPI(clients), nil)
err = mime.AddExtensionType(".js", "application/javascript")
cobra.CheckErr(err)
err = mime.AddExtensionType(".json", "application/json")
cobra.CheckErr(err)
e := echo.New()
logFormat := `${remote_ip} - - [${time_custom}] "${method} ${path} ${protocol}" ${status} ${bytes_out} "${referer}" "${user_agent}"` + "\n"
customTimeFormat := "2/Jan/2006:15:04:05 -0700"
logMiddleware := middleware.LoggerWithConfig(middleware.LoggerConfig{
Format: logFormat,
CustomTimeFormat: customTimeFormat,
})
// Register components with echo server.
api.RegisterHandlers(e, apiHandler)
e.Use(logMiddleware)
if debug {
e.Use(middleware.BodyDump(func(c echo.Context, reqBody []byte, resBody []byte) {
fmt.Printf("Request: %s\n", string(reqBody))
fmt.Printf("Response: %s\n", string(resBody))
}))
}
e.Binder = &WebhookBinder{}
e.Use(echoprometheus.NewMiddleware("openwrt"))
e.GET("/metrics", echo.WrapHandler(promhttp.Handler()))
e.Debug = debug
e.Logger.Fatal(e.Start(host))
}