import { Breakpoint, Breakpoints, useTheme } from '@mui/material';
import { useMemo } from 'react';

import { useCurrentBreakpoint } from './useCurrentBreakpoint';

type ValueMap<T> = Partial<Record<Breakpoint, T>>;

function isValueMap<T>(value: T | ValueMap<T>): value is ValueMap<T> {
  return typeof value === 'object' && value !== null && !Array.isArray(value);
}

function getNearestSmallerBreakpoint(
  breakpoints: Breakpoints,
  currentBreakpoint: Breakpoint,
  breakpointKeys?: Breakpoint[],
): Breakpoint | undefined {
  // Filter keys if breakpointKeys is provided; otherwise, use all keys
  const filteredKeys = breakpointKeys
    ? breakpoints.keys.filter((key) => breakpointKeys.includes(key))
    : breakpoints.keys;

  // Ensure the filtered keys are sorted by their associated values
  const sortedBreakpoints = filteredKeys
    .slice() // Copy the array to prevent mutation
    .sort((a, b) => breakpoints.values[a] - breakpoints.values[b]);

  let nearestSmaller = undefined;

  for (const breakpoint of sortedBreakpoints) {
    if (breakpoints.values[breakpoint] < breakpoints.values[currentBreakpoint]) {
      nearestSmaller = breakpoint;
    } else {
      // Since the array is sorted, once we find a breakpoint greater than the current,
      // we can stop searching as all subsequent breakpoints will also be greater.
      break;
    }
  }

  return nearestSmaller;
}

export function useResponsiveValue<T>(value: T | ValueMap<T>, defaultValue: T): T;
export function useResponsiveValue<T>(value: T | ValueMap<T>): T | undefined;
export function useResponsiveValue<T>(value: T | ValueMap<T>, defaultValue?: T): T | undefined {
  // Get the theme breakpoints
  const { breakpoints } = useTheme();

  // Get the current breakpoint
  const currentBreakpoint = useCurrentBreakpoint();

  // Get the value for the current breakpoint or the nearest smaller breakpoint
  return useMemo(() => {
    if (!isValueMap(value)) {
      return value ?? defaultValue;
    }

    if (!currentBreakpoint) {
      return defaultValue;
    }

    if (value[currentBreakpoint] !== undefined) {
      return value[currentBreakpoint] ?? defaultValue;
    }

    const nearestSmallerBreakpoint = getNearestSmallerBreakpoint(
      breakpoints,
      currentBreakpoint,
      Object.keys(value) as Breakpoint[],
    );

    if (nearestSmallerBreakpoint) {
      return value[nearestSmallerBreakpoint] ?? defaultValue;
    }

    return defaultValue;
  }, [currentBreakpoint, value, defaultValue]);
}
