import { Chip, type ChipProps, cn, Skeleton, Typography } from '@strise/midgard'
import * as React from 'react'
import { type AdverseFlagFragment, type BaseEntityLikeFragment } from '../../graphqlTypes'
import { FlagsTooltip } from './FlagsTooltip'
import { AdverseFlags } from './AdverseFlags'
import { useEntityFlagsQuery } from '../../graphqlOperations'
import { extractIsGlobalEntity } from '../../utils/entity'
import { type DivProps } from '@strise/react-utils'
import { isString } from 'lodash-es'

type FlagEdges = Array<{ node: AdverseFlagFragment }>

interface Flags {
  flags?: { edges: FlagEdges }
}

export type FlagEntity = BaseEntityLikeFragment & Flags

const useEntityFlags = (entityOrId: FlagEntity | string) => {
  const isStr = isString(entityOrId)
  const isGlobal = isStr ? false : extractIsGlobalEntity(entityOrId)
  const isGlobalv2 = isStr
    ? false
    : (entityOrId.__typename === 'Company' && entityOrId.isGlobalCompany) ||
      (entityOrId.__typename === 'Person' && entityOrId.isGlobalPerson)
  const hasFlags = !isStr && 'flags' in entityOrId

  const { data, loading } = useEntityFlagsQuery({
    variables: { id: isStr ? entityOrId : entityOrId.id },
    skip: hasFlags || isGlobal || isGlobalv2
  })

  const flags = hasFlags ? (entityOrId.flags?.edges ?? []) : (data?.entity.flags.edges ?? [])

  return { flags, loading }
}

interface FlagChipProps extends DivProps {
  children?: React.ReactNode
  chipProps?: ChipProps
  className?: string
  entityId: string
  flags: AdverseFlagFragment[]
}

const FlagChip = ({ children, chipProps, className, entityId, flags, ...props }: FlagChipProps) => (
  <FlagsTooltip entityId={entityId} flags={flags} {...props}>
    <Chip className={className} variant='contained' palette='tertiary' label={children} {...chipProps} />
  </FlagsTooltip>
)

export interface EntityFlagProps extends DivProps {
  chipProps?: ChipProps
  entityOrId: FlagEntity | string
}

export const EntityFlag = ({ chipProps, className, entityOrId, ...props }: EntityFlagProps) => {
  const { flags, loading } = useEntityFlags(entityOrId)

  if (!loading && !flags.length) return null
  if (loading) {
    return (
      <div className={cn('flex items-center', className)} {...props}>
        <Skeleton className='my-0 h-6 w-[42px] rounded-[16px]' variant='rect' />
      </div>
    )
  }

  const entityId = isString(entityOrId) ? entityOrId : entityOrId.id

  const flagNodes = flags.map(({ node }) => node)

  return (
    <FlagChip className={className} entityId={entityId} flags={flagNodes} chipProps={chipProps} {...props}>
      <div className='flex items-center'>
        <AdverseFlags flagProps={{ className: 'mr-1' }} severities={flagNodes.map(({ severity }) => severity)} />
        <Typography variant='body2'>{flags.length}</Typography>
      </div>
    </FlagChip>
  )
}
