import React from 'react';
import Box from '@mui/material/Box';
import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody';
import TableCell from '@mui/material/TableCell';
import TableHead from '@mui/material/TableHead';
import TableRow from '@mui/material/TableRow';
import Typography from '@mui/material/Typography';
import { useTheme } from '@mui/material/styles';
import {
  categoryMapReverse,
  riskCategories,
  Severity,
  severityDisplayNames,
  type TopLevelCategory,
} from '@promptfoo/redteam/constants';
import type { IssueDTO } from '@shared/dto';

export interface HeatMapData {
  category: TopLevelCategory;
  [Severity.Critical]: number;
  [Severity.High]: number;
  [Severity.Medium]: number;
  [Severity.Low]: number;
  total: number;
}

export type HeatMapMatrix = HeatMapData[];

export function createHeatMapMatrix(issues: IssueDTO[]): HeatMapMatrix {
  // Initialize the matrix with all risk categories and zero counts
  const matrix: HeatMapMatrix = Object.keys(riskCategories).map((category) => ({
    category: category as TopLevelCategory,
    critical: 0,
    high: 0,
    medium: 0,
    low: 0,
    total: 0,
  }));

  // Count issues for each category and severity
  issues.forEach((issue) => {
    if (issue.status !== 'open') {
      return;
    }

    // Find the top-level category for this issue's plugin
    const category = categoryMapReverse[issue.pluginId];

    if (!category) {
      console.warn('No category found for issue', issue);
      return;
    }

    // Find the corresponding matrix row
    const row = matrix.find((r) => r.category === category);
    if (!row) {
      console.error('No row found for category', category);
      return;
    }

    // Increment the appropriate severity counter
    row[
      issue.severity.toLowerCase() as keyof Pick<
        HeatMapData,
        Severity.Critical | Severity.High | Severity.Medium | Severity.Low
      >
    ] += issue.occurrences;
    row.total += issue.occurrences;
  });

  return matrix;
}

// Add severity order constant if not already present
export const SEVERITY_ORDER = [
  Severity.Critical,
  Severity.High,
  Severity.Medium,
  Severity.Low,
] as const;

// Update the getHeatMapColor function to handle dark mode
export function getHeatMapColor(count: number, maxCount: number, isDarkMode: boolean): string {
  if (count === 0) {
    return isDarkMode ? 'rgba(255, 255, 255, 0.05)' : '#f5f5f5';
  }

  // Calculate intensity (0-1)
  const intensity = Math.min(count / maxCount, 1);

  // For dark mode, use a different color scale that works better on dark backgrounds
  if (isDarkMode) {
    const r = Math.round(255 * intensity);
    const g = Math.round(50 * (1 - intensity));
    const b = Math.round(50 * (1 - intensity));
    return `rgba(${r}, ${g}, ${b}, 0.8)`;
  }

  // Light mode colors (original implementation)
  const r = Math.round(255);
  const g = Math.round(255 * (1 - intensity));
  const b = Math.round(255 * (1 - intensity));
  return `rgb(${r}, ${g}, ${b})`;
}

interface IssueHeatMapProps {
  issues: IssueDTO[];
  navigateToIssues: (params: Record<string, string>) => void;
}

export default function IssueHeatMap({ issues, navigateToIssues }: IssueHeatMapProps) {
  const theme = useTheme();
  const isDarkMode = theme.palette.mode === 'dark';
  const heatMapData = createHeatMapMatrix(issues);

  const handleCellClick = (severity: string, category: TopLevelCategory) => {
    navigateToIssues({
      severity,
      riskCategory: category,
    });
  };

  // Find the maximum count for color scaling
  const maxCount = Math.max(
    ...heatMapData.flatMap((row) => [row.critical, row.high, row.medium, row.low]),
    1,
  );

  return (
    <Box sx={{ width: '100%', overflow: 'auto' }}>
      <Typography variant="h6" gutterBottom>
        Attack Risk Posture
      </Typography>
      <Table size="small" sx={{ color: theme.palette.text.primary }}>
        <TableHead>
          <TableRow>
            <TableCell sx={{ color: theme.palette.text.primary }}>Category</TableCell>
            {SEVERITY_ORDER.map((severity) => (
              <TableCell key={severity} align="center" sx={{ color: theme.palette.text.primary }}>
                {severityDisplayNames[severity]}
              </TableCell>
            ))}
            <TableCell align="center" sx={{ color: theme.palette.text.primary }}>
              Total
            </TableCell>
          </TableRow>
        </TableHead>
        <TableBody>
          {heatMapData.map((row) => (
            <TableRow key={row.category}>
              <TableCell component="th" scope="row" sx={{ color: theme.palette.text.primary }}>
                {row.category}
              </TableCell>
              {SEVERITY_ORDER.map((severity) => {
                const count = row[severity.toLowerCase() as keyof typeof row];
                return (
                  <TableCell
                    key={severity}
                    align="center"
                    sx={{
                      backgroundColor: getHeatMapColor(count as number, maxCount, isDarkMode),
                      fontWeight: 'bold',
                      cursor: (count as number) > 0 ? 'pointer' : 'default',
                      color:
                        (count as number) > 0
                          ? isDarkMode
                            ? '#ffffff'
                            : '#000000'
                          : theme.palette.text.primary,
                    }}
                    onClick={
                      (count as number) > 0
                        ? () => handleCellClick(severity, row.category)
                        : undefined
                    }
                  >
                    {count}
                  </TableCell>
                );
              })}
              <TableCell
                align="center"
                sx={{ fontWeight: 'bold', color: theme.palette.text.primary }}
              >
                {row.total}
              </TableCell>
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </Box>
  );
}
