import { Input, Stack, Table } from '@mantine/core';
import { IconSearch } from '@tabler/icons-react';
import {
  ColumnDef,
  ColumnFiltersState,
  flexRender,
  getCoreRowModel,
  getFilteredRowModel,
  getSortedRowModel,
  SortingState,
  useReactTable
} from '@tanstack/react-table';
import { useVirtualizer } from '@tanstack/react-virtual';
import React from 'react';
import { VirtualRow } from './virtual-row';

interface DataTableVirtualProps<TData, TValue> {
  id: string;
  columns: ColumnDef<TData, TValue>[];
  data: TData[];
  rowSize: number;
  sortable?: boolean;
  filterable?: boolean;
}

export const DataTableVirtual = <TData, TValue>({
  id,
  columns,
  data,
  rowSize,
  sortable,
  filterable
}: DataTableVirtualProps<TData, TValue>) => {
  const [sorting, setSorting] = React.useState<SortingState>([]);
  const [globalFilter, setGlobalFilter] = React.useState<any>([]);

  const table = useReactTable({
    data,
    columns,
    getCoreRowModel: getCoreRowModel(),
    onSortingChange: setSorting,
    getSortedRowModel: getSortedRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    globalFilterFn: 'includesString',
    onGlobalFilterChange: setGlobalFilter,
    state: {
      sorting,
      globalFilter
    },
    enableSorting: sortable ?? false,
    enableGlobalFilter: filterable ?? false
  });

  const { rows } = table.getRowModel();

  const parentRef = React.useRef<HTMLDivElement>(null);

  const virtualizer = useVirtualizer({
    count: rows.length,
    getScrollElement: () => parentRef.current,
    estimateSize: () => rowSize,
    overscan: 20
  });

  return (
    <Stack spacing="lg">
      {filterable && (
        <Input
          icon={<IconSearch size="20" />}
          placeholder="Filter enrolments..."
          onChange={e => table.setGlobalFilter(String(e.target.value))}
        />
      )}
      <div ref={parentRef}>
        <div style={{ height: `${virtualizer.getTotalSize()}px` }}>
          <Table>
            <thead>
              {table.getHeaderGroups().map(headerGroup => (
                <tr key={headerGroup.id}>
                  {headerGroup.headers.map(header => {
                    return (
                      <th key={header.id}>
                        {header.isPlaceholder
                          ? null
                          : flexRender(
                              header.column.columnDef.header,
                              header.getContext()
                            )}
                      </th>
                    );
                  })}
                </tr>
              ))}
            </thead>
            <tbody>
              {virtualizer.getVirtualItems().map((virtualRow, index) => {
                const row = rows[virtualRow.index];

                return (
                  <VirtualRow
                    key={row.id}
                    row={row}
                    virtualRow={virtualRow}
                    virtualIndex={index}
                  />
                );
              })}
            </tbody>
          </Table>
        </div>
      </div>
    </Stack>
  );
};
