import * as d3 from 'd3'
import { groupBy, orderBy, uniq } from 'lodash'
import React, { useState } from 'react'
import SelectedArticlesList from './SelectedArticlesList'
import FocusedArticleTooltip from './FocusedArticleTooltip'


const Scatterplot = ({
  dimensions,
  articles: articlesData,
  keyword,
  minKnownCitationsCount,
  minNeighborhoodCitationsCount,
  neighborhoodJournals,
  neighborhoodJournalLabels,
  neighborhoodArticlePmid,
  selectedArticle,
  setSelectedArticle,
  neighborhoodLoaded,
  setNeighborhoodLoaded,
}) => {
  const svgRef = React.useRef(null)
  const plotContainerRef = React.useRef(null)
  const xGroupRef = React.useRef(null)
  const yGroupRef = React.useRef(null)
  const xAxisGroupRef = React.useRef(null)
  const yAxisGroupRef = React.useRef(null)
  const circlesGroupRef = React.useRef(null)
  const selectionPathRef = React.useRef(null)
  const [selectedArticles, setSelectedArticles] = useState([])
  const [selectedPmids, setSelectedPmids] = useState([])
  const [focusedArticleProps, setFocusedArticleProps] = useState({ article: null, x: null, y: null })

  const keywordLabelTicks = neighborhoodJournalLabels.reduce((acc, item) => {
    const { name, minY, maxY } = item
    acc[name] = [minY, maxY]
    return acc
  }, {})

  const keywordLabelTickLabels = Object.keys(keywordLabelTicks)
  const keywordLabelTickValues = Object.values(keywordLabelTicks).flat()

  let zoomEnabled = true
  let selectedArticlesListVisible = false
  let selectedArticlePmid = selectedArticle.pmid

  const dotOnClick = async (event) => {
    const targetNodeId = parseInt(event.target.dataset.pmid)
    const newSelectedArticle = articlesData.find(({ pmid }) => targetNodeId === pmid)
    selectedArticlePmid = newSelectedArticle.pmid
    setSelectedArticle(newSelectedArticle)
  }

  const dotOnMouseOver = ({ x, y, target }) => {
    if (zoomEnabled && !selectedArticlesListVisible) {
      target.style['stroke-width'] = '0.25rem'
      const pmid = parseInt(target.dataset.pmid)
      const article = articles.find((a) => a.pmid === pmid)
      setFocusedArticleProps({ article, x, y })
    }
  }

  const dotOnMouseLeave = (event) => {
    event.target.style['stroke-width'] = '0.15rem'
    if (!selectedArticlesListVisible) {
      setFocusedArticleProps({ article: null, x: null, y: null })
    }
  }

  const articles = articlesData
  let topArticlePmids = [neighborhoodArticlePmid]
  let similarlyTitledArticles = []
  let lessSimilarlyTitledArticles = []

  articles.forEach((article) => {
    if (article.titleSimilarity >= 0.4) {
      similarlyTitledArticles.push(article)
    } else {
      lessSimilarlyTitledArticles.push(article)
    }
  })

  let articlesByNeighborType = groupBy(lessSimilarlyTitledArticles.filter(({ pmid }) => pmid !== neighborhoodArticlePmid), 'neighborType')
  const orderedAncestors = orderBy(articlesByNeighborType['ANCESTOR'], 'publishedOn', ['asc']).slice(0, 5)
  const orderedDescendants = orderBy(articlesByNeighborType['DESCENDANT'], ['publishedOn', 'citationCount'], ['desc', 'desc']).slice(0, 5)
  articlesByNeighborType = Object.values(articlesByNeighborType).map((neighborTypeArticles) => {
    return orderBy(neighborTypeArticles, 'citationCount', 'asc')
  })

  articlesByNeighborType.push(similarlyTitledArticles)

  while (topArticlePmids.length <= 40 && articlesByNeighborType.filter((a) => a.length > 0).length > 0) {
    articlesByNeighborType.forEach((neighborTypeArticles) => {
      const pmid = neighborTypeArticles.pop()?.pmid

      if (pmid) {
        topArticlePmids.push(pmid)
      }
    })
  }

  topArticlePmids = uniq(topArticlePmids)

  function center(event, target) {
    if (event.sourceEvent) {
      const p = d3.pointers(event, target)
      return [d3.mean(p, d => d[0]), d3.mean(p, d => d[1])]
    }
    return [width / 2, height / 2]
  }

  orderedAncestors.forEach(({ pmid }) => topArticlePmids.push(pmid))
  orderedDescendants.forEach(({ pmid }) => topArticlePmids.push(pmid))

  const [interactablePmids, setInteractablePmids] = React.useState(topArticlePmids)

  const endSelection = async (start, end) => {
    const dots = d3.selectAll('circle')._groups[0]
    const selectedDots = [...dots].filter((c) => {
      const cx = c.cx.baseVal.value
      const cy = c.cy.baseVal.value
      const withinSelectedXRange = (cx <= end[0] && cx >= start[0]) || (cx <= start[0] && cx >= end[0])
      const withinSelectedYRange = (cy <= end[1] && cy >= start[1]) || (cy <= start[1] && cy >= end[1])
      return withinSelectedXRange && withinSelectedYRange
    })

    if (selectedDots.length === 0) { return }

    const selectionPmids = selectedDots.map(({ __data__}) => __data__.pmid)
    setSelectedPmids(selectionPmids)
  }

  const topArticles = articles.filter(({ pmid }) => topArticlePmids.includes(pmid))
  let topArticleJournals = uniq(topArticles.map(({ journalTitle }) => journalTitle))
  topArticleJournals = orderBy(neighborhoodJournals.filter(({ journal }) => topArticleJournals.includes(journal)), 'articlesCount', 'desc')

  const r = ({ titleSimilarity, citationsCount, pmid }) => {
    let weight = 10
    weight += citationsCount

    if (topArticlePmids.includes(pmid)) { weight += 2 }

    if (titleSimilarity >= 0.4) {
      const titleSimilarityWeight = (1 / titleSimilarity) * 3
      weight += titleSimilarityWeight
    }

    return Math.log(weight)
  }

  const opacity = ({ pmid }) => {
    if (pmid === selectedArticle.pmid) { return 1 }

    return 0.1
  }

  const fillOpacity = ({ pmid }) => {
    if (selectedArticle.pmid === pmid) {
      return 0.5
    }

    return 0.25
  }

  const { width, height, margin } = dimensions
  const svgWidth = width + margin.left + margin.right
  const svgHeight = height + margin.top + margin.bottom
  const sortedXValues = articles.map(({ x }) => x).sort()
  const sortedYValues = articles.map(({ y }) => y).sort()
  const minX = sortedXValues[0]
  const maxX = sortedXValues[sortedXValues.length - 1]
  const minY = sortedYValues[0]
  const maxY = sortedYValues[sortedYValues.length - 1]
  const xScale = d3.scaleLinear()
    .domain([minX, maxX])
    .range([0, width])
  const yScale = d3.scaleLinear()
    .domain([minY, maxY])
    .range([height, 0])

  React.useEffect(() => {
    let selectedArticlePmid = selectedArticle.pmid

    const svgEl = d3.select(svgRef.current)
    const svg = d3.select(plotContainerRef.current)
    const gx = d3.select(xGroupRef.current)
    const gy = d3.select(yGroupRef.current)

    let currentTransform = d3.zoomIdentity

    const xAxis = d3.axisBottom(xScale)
      .ticks(6)
      .tickSize(-height + margin.bottom)
      .tickFormat((tickValue) => Math.floor((tickValue / 365) + 1781))

    const xAxisGroup = d3.select(xAxisGroupRef.current).call(xAxis)

    xAxisGroup.select('.domain').remove()
    xAxisGroup.selectAll('line').attr('stroke', 'lightgray')
    xAxisGroup.selectAll('text')
      .attr('class', 'font-sans text-xl')
      .attr('color', 'black')

    // Add Y grid lines with labels
    let yAxisTickValues = keywordLabelTickValues

    const journalTitle = (tickValue) => {
      const maxTitleLength = 37
      const journal = neighborhoodJournals.find(({ minY }) => minY === tickValue)?.journal
      const journalTitleParts = journal?.replace(/^The/, '')?.split(' ') || []
      let formattedJournalTitleParts = []

      journalTitleParts.forEach((titlePart) => {
        if (formattedJournalTitleParts.join(' ').length <= maxTitleLength) {
          formattedJournalTitleParts = formattedJournalTitleParts.concat([titlePart])
        }
      })

      let formattedJournalTitle = formattedJournalTitleParts.join(' ')

      if (formattedJournalTitle.length > maxTitleLength) {
        const suffix = ' ...'
        formattedJournalTitle = formattedJournalTitle.slice(0, maxTitleLength - suffix.length) + ' ...'
      }

      return formattedJournalTitle.trim()
    }

    let yAxis = d3.axisLeft(yScale)
      .tickValues(yAxisTickValues)
      .tickFormat(journalTitle)
      .tickSize(width - margin.left)
    let yAxisGroup = d3.select(yAxisGroupRef.current).call(yAxis)
    yAxisGroup.selectAll('text')
      .attr('class', 'font-sans text-base')
      .attr('color', 'black')
    yAxisGroup.selectAll('line').attr('stroke', 'lightgray')
    yAxisGroup.select('.domain').remove()

    const gDot = d3.select(circlesGroupRef.current)
      .selectAll('dot')
      .data(articles)
      .enter()
      .append('circle')
      .attr('data-pmid', ({ pmid }) => pmid)
      .attr('cx', ({ x }) => xScale(x))
      .attr('cy', ({ y }) => yScale(y))
      .attr('r', r)
      .style('stroke', 'black')
      .style('stroke-width', '0.15rem')
      .style('fill', 'black')
      .style('opacity', opacity)
      .style('fill-opacity', fillOpacity)

    const {
      top,
      left,
      bottom,
      right,
    } = svgEl.node().getBoundingClientRect()

    const zoomX = d3
      .zoom()
      .extent([[left, top], [right, bottom]])
      .scaleExtent([0.5, 10])

    const zoomY = d3
      .zoom()
      .extent([[left, top], [right, bottom]])
      .scaleExtent([0.5, 10])

    const tx = () => d3.zoomTransform(gx.node())
    const ty = () => d3.zoomTransform(gy.node())

    function zoomed(event) {
      if (!zoomEnabled) { return }
      const { transform, sourceEvent } = event
      const { k: newK, x: newX, y: newY } = transform
      const k = newK / currentTransform.k
      const ctrlKey = sourceEvent?.ctrlKey

      if (k === 1) {
        const x = (newX - currentTransform.x) / tx().k
        const y = (newY - currentTransform.y) / ty().k
        gx.call(zoomX.translateBy, x, 0)
        gy.call(zoomY.translateBy, 0, y)
      } else {
        const point = center(event, this)

        if (ctrlKey) {
          gy.call(zoomX.scaleBy, k, point)
        } else if (!ctrlKey && sourceEvent) {
          gx.call(zoomY.scaleBy, k, point)
        } else if (!sourceEvent) {
          gy.call(zoomX.scaleBy, k, point)
          gx.call(zoomY.scaleBy, k, point)
        }
      }

      currentTransform = transform
      const currentYK = ty().k
      const currentXK = tx().k

      const newYScale = ty().rescaleY(yScale)
      const newXScale = tx().rescaleX(xScale)

      if (currentYK <= 6.5) {
        const newYDomain = yAxis.scale(newYScale).scale().domain()
        const filteredTickValues = keywordLabelTickLabels
          .reduce((acc, label) => {
            const minY = keywordLabelTicks[label][0]
            const maxY = keywordLabelTicks[label][1]

            if (minY >= newYDomain[0] && maxY <= newYDomain[1]) {
              acc.push(minY)
              acc.push(maxY)
            }
            return acc
          }, [])
        yAxisGroup.remove()
        yAxis = d3.axisLeft(yScale)
          .tickValues(filteredTickValues)
          .tickFormat((tickValue, index) => {
            const words = keywordLabelTickLabels
            const label = words.find((w) => keywordLabelTicks[w][1] === tickValue)
            return label
          })
          .tickSize(width - margin.left)
      } else if (currentYK > 6.5) {
        const newYDomain = yAxis.scale(newYScale).scale().domain()
        yAxisTickValues = neighborhoodJournals
          .filter(({ minY }) => minY >= newYDomain[0] && minY <= newYDomain[1])
          .map(({ minY }) => minY)
        yAxisGroup.remove()
        yAxis = d3.axisLeft(yScale)
          .tickValues(yAxisTickValues)
          .tickFormat(journalTitle)
          .tickSize(width - margin.left)
      }

      yAxisGroup = svg.append('g')
        .call(yAxis)
        .attr('transform', `translate(${width - margin.left}, 0)`)

      xAxisGroup.call(xAxis.scale(newXScale))
      yAxisGroup.call(yAxis.scale(newYScale))
      gDot.attr('cx', ({ x }) => newXScale(x))
      gDot.attr('cy', ({ y }) => newYScale(y))

      xAxisGroup.selectAll('line').attr('stroke', 'lightgray')
      xAxisGroup.select('.domain').remove()
      xAxisGroup.selectAll('text')
        .attr('class', 'font-sans text-xl')
        .attr('color', 'black')
      yAxisGroup.select('.domain').remove()

      if (currentYK > 6.5) {
        yAxisGroup.selectAll('text')
          .attr('class', 'font-sans text-base')
          .attr('color', 'black')
      } else if (currentYK <= 6.5) {
        yAxisGroup.selectAll('text')
          .attr('class', 'font-sans text-base')
          .attr('color', 'black')
          .attr('opacity', 0.8)
          .attr("filter","url(#y-label-background)")
      }

      yAxisGroup.selectAll('line').attr('stroke', 'lightgray')
    }

    const zoom = d3.zoom().on('zoom', zoomed)
    gx.call(zoomX).attr('pointer-events', 'none')
    gy.call(zoomY).attr('pointer-events', 'none')

    function rect(x, y, w, h) {
      return 'M'+[x,y]+' l'+[w,0]+' l'+[0,h]+' l'+[-w,0]+'z'
    }

    const selection = d3.select(selectionPathRef.current)

    const startSelection = (start) => {
      selection.attr('d', rect(start[0], start[0], 0, 0))
        .attr('visibility', 'visible')
        .style('fill-opacity', 0)
        .attr('stroke', '#44403c')
        .attr('stroke-width', 2)
        .attr('stroke-dasharray', '10 10')
    }

    const moveSelection = (start, moved) => {
      selection.attr('d', rect(start[0], start[1], moved[0]-start[0], moved[1]-start[1]))
    }

    svg.on('keydown', (event) => {
      if (event.code === 'Escape') {
        setSelectedArticles([])
        selection.attr('visibility', 'hidden')
        selectedArticlesListVisible = false
        return zoomEnabled = true
      }

      if (event.code === 'ShiftLeft' && zoomEnabled) {
        event.preventDefault()
        svg.on('.zoom', null)
        return zoomEnabled = false
      }
    })

    svg.on('keyup', (event) => {
      const initiatedSelection = event.code === 'ShiftLeft'
      const completedSelection = selectedArticlesListVisible && !zoomEnabled
      const cancelledSelection = initiatedSelection && !completedSelection

      if (cancelledSelection) {
        event.preventDefault()
        svg.call(zoom)
        return zoomEnabled = true
      }
    })

    svg.on('mousedown', (event) => {
      if (zoomEnabled) { return }

      const drawingSelection = event.shiftKey

      if (drawingSelection) {
        event.preventDefault()
        const subject = svgEl
        const parent = event.target.parentNode
        const start = d3.pointer(event, parent)
        startSelection(start)
        subject
          .on('mousemove.selection', (e) => {
            e.preventDefault()
            moveSelection(start, d3.pointer(e, parent))
          }).on('mouseup.selection', (e) => {
            selectedArticlesListVisible = true
            endSelection(start, d3.pointer(e, parent))
            subject.on('mousemove.selection', null).on('mouseup.selection', null)
          })
      } else {
        selection.attr('visibility', 'hidden')
        setSelectedArticles([])
        selectedArticlesListVisible = false
        svg.call(zoom)
        zoomEnabled = true
      }
    })

    svg
      .call(zoom)
      .call(zoom.transform, d3.zoomIdentity.translate(0, 0).scale(0.5))

    setNeighborhoodLoaded(true)
  }, [articlesData])

  React.useEffect(() => {
    d3.selectAll('circle')
      .style('fill-opacity', fillOpacity)
    const selectedArticleDot = d3.selectAll(`circle[data-pmid="${selectedArticle.pmid}"]`)
    selectedArticleDot.style('fill-opacity', 1)
  }, [selectedArticle])

  React.useEffect(() => {
    const selectionArticles = articles.filter(({ pmid }) => selectedPmids.includes(pmid) && interactablePmids.includes(pmid))
    setSelectedArticles(selectionArticles)
  }, [selectedPmids])

  React.useEffect(() => {
    if (neighborhoodLoaded) {
      const neighborhoodArticle = articles.find(({ pmid }) => pmid === neighborhoodArticlePmid)
      const neighborhoodArticleCircle = d3.select(`circle[data-pmid="${neighborhoodArticlePmid}"]`)._groups[0][0]
      const cx = neighborhoodArticleCircle.cx.baseVal.value
      const cy = neighborhoodArticleCircle.cy.baseVal.value
      const xOffsetPx = 310
      const yOffsetPx = 160
      const x = cx + xOffsetPx
      const y = cy + yOffsetPx
      setInteractablePmids([neighborhoodArticlePmid])
      setFocusedArticleProps({ article: neighborhoodArticle, x, y })
    }
  }, [neighborhoodLoaded])

  React.useEffect(() => {
    if (minKnownCitationsCount === 0 && minNeighborhoodCitationsCount === 0 && keyword === '') {
      return setInteractablePmids([selectedArticle.pmid])
    }

    if (minKnownCitationsCount > 0 || minNeighborhoodCitationsCount > 0 || keyword.length > 0) {
      const eligiblePmids = articlesData.reduce((acc, { pmid, citationsCount, citedByArticlePmids, title, journalTitle }) => {
        const knownCitationsFilterApplied = minKnownCitationsCount > 0
        const neighborhoodCitationsFilterApplied = minNeighborhoodCitationsCount > 0
        const keywordFilterApplied = keyword.length > 0
        const citationsCountAboveMin = citationsCount >= minKnownCitationsCount
        const neighborhoodCitationsCountAboveMin = citedByArticlePmids.length >= minNeighborhoodCitationsCount
        const titleKeywordMatch = title.toLowerCase().match(keyword)
        const journalTitleKeywordMatch = journalTitle.toLowerCase().match(keyword)
        const containsKeyword = !!titleKeywordMatch || !!journalTitleKeywordMatch
        const filtersApplied = []

        if (neighborhoodCitationsFilterApplied) {
          filtersApplied.push(neighborhoodCitationsCountAboveMin)
        }

        if (knownCitationsFilterApplied) {
          filtersApplied.push(citationsCountAboveMin)
        }

        if (keywordFilterApplied) {
          filtersApplied.push(containsKeyword)
        }

        if (filtersApplied.every((filter) => filter)) {
          acc.push(pmid)
        }

        return acc
      }, [])

      setInteractablePmids(eligiblePmids)
    } else {
      setInteractablePmids(articlesData.map(({ pmid }) => pmid))
    }
  }, [minKnownCitationsCount, minNeighborhoodCitationsCount, keyword])

  React.useEffect(() => {
    const dots = d3.selectAll('circle')
    const interactableDots = dots.filter((dot) => interactablePmids.includes(dot.pmid))
    const nonInteractableDots = dots.filter((dot) => !interactablePmids.includes(dot.pmid))
    zoomEnabled = true

    nonInteractableDots
      .on('mouseover', null)
      .on('mouseleave', null)
      .on('click', null)
      .style('stroke', 'Black')
      .style('stroke-width', '0.15rem')
      .style('fill', 'black')
      .style('opacity', 0.15)
      .style('fill-opacity', fillOpacity)

    interactableDots
      .on('mouseover', dotOnMouseOver)
      .on('mouseleave', dotOnMouseLeave)
      .on('click', dotOnClick)
      .style('stroke', 'black')
      .style('stroke-width', '0.15rem')
      .style('fill', 'black')
      .style('opacity', 1)
      .style('fill-opacity', 0.25)

  }, [interactablePmids])

  const viewBoxMinY = 100

  return(
    <div className={'h-screen w-screen'}>
      <SelectedArticlesList selectedArticlePmid={selectedArticle.pmid} articles={selectedArticles} setSelectedArticle={setSelectedArticle} />
      <svg viewBox={`0 ${viewBoxMinY} ${svgWidth} ${svgHeight}`} ref={svgRef} width={svgWidth} height={svgHeight}>
        <g
          ref={plotContainerRef}
          viewBox={`0 0 ${width} ${height}`}
          transform={`translate(${margin.left},${margin.top})`}
          onFocus={() => null}
        >
          <rect
            width={width}
            height={height}
            fill={'White'}
            tabIndex={'-1'}
            onFocus={({ target }) => target.style['outline'] = 0}
          />

          <defs>
            <filter
              x='0'
              y='0'
              width='1'
              height='1'
              id='y-label-background'
            >
              <feFlood floodColor='#FAFAF9' />
              <feComposite in='SourceGraphic' />
            </filter>
          </defs>
          <g ref={xGroupRef} />
          <g ref={yGroupRef} />
          <g ref={xAxisGroupRef} transform={`translate(0, ${height - margin.bottom})`} />
          <g ref={yAxisGroupRef} transform={`translate(${width - margin.left}, 0)`}/>
          <g ref={circlesGroupRef} />
          <path ref={selectionPathRef} className={'selection'} visibility={'hidden'} />
        </g>
      </svg>

      <FocusedArticleTooltip props={focusedArticleProps} />
    </div>
  )
}

export default Scatterplot
