import {
  LegacyRef,
  useCallback,
  useEffect,
  useRef,
  useState,
} from 'react';

import { Box, Tooltip, useTheme } from '@mui/material';

import {
  getSelectionBounds,
  boundsToViewport,
  Node,
  Edge,
  NodeStyle,
} from '@sayari/trellis';

import { withResizeDetector } from 'react-resize-detector';

import { Layout as forceLayout } from '@sayari/trellis/layout/force';
import { Renderer } from '@sayari/trellis/bindings/react/renderer';
import {
  EdgePointerEvent,
  NodeDragEvent,
  NodePointerEvent,
  ViewportDragDecelerateEvent,
  ViewportDragEvent,
  ViewportWheelEvent,
} from '@sayari/trellis/renderers/webgl';

import { Instance } from '@popperjs/core';

import hierarchyLayout from './layouts/hierarchyLayout';
import radialLayout from './layouts/radialLayout';
import { findNodeIdWithMostConnections } from './layouts/layoutUtils';
import {
  ChartDataInfo,
  ChartLegendInfo,
  LayoutType,
  NetworkChartConfig,
} from './types';

import clustersLayout from './layouts/clustersLayout';
import { getShortNodeLabel, getShortRelationLabel } from '../../core/utils/uxUtils';

type Props = {
  firstIsRoot?: boolean;
  width: number;
  height: number;
  targetRef?: LegacyRef<HTMLDivElement>;
  getNodeAttributes?: (node: Omit<Node, 'radius'>) => Node;
  getEdgeAttributes?: (edge: Edge) => Edge;
  getLegendStyle?: (label: string) => NodeStyle;
  networkData: {
    nodes: Node[];
    relationships: Edge[];
  };
  paddingTop?: number;
  legendData?: ChartLegendInfo[];
  invisibleNodeIds?: string[];
  onInvisibleNodesUpdated?: (invisibleNodeIds: string[]) => void;
  config?: NetworkChartConfig;
};

const DEFAULT_RADIUS = 16;
const VIEWPORT_PADDING = 70;

export const defaultConfig: NetworkChartConfig = {
  layoutType: 'network',
  isLayoutLocked: false,
  isViewportLocked: false,
  layoutOptions: {
    network: {
      gravity: -600,
      linkDistance: 180,
    },
    hierarchy: {
      stack: false,
      anchor: 'top',
      alignment: 'mid',
      widthDistance: 75,
      heightDistance: 200,
    },
    radial: {
      stack: false,
      radius: 6.3,
      linkDistance: 650,
      nodeSeparation: 2,
    },
    clusters: {
      layout: 'radial',
      clusterPadding: 5,
      collidePadding: 5,
    },
  },
  stylingOptions: {
    hideBadge: false,
    hideIcon: false,
    hideLabel: false,
    useSameRadius: false,
  },
};

const NetworkChart = ({
  firstIsRoot,
  width,
  height,
  targetRef,
  getNodeAttributes = (node) => ({ ...node, radius: DEFAULT_RADIUS }),
  getEdgeAttributes = (edge) => edge,
  networkData,
  paddingTop = 0,
  legendData,
  invisibleNodeIds = [],
  onInvisibleNodesUpdated = () => {},
  config = defaultConfig,
}: Props) => {
  const theme = useTheme();

  const widthRef = useRef(width);
  widthRef.current = width;

  const heightRef = useRef(width);
  heightRef.current = height;

  const shouldResetViewport = useRef(false);

  const { layoutType } = config;
  const { nodes, relationships } = networkData;

  const relationCountLookup = (relationships || [])
    .filter((relation) => (
      !invisibleNodeIds.includes(relation.source)
      && !invisibleNodeIds.includes(relation.target)
    ))
    .reduce((lookup, relation) => ({
      ...lookup,
      [relation.target]: (lookup[relation.target] || 0) + 1,
      [relation.source]: (lookup[relation.source] || 0) + 1,
    }), {} as Record<string, number>);

  const originalLabelByNodeIds = nodes.reduce((lookup, node) => ({
    ...lookup,
    [node.id]: node.label || '',
  }), {} as Record<string, string>);

  const originalLabelByConnectionIds = relationships.reduce((lookup, relation) => ({
    ...lookup,
    [relation.id]: relation.label || '',
  }), {} as Record<string, string>);

  const tooltipPositionRef = useRef<{ x: number; y: number }>({
    x: 0,
    y: 0,
  });

  const tooltipPopperRef = useRef<Instance>(null);
  const renderAreaRef = useRef<HTMLDivElement>(null);

  const initialChartData = {
    nodes: [] as Node[],
    edges: [] as Edge[],
    x: 0,
    y: 0,
    zoom: 1,
  };

  const [chartData, setChartData] = useState<ChartDataInfo>(initialChartData);
  const [tooltipText, setTooltipText] = useState<string>();
  const [isTooltipOpen, setTooltipOpen] = useState(false);

  const [
    backupChartDataByLayoutType,
    setBackupChartDataByLayoutType,
  ] = useState<Record<string, ChartDataInfo | undefined>>({});

  const [
    savedChartData,
    setSavedChartData,
  ] = useState<ChartDataInfo | undefined>();

  const lastConfigApplyTimeout = useRef<NodeJS.Timeout>();

  const getLayoutByType = async (
    type: LayoutType,
    givenNodes: Node[],
    givenEdges: Edge[],
  ) => {
    if (type === 'network') {
      return forceLayout()({
        nodes: givenNodes.map((node) => getNodeAttributes(node)),
        edges: givenEdges.map((edge) => getEdgeAttributes(edge)),
        options: {
          linkDistance: config.layoutOptions.network.linkDistance,
          nodeStrength: config.layoutOptions.network.gravity,
        },
      });
    }

    if (type === 'hierarchy') {
      const rootId = firstIsRoot ? '0' : findNodeIdWithMostConnections(relationships);
      const {
        anchor,
        alignment,
        widthDistance,
        heightDistance,
      } = config.layoutOptions.hierarchy;

      return Promise.resolve(hierarchyLayout()(rootId, {
        nodes: givenNodes.map((node) => getNodeAttributes(node)),
        edges: givenEdges.map((edge) => getEdgeAttributes(edge)),
        options: {
          bfs: true,
          anchor,
          alignment,
          nodeSize: [widthDistance, heightDistance],
          size: undefined,
          separation: undefined,
        },
      }));
    }

    if (type === 'radial') {
      const rootId = firstIsRoot ? '0' : findNodeIdWithMostConnections(relationships);

      return Promise.resolve(radialLayout(rootId, {
        nodes: givenNodes.map((node) => getNodeAttributes(node)),
        edges: givenEdges.map((edge) => getEdgeAttributes(edge)),
        options: {
          bfs: !config.layoutOptions.radial.stack,
          nodeSize: [120, 240],
          size: [
            config.layoutOptions.radial.radius,
            config.layoutOptions.radial.linkDistance,
          ],
          separation: (a, b) => (a.parent === b.parent
            ? 1
            : config.layoutOptions.radial.nodeSeparation) / a.depth,
        },
      }));
    }

    if (type === 'clusters') {
      return Promise.resolve(clustersLayout({
        nodes: givenNodes.map((node) => getNodeAttributes(node)),
        edges: givenEdges.map((edge) => getEdgeAttributes(edge)),
        options: {
          size: [
            config.layoutOptions.clusters.clusterPadding,
            config.layoutOptions.clusters.collidePadding,
          ],
        },
      }, legendData, config.layoutOptions.clusters.layout));
    }

    return Promise.resolve(undefined);
  };

  const redrawChart = async () => {
    const visibleNodes = nodes
      .filter((node) => !invisibleNodeIds.includes(node.id));

    const labelTrimNodes = visibleNodes.map((node) => ({
      ...node,
      label: getShortNodeLabel(node),
    }));

    const visibleEdges = relationships.filter((edge) => (
      !invisibleNodeIds.includes(edge.source)
      && !invisibleNodeIds.includes(edge.target)
    ));

    const labelTrimEdges = visibleEdges.map((edge) => ({
      ...edge,
      label: getShortRelationLabel(edge),
    }));

    const data = await getLayoutByType(layoutType, labelTrimNodes, labelTrimEdges);
    if (!data) {
      return;
    }

    const { x, y, zoom } = boundsToViewport(
      getSelectionBounds(data.nodes, VIEWPORT_PADDING),
      { width: widthRef.current, height: heightRef.current },
    );

    const additionalAttrs = (layoutType === 'network')
      ? {
        nodeStrength: config.layoutOptions.network.gravity,
        linkDistance: config.layoutOptions.network.linkDistance,
      }
      : undefined;

    setChartData((graph) => ({
      ...graph,
      nodes: data.nodes,
      edges: data.edges,
      ...(
        (!config.isViewportLocked || shouldResetViewport.current) && {
          x,
          y: y + paddingTop,
          zoom,
          ...additionalAttrs,
        }
      ),
    }));

    shouldResetViewport.current = false;
  };

  const refreshChart = async () => {
    if (!config.isLayoutLocked || !savedChartData) {
      await redrawChart();
      return;
    }

    const storedChartData = savedChartData || backupChartDataByLayoutType[config.layoutType];
    const visibleNodes = (storedChartData?.nodes || [])
      .filter((node) => !invisibleNodeIds.includes(node.id));

    const visibleEdges = (storedChartData?.edges || []).filter((edge) => (
      !invisibleNodeIds.includes(edge.source)
      && !invisibleNodeIds.includes(edge.target)
    ));

    const { x, y, zoom } = boundsToViewport(
      getSelectionBounds(visibleNodes, VIEWPORT_PADDING),
      { width: widthRef.current, height: heightRef.current },
    );

    const additionalAttrs = (layoutType === 'network')
      ? {
        nodeStrength: config.layoutOptions.network.gravity,
        linkDistance: config.layoutOptions.network.linkDistance,
      }
      : undefined;

    setChartData((graph) => ({
      ...graph,
      nodes: visibleNodes,
      edges: visibleEdges,
      ...(
        (!config.isViewportLocked || shouldResetViewport.current) && {
          x,
          y: y + paddingTop,
          zoom,
          ...additionalAttrs,
        }
      ),
    }));

    shouldResetViewport.current = false;
  };

  useEffect(() => {
    const createBackup = async () => {
      const networkLayoutData = await getLayoutByType('network', nodes, relationships);
      const hierarchyLayoutData = await getLayoutByType('hierarchy', nodes, relationships);
      const radialLayoutData = await getLayoutByType('radial', nodes, relationships);
      const clustersLayoutData = await getLayoutByType('clusters', nodes, relationships);

      setBackupChartDataByLayoutType({
        network: networkLayoutData as ChartDataInfo,
        hierarchy: hierarchyLayoutData as ChartDataInfo,
        radial: radialLayoutData as ChartDataInfo,
        clusters: clustersLayoutData as ChartDataInfo,
      });
    };

    void createBackup();
  }, [networkData]);

  useEffect(() => {
    void refreshChart();
  }, [invisibleNodeIds, networkData]);

  useEffect(() => {
    if (savedChartData && !config.isLayoutLocked) {
      return;
    }

    if (lastConfigApplyTimeout.current) {
      clearTimeout(lastConfigApplyTimeout.current);
    }

    lastConfigApplyTimeout.current = setTimeout(() => {
      void refreshChart();
    }, 300);
  }, [config, savedChartData]);

  useEffect(() => {
    shouldResetViewport.current = true;
    setSavedChartData(undefined);
  }, [layoutType, networkData]);

  const getConnectedNodeIds = (targetId: string) => {
    const backupChartData = backupChartDataByLayoutType[layoutType];
    const connectedNodeIds = backupChartData?.edges
      .filter((edge) => edge.source === targetId || edge.target === targetId)
      .map((edge) => (edge.source === targetId ? edge.target : edge.source));

    return connectedNodeIds;
  };

  const getLeafNodeIds = (nodeIds: string[], edges: Edge[]) => {
    const lonelyNodeIds = nodeIds.filter(
      (nodeId) => (
        edges.filter(
          (edge) => edge.source === nodeId || edge.target === nodeId,
        ).length === 1
      ),
    );

    return lonelyNodeIds;
  };

  const onNodeDrag = useCallback(({
    nodeX,
    nodeY,
    target: { id, x = 0, y = 0 },
  }: NodeDragEvent) => {
    const connectedNodeIds = getConnectedNodeIds(id);
    const leafNodeIds = getLeafNodeIds(connectedNodeIds || [], chartData.edges);

    const getLeafNode = (node: Node) => (
      leafNodeIds.includes(node.id)
        ? {
          ...node,
          x: node.x !== undefined ? node.x + (nodeX - x) : undefined,
          y: node.y !== undefined ? node.y + (nodeY - y) : undefined,
        }
        : node
    );

    setChartData((graph) => ({
      ...graph,
      nodes: graph.nodes.map(
        (node) => (node.id === id ? { ...node, x: nodeX, y: nodeY } : node),
      ),
    }));

    // Delay the movement of "lonely nodes"
    setTimeout(() => {
      setChartData((graph) => ({
        ...graph,
        nodes: graph.nodes.map(
          (node) => getLeafNode(node),
        ),
      }));
    }, 150);

    setSavedChartData((graph) => ({
      ...(graph || chartData),
      nodes: (graph || chartData).nodes.map(
        (node) => (node.id === id ? { ...node, x: nodeX, y: nodeY } : getLeafNode(node)),
      ),
    }));
  }, [layoutType, chartData.edges]);

  const onNodeClick = useCallback(({
    target: { id },
  }: NodePointerEvent) => {
    const backupChartData = backupChartDataByLayoutType[layoutType];
    const connectedNodeIds = getConnectedNodeIds(id);
    const leafNodeIds = getLeafNodeIds(connectedNodeIds || [], backupChartData?.edges || []);

    if (leafNodeIds.length === 0) {
      return;
    }

    const isAllLeafNodesInvisible = leafNodeIds
      .every((leafNodeId) => invisibleNodeIds.includes(leafNodeId));

    if (isAllLeafNodesInvisible) {
      const excludedAllLeafNodeIds = invisibleNodeIds
        .filter((nodeId) => !leafNodeIds.includes(nodeId));

      onInvisibleNodesUpdated(excludedAllLeafNodeIds);
    } else {
      onInvisibleNodesUpdated(Array.from(new Set([
        ...invisibleNodeIds,
        ...leafNodeIds,
      ])));
    }
  }, [chartData]);

  const onViewportDrag = useCallback(({
    viewportX: x,
    viewportY: y,
  }: ViewportDragEvent | ViewportDragDecelerateEvent) => {
    setChartData((graph) => ({ ...graph, x, y }));
  }, []);

  const onViewportWheel = useCallback(({
    viewportX: x,
    viewportY: y,
    viewportZoom: zoom,
  }: ViewportWheelEvent) => {
    setChartData((graph) => ({
      ...graph,
      x,
      y,
      zoom,
    }));
  }, []);

  const onNodePointerEnter = useCallback(({
    target: {
      id,
      label,
    },
    clientX,
    clientY,
  }: NodePointerEvent) => {
    setChartData((graph) => ({
      ...graph,
      nodes: graph.nodes.map(
        (node) => (node.id === id ? { ...node, radius: node.radius * 1.4 } : node),
      ),
    }));

    tooltipPositionRef.current = { x: clientX, y: clientY };
    const originalLabel = originalLabelByNodeIds[id];

    if (originalLabel !== label) {
      setTooltipText(originalLabel);
      setTooltipOpen(true);

      if (tooltipPopperRef.current) {
        void tooltipPopperRef.current.update();
      }
    }
  }, [chartData]);

  const onNodePointerLeave = useCallback(({
    target: { id },
  }: NodePointerEvent) => {
    const relationCount = relationCountLookup[id];
    setChartData((graph) => ({
      ...graph,
      nodes: graph.nodes.map(
        (node) => (node.id === id
          ? {
            ...node,
            radius: relationCount * 1.8 + DEFAULT_RADIUS,
          }
          : node
        ),
      ),
    }));
    setTooltipOpen(false);
  }, [chartData]);

  const onEdgePointerEnter = useCallback(({
    target: { id, label },
    clientX,
    clientY,
  }: EdgePointerEvent) => {
    setChartData((graph) => ({
      ...graph,
      edges: graph.edges.map(
        (edge) => (edge.id === id ? {
          ...edge,
          style: {
            ...edge.style,
            stroke: theme.palette.grey[900],
            width: 2,
            arrow: 'forward',
            label: {
              ...edge.style?.label,
              color: theme.palette.common.black,
              background: theme.palette.common.white,
            },
          },
        } : edge),
      ),
    }));

    tooltipPositionRef.current = { x: clientX, y: clientY };
    const originalLabel = originalLabelByConnectionIds[id];

    if (originalLabel !== label) {
      setTooltipText(originalLabel);
      setTooltipOpen(true);

      if (tooltipPopperRef.current) {
        void tooltipPopperRef.current.update();
      }
    }
  }, [chartData]);

  const onEdgePointerLeave = useCallback(({
    target: { id },
  }: EdgePointerEvent) => {
    setChartData((graph) => ({
      ...graph,
      edges: graph.edges.map(
        (edge) => (edge.id === id ? {
          ...edge,
          style: {
            arrow: 'forward',
          },
        } : edge),
      ),
    }));
    setTooltipOpen(false);
  }, [chartData]);

  return (
    <div
      ref={targetRef}
      style={{
        position: 'relative',
        overflow: 'hidden',
        width,
        height: '100%',
      }}
    >
      {width === undefined || height === undefined ? (
        <span />
      ) : (
        <Tooltip
          title={tooltipText || ''}
          placement="top"
          arrow
          enterDelay={500}
          open={isTooltipOpen}
          disableInteractive
          PopperProps={{
            popperRef: tooltipPopperRef,
            anchorEl: {
              getBoundingClientRect: () => (
                new DOMRect(
                  tooltipPositionRef.current.x,
                  tooltipPositionRef.current.y,
                  0,
                  0,
                )
              ),
            },
          }}
        >
          <Box
            ref={renderAreaRef}
          >
            <Renderer
              width={width}
              height={height}
              nodes={chartData.nodes}
              edges={chartData.edges}
              x={chartData.x}
              y={chartData.y}
              zoom={chartData.zoom}
              onNodeDrag={onNodeDrag}
              onNodeClick={onNodeClick}
              onNodePointerEnter={onNodePointerEnter}
              onNodePointerLeave={onNodePointerLeave}
              onEdgePointerEnter={onEdgePointerEnter}
              onEdgePointerLeave={onEdgePointerLeave}
              onViewportDrag={onViewportDrag}
              onViewportWheel={onViewportWheel}
            />
          </Box>
        </Tooltip>
      )}
    </div>
  );
};

export default withResizeDetector(NetworkChart);
