package v6

import (
	"errors"
	"fmt"
	"io"
	"strings"
	"time"

	"github.com/hashicorp/go-multierror"
	"github.com/iancoleman/strcase"
	"github.com/scylladb/go-set/strset"

	"github.com/anchore/go-logger"
	"github.com/anchore/grype/grype/db/v6/name"
	"github.com/anchore/grype/grype/pkg"
	"github.com/anchore/grype/grype/search"
	"github.com/anchore/grype/grype/version"
	"github.com/anchore/grype/grype/vulnerability"
	"github.com/anchore/grype/internal/log"
	"github.com/anchore/syft/syft/cpe"
	syftPkg "github.com/anchore/syft/syft/pkg"
)

var (
	_ vulnerability.Provider              = (*vulnerabilityProvider)(nil)
	_ vulnerability.StoreMetadataProvider = (*vulnerabilityProvider)(nil)
)

func NewVulnerabilityProvider(rdr Reader) vulnerability.Provider {
	return &vulnerabilityProvider{
		reader: rdr,
	}
}

type vulnerabilityProvider struct {
	reader Reader
}

var _ interface {
	vulnerability.Provider
} = (*vulnerabilityProvider)(nil)

// Deprecated: vulnerability.Vulnerability objects now have metadata included
func (vp vulnerabilityProvider) VulnerabilityMetadata(ref vulnerability.Reference) (*vulnerability.Metadata, error) {
	vuln, ok := ref.Internal.(*VulnerabilityHandle)
	if !ok {
		var err error
		vuln, err = vp.fetchVulnerability(ref)
		if err != nil {
			return nil, err
		}
	}

	if vuln == nil {
		log.WithFields("id", ref.ID, "namespace", ref.Namespace).Debug("unable to find vulnerability for given reference")
		return &vulnerability.Metadata{
			ID:         ref.ID,
			DataSource: strings.Split(ref.Namespace, ":")[0],
			Namespace:  ref.Namespace,
			Severity:   toSeverityString(vulnerability.UnknownSeverity),
		}, nil
	}

	return vp.getVulnerabilityMetadata(vuln, ref.Namespace)
}

func (vp vulnerabilityProvider) getVulnerabilityMetadata(vuln *VulnerabilityHandle, namespace string) (*vulnerability.Metadata, error) {
	cves := getCVEs(vuln)

	kevs, err := vp.fetchKnownExploited(cves)
	if err != nil {
		log.WithFields("id", vuln.Name, "vulnerability", vuln.String(), "error", err).Debug("unable to fetch known exploited from vulnerability")
	}

	epss, err := vp.fetchEpss(cves)
	if err != nil {
		log.WithFields("id", vuln.Name, "vulnerability", vuln.String(), "error", err).Debug("unable to fetch epss from vulnerability")
	}

	return newVulnerabilityMetadata(vuln, namespace, kevs, epss)
}

func newVulnerabilityMetadata(vuln *VulnerabilityHandle, namespace string, kevs []vulnerability.KnownExploited, epss []vulnerability.EPSS) (*vulnerability.Metadata, error) {
	if vuln == nil {
		return nil, nil
	}

	sev, cvss, err := extractSeverities(vuln)
	if err != nil {
		log.WithFields("id", vuln.Name, "vulnerability", vuln.String()).Debug("unable to extract severity from vulnerability")
	}

	return &vulnerability.Metadata{
		ID:             vuln.Name,
		DataSource:     firstReferenceURL(vuln),
		Namespace:      namespace,
		Severity:       toSeverityString(sev),
		URLs:           lastReferenceURLs(vuln),
		Description:    vuln.BlobValue.Description,
		Cvss:           cvss,
		KnownExploited: kevs,
		EPSS:           epss,
	}, nil
}

func (vp vulnerabilityProvider) DataProvenance() (map[string]vulnerability.DataProvenance, error) {
	providers, err := vp.reader.AllProviders()
	if err != nil {
		return nil, err
	}
	dps := make(map[string]vulnerability.DataProvenance)

	for _, p := range providers {
		var date time.Time
		if p.DateCaptured != nil {
			date = *p.DateCaptured
		}
		dps[p.ID] = vulnerability.DataProvenance{
			DateCaptured: date,
			InputDigest:  p.InputDigest,
		}
	}
	return dps, nil
}

func (vp vulnerabilityProvider) fetchVulnerability(ref vulnerability.Reference) (*VulnerabilityHandle, error) {
	provider := strings.Split(ref.Namespace, ":")[0]
	vulns, err := vp.reader.GetVulnerabilities(&VulnerabilitySpecifier{Name: ref.ID, Providers: []string{provider}}, &GetVulnerabilityOptions{Preload: true})
	if err != nil {
		return nil, err
	}
	if len(vulns) > 0 {
		return &vulns[0], nil
	}
	return nil, nil
}

func (vp vulnerabilityProvider) fetchKnownExploited(cves []string) ([]vulnerability.KnownExploited, error) {
	var out []vulnerability.KnownExploited
	var errs error
	for _, cve := range cves {
		kevs, err := vp.reader.GetKnownExploitedVulnerabilities(cve)
		if err != nil {
			errs = multierror.Append(errs, err)
			continue
		}
		for _, kev := range kevs {
			out = append(out, vulnerability.KnownExploited{
				CVE:                        kev.Cve,
				VendorProject:              kev.BlobValue.VendorProject,
				Product:                    kev.BlobValue.Product,
				DateAdded:                  kev.BlobValue.DateAdded,
				RequiredAction:             kev.BlobValue.RequiredAction,
				DueDate:                    kev.BlobValue.DueDate,
				KnownRansomwareCampaignUse: kev.BlobValue.KnownRansomwareCampaignUse,
				Notes:                      kev.BlobValue.Notes,
				URLs:                       kev.BlobValue.URLs,
				CWEs:                       kev.BlobValue.CWEs,
			})
		}
	}
	return out, errs
}

func (vp vulnerabilityProvider) fetchEpss(cves []string) ([]vulnerability.EPSS, error) {
	var out []vulnerability.EPSS
	var errs error
	for _, cve := range cves {
		entries, err := vp.reader.GetEpss(cve)
		if err != nil {
			errs = multierror.Append(errs, err)
			continue
		}
		for _, entry := range entries {
			out = append(out, vulnerability.EPSS{
				CVE:        entry.Cve,
				EPSS:       entry.Epss,
				Percentile: entry.Percentile,
				Date:       entry.Date,
			})
		}
	}
	return out, errs
}

func (vp vulnerabilityProvider) PackageSearchNames(p pkg.Package) []string {
	return name.PackageNames(p)
}

func (vp vulnerabilityProvider) Close() error {
	return vp.reader.(io.Closer).Close()
}

//nolint:funlen,gocognit,gocyclo
func (vp vulnerabilityProvider) FindVulnerabilities(criteria ...vulnerability.Criteria) ([]vulnerability.Vulnerability, error) {
	if err := search.ValidateCriteria(criteria); err != nil {
		return nil, err
	}

	var err error

	var out []vulnerability.Vulnerability
	for _, criteriaSet := range search.CriteriaIterator(criteria) {
		var vulnSpecs VulnerabilitySpecifiers
		var osSpecs OSSpecifiers
		var pkgSpec *PackageSpecifier
		var cpeSpec *cpe.Attributes
		var pkgType syftPkg.Type

		for i := 0; i < len(criteriaSet); i++ {
			applied := false
			switch c := criteriaSet[i].(type) {
			case *search.PackageNameCriteria:
				if pkgSpec == nil {
					pkgSpec = &PackageSpecifier{}
				}
				pkgSpec.Name = c.PackageName
				applied = true
			case *search.EcosystemCriteria:
				if pkgSpec == nil {
					pkgSpec = &PackageSpecifier{}
				}
				// the v6 store normalizes ecosystems around the syft package type, so that field is preferred
				switch {
				case c.PackageType != "" && c.PackageType != syftPkg.UnknownPkg:
					// prefer to match by a non-blank, known package type
					pkgType = c.PackageType
					pkgSpec.Ecosystem = string(c.PackageType)
				case c.Language != "":
					// if there's no known package type, but there is a non-blank language
					// try that.
					pkgSpec.Ecosystem = string(c.Language)
				case c.PackageType == syftPkg.UnknownPkg:
					// if language is blank, and package type is explicitly "UnknownPkg" and not
					// just blank, use that.
					pkgType = c.PackageType
					pkgSpec.Ecosystem = string(c.PackageType)
				}
				applied = true
			case *search.IDCriteria:
				vulnSpecs = append(vulnSpecs, VulnerabilitySpecifier{
					Name: c.ID,
				})
				applied = true
			case *search.CPECriteria:
				if cpeSpec == nil {
					cpeSpec = &cpe.Attributes{}
				}
				*cpeSpec = c.CPE.Attributes
				if cpeSpec.Product == cpe.Any {
					return nil, fmt.Errorf("must specify product to search by CPE; got: %s", c.CPE.Attributes.BindToFmtString())
				}
				if pkgSpec == nil {
					pkgSpec = &PackageSpecifier{}
				}
				pkgSpec.CPE = &c.CPE.Attributes
				applied = true
			case *search.DistroCriteria:
				for _, d := range c.Distros {
					osSpecs = append(osSpecs, &OSSpecifier{
						Name:             d.Name(),
						MajorVersion:     d.MajorVersion(),
						MinorVersion:     d.MinorVersion(),
						RemainingVersion: d.RemainingVersion(),
						LabelVersion:     d.Codename,
					})
				}
				applied = true
			}

			// remove fully applied criteria from later checks
			if applied {
				criteriaSet = append(criteriaSet[0:i], criteriaSet[i+1:]...)
				i--
			}
		}

		if len(osSpecs) == 0 {
			// we don't want to search across all distros, instead if the user did not specify a distro we should assume that
			// they want to search across affected packages not associated with any distro.
			osSpecs = append(osSpecs, NoOSSpecified)
		}

		// if there is an ecosystem provided and a name, we need to make certain that we're using the name normalization
		// rules specific to the ecosystem before searching.
		if pkgType != "" && pkgSpec.Name != "" {
			pkgSpec.Name = name.Normalize(pkgSpec.Name, pkgType)
		}

		versionMatcher, remainingCriteria := splitConstraintMatcher(criteriaSet...)

		var affectedPackages []AffectedPackageHandle
		var affectedCPEs []AffectedCPEHandle

		if pkgSpec != nil || len(vulnSpecs) > 0 {
			affectedPackages, err = vp.reader.GetAffectedPackages(pkgSpec, &GetAffectedPackageOptions{
				OSs:             osSpecs,
				Vulnerabilities: vulnSpecs,
				PreloadBlob:     true,
			})
			if err != nil {
				if errors.Is(err, ErrOSNotPresent) {
					log.WithFields("os", osSpecs).Debug("no OS found in the DB for the given criteria")
					return nil, nil
				}
				return nil, err
			}

			affectedPackages = filterAffectedPackageVersions(versionMatcher, affectedPackages)

			// after filtering, read vulnerability data
			if err = fillAffectedPackageHandles(vp.reader, ptrs(affectedPackages)); err != nil {
				return nil, err
			}
		}

		if cpeSpec != nil {
			affectedCPEs, err = vp.reader.GetAffectedCPEs(cpeSpec, &GetAffectedCPEOptions{
				Vulnerabilities: vulnSpecs,
				PreloadBlob:     true,
			})
			if err != nil {
				return nil, err
			}

			affectedCPEs = filterAffectedCPEVersions(versionMatcher, affectedCPEs, cpeSpec)

			// after filtering, read vulnerability data
			if err = fillAffectedCPEHandles(vp.reader, ptrs(affectedCPEs)); err != nil {
				return nil, err
			}
		}

		// fill complete vulnerabilities for this set -- these should have already had all properties lazy loaded
		vulns, err := vp.toVulnerabilities(affectedPackages, affectedCPEs)
		if err != nil {
			return nil, err
		}

		// filter vulnerabilities by any remaining criteria such as ByQualifiedPackages
		vulns, err = vp.filterVulnerabilities(vulns, remainingCriteria...)
		if err != nil {
			return nil, err
		}

		out = append(out, vulns...)
	}

	return out, nil
}

func (vp vulnerabilityProvider) filterVulnerabilities(vulns []vulnerability.Vulnerability, criteria ...vulnerability.Criteria) ([]vulnerability.Vulnerability, error) {
	isMatch := func(v vulnerability.Vulnerability) (bool, error) {
		for _, c := range criteria {
			if _, ok := c.(search.VersionConstraintMatcher); ok {
				continue // already run
			}
			matches, reason, err := c.MatchesVulnerability(v)
			if !matches || err != nil {
				fields := logger.Fields{
					"vulnerability": v,
				}
				if err != nil {
					fields["error"] = err
				}

				logDroppedVulnerability(v.ID, reason, fields)
				return false, err
			}
		}
		return true, nil
	}
	for i := 0; i < len(vulns); i++ {
		matches, err := isMatch(vulns[i])
		if err != nil {
			return nil, err
		}
		if !matches {
			vulns = append(vulns[0:i], vulns[i+1:]...)
			i--
		}
	}
	return vulns, nil
}

// toVulnerabilities takes fully-filled handles and returns all vulnerabilities from them
func (vp vulnerabilityProvider) toVulnerabilities(packageHandles []AffectedPackageHandle, cpeHandles []AffectedCPEHandle) ([]vulnerability.Vulnerability, error) { //nolint:funlen,gocognit
	var out []vulnerability.Vulnerability

	metadataByCVE := make(map[string]*vulnerability.Metadata)

	getMetadata := func(vuln *VulnerabilityHandle, namespace string) (*vulnerability.Metadata, error) {
		if vuln == nil {
			return nil, nil
		}

		if metadata, ok := metadataByCVE[vuln.Name]; ok {
			return metadata, nil
		}

		metadata, err := vp.getVulnerabilityMetadata(vuln, namespace)
		if err != nil {
			return nil, err
		}

		metadataByCVE[vuln.Name] = metadata
		return metadata, nil
	}

	for _, packageHandle := range packageHandles {
		if packageHandle.BlobValue == nil {
			log.Debugf("unable to find blobValue for %+v", packageHandle)
			continue
		}
		v, err := newVulnerabilityFromAffectedPackageHandle(packageHandle, packageHandle.BlobValue.Ranges)
		if err != nil {
			return nil, err
		}
		if v == nil {
			continue
		}

		meta, err := getMetadata(packageHandle.Vulnerability, v.Namespace)
		if err != nil {
			log.WithFields("error", err, "vulnerability", v.String()).Debug("unable to fetch metadata for vulnerability")
		} else {
			v.Metadata = meta
		}

		out = append(out, *v)
	}

	for _, c := range cpeHandles {
		if c.BlobValue == nil {
			log.Debugf("unable to find blobValue for %+v", c)
			continue
		}
		v, err := newVulnerabilityFromAffectedCPEHandle(c, c.BlobValue.Ranges)
		if err != nil {
			return nil, err
		}
		if v == nil {
			continue
		}

		meta, err := getMetadata(c.Vulnerability, v.Namespace)
		if err != nil {
			log.WithFields("error", err, "vulnerability", v.String()).Debug("unable to fetch metadata for vulnerability")
		} else {
			v.Metadata = meta
		}

		out = append(out, *v)
	}

	return out, nil
}

// splitConstraintMatcher returns a search.VersionConstraintMatcher from all search.VersionConstraintMatcher(s) in the criteria
func splitConstraintMatcher(criteria ...vulnerability.Criteria) (search.VersionConstraintMatcher, []vulnerability.Criteria) {
	var remaining []vulnerability.Criteria
	var matcher search.VersionConstraintMatcher
	for _, c := range criteria {
		if nextMatcher, ok := c.(search.VersionConstraintMatcher); ok {
			if matcher == nil {
				matcher = nextMatcher
			} else {
				matcher = search.MultiConstraintMatcher(matcher, nextMatcher)
			}
		} else {
			remaining = append(remaining, c)
		}
	}
	return matcher, remaining
}

func filterAffectedPackageVersions(constraintMatcher search.VersionConstraintMatcher, packages []AffectedPackageHandle) []AffectedPackageHandle {
	// no constraint matcher, just return all packages
	if constraintMatcher == nil {
		return packages
	}
	var out []AffectedPackageHandle
	for packageIdx := 0; packageIdx < len(packages); packageIdx++ {
		handle := packages[packageIdx]
		vuln := handle.vulnerability()
		allDropped, unmatchedConstraints := filterAffectedPackageRanges(constraintMatcher, handle.BlobValue)
		if !allDropped {
			out = append(out, handle)
			continue // keep this handle
		}

		reason := fmt.Sprintf("not within vulnerability version constraints: %q", strings.Join(unmatchedConstraints, ", "))
		f := make(logger.Fields)
		if handle.Package != nil {
			f["package"] = handle.Package.String()
		} else {
			f["affectedPackage"] = handle
		}

		logDroppedVulnerability(vuln, reason, f)
	}
	return out
}

func filterAffectedCPEVersions(constraintMatcher search.VersionConstraintMatcher, handles []AffectedCPEHandle, cpeSpec *cpe.Attributes) []AffectedCPEHandle {
	// no constraint matcher, just return all packages
	if constraintMatcher == nil {
		return handles
	}
	var out []AffectedCPEHandle
	for i := range handles {
		handle := handles[i]
		vuln := handle.vulnerability()
		allDropped, unmatchedConstraints := filterAffectedPackageRanges(constraintMatcher, handle.BlobValue)
		if !allDropped {
			out = append(out, handle)
			continue // keep this handle
		}

		reason := fmt.Sprintf("not within vulnerability version constraints: %q", strings.Join(unmatchedConstraints, ", "))
		logDroppedVulnerability(vuln, reason, logger.Fields{
			"cpe": cpeSpec.String(),
		})
	}
	return out
}

// filterAffectedPackageRanges returns true if all ranges removed
func filterAffectedPackageRanges(matcher search.VersionConstraintMatcher, b *AffectedPackageBlob) (bool, []string) {
	var unmatchedConstraints []string
	for _, r := range b.Ranges {
		v := r.Version
		format := version.ParseFormat(v.Type)
		constraint, err := version.GetConstraint(v.Constraint, format)
		if err != nil || constraint == nil {
			log.WithFields("error", err, "constraint", v.Constraint, "format", v.Type).Debug("unable to parse constraint")
			continue
		}
		matches, err := matcher.MatchesConstraint(constraint)
		if err != nil {
			log.WithFields("error", err, "constraint", v.Constraint, "format", v.Type).Debug("match constraint error")
		}
		if matches {
			continue
		}
		unmatchedConstraints = append(unmatchedConstraints, v.Constraint)
	}
	return len(b.Ranges) == len(unmatchedConstraints), unmatchedConstraints
}

func toSeverityString(sev vulnerability.Severity) string {
	return strcase.ToCamel(sev.String())
}

// returns the first reference url to populate the DataSource
func firstReferenceURL(vuln *VulnerabilityHandle) string {
	for _, v := range vuln.BlobValue.References {
		return v.URL
	}
	return ""
}

// skip the first reference URL and return the remainder to populate the URLs
func lastReferenceURLs(vuln *VulnerabilityHandle) []string {
	var out []string
	for i, v := range vuln.BlobValue.References {
		if i == 0 {
			continue
		}
		out = append(out, v.URL)
	}
	return out
}

func getCVEs(vuln *VulnerabilityHandle) []string {
	var cves []string
	set := strset.New()

	addCVE := func(id string) {
		lower := strings.ToLower(id)
		if strings.HasPrefix(lower, "cve-") {
			if !set.Has(lower) {
				cves = append(cves, id)
				set.Add(lower)
			}
		}
	}

	if vuln == nil {
		return cves
	}

	addCVE(vuln.Name)

	if vuln.BlobValue == nil {
		return cves
	}

	addCVE(vuln.BlobValue.ID)

	for _, alias := range vuln.BlobValue.Aliases {
		addCVE(alias)
	}

	return cves
}
