import Paper from '@mui/material/Paper'
import Typography from '@mui/material/Typography'
import { lighten, styled } from '@mui/material/styles'
import ApiPB from '../qai_hub/public_api_pb'

// Comma-separate running job states that correspond to "running" jobs
export const RUNNING_JOB_STATES = [
  ApiPB.JobState.JOB_STATE_CREATED,
  ApiPB.JobState.JOB_STATE_OPTIMIZING_MODEL,
  ApiPB.JobState.JOB_STATE_QUANTIZING_MODEL,
  ApiPB.JobState.JOB_STATE_PROVISIONING_DEVICE,
  ApiPB.JobState.JOB_STATE_MEASURING_PERFORMANCE,
  ApiPB.JobState.JOB_STATE_RUNNING_INFERENCE,
].join(',')

// Check if running in development mode
export function isDev() {
  return !process.env.NODE_ENV || process.env.NODE_ENV === 'development' // eslint-disable-line no-undef
}

export function computeUnitToString(compute_unit_integer) {
  // Cache a reverse look up of ComputeUnit in the function object

  // This pulls the state strings directly from the Protobuf
  if (!computeUnitToString.rev) {
    computeUnitToString.rev = {}
    for (let key in ApiPB.ComputeUnit) {
      const integer = ApiPB.ComputeUnit[key]
      // Make name exception for the unknown compute type
      if (integer == ApiPB.ComputeUnit.COMPUTE_UNIT_UNSPECIFIED) {
        computeUnitToString.rev[ApiPB.ComputeUnit[key]] = '(Unknown)'
      } else {
        // Cut off prefix ("COMPUTE_UNIT_")
        computeUnitToString.rev[ApiPB.ComputeUnit[key]] = key.replace(/^(COMPUTE_UNIT_)/, '')
      }
    }
  }
  return computeUnitToString.rev[compute_unit_integer]
}

export function computeUnitToStringIfValid(compute_unit_integer) {
  if (compute_unit_integer == ApiPB.ComputeUnit.COMPUTE_UNIT_UNSPECIFIED) {
    return ''
  } else {
    return computeUnitToString(compute_unit_integer)
  }
}

const NPU_COLOR = '#641c87'
const GPU_COLOR = '#17991b'
const CPU_COLOR = '#1c7fbd'
const UNKNOWN_COLOR = '#555555'

// Variant colors (used for specific delegates)
const NPU_COLOR_VAR1 = '#7d2252'
const NPU_COLOR_VAR2 = '#b941c4'
const NPU_COLOR_VAR3 = '#6e59b7'
const GPU_COLOR_VAR1 = '#377819'
const GPU_COLOR_VAR2 = '#739c19'
const CPU_COLOR_VAR1 = '#48a6b0'

export function computeUnitToColor(compute_unit_integer) {
  switch (compute_unit_integer) {
    case ApiPB.ComputeUnit.COMPUTE_UNIT_CPU:
      return CPU_COLOR
    case ApiPB.ComputeUnit.COMPUTE_UNIT_GPU:
      return GPU_COLOR
    case ApiPB.ComputeUnit.COMPUTE_UNIT_NPU:
      return NPU_COLOR
    default:
      return UNKNOWN_COLOR
  }
}

export function dispatchInfoToColor(dispatchInfo) {
  if (!dispatchInfo) {
    return UNKNOWN_COLOR
  }

  const { computeUnit, delegateName, delegateExtraInfo } = dispatchInfo
  const name = delegateName.toLowerCase()
  const extra = delegateExtraInfo.toLowerCase()

  switch (computeUnit) {
    case ApiPB.ComputeUnit.COMPUTE_UNIT_CPU:
      switch (true) {
        case name == 'xnnpack':
          return CPU_COLOR_VAR1
        default:
          return CPU_COLOR
      }
    case ApiPB.ComputeUnit.COMPUTE_UNIT_GPU:
      switch (true) {
        case name == 'nnapi':
          return GPU_COLOR_VAR1
        case name == 'gpuv2' && extra == 'opengl':
          return GPU_COLOR_VAR2
        case name == 'gpuv2' && extra == 'opencl':
        default:
          return GPU_COLOR
      }
    case ApiPB.ComputeUnit.COMPUTE_UNIT_NPU:
      switch (true) {
        case name == 'nnapi' && extra == 'qti-hta':
          return NPU_COLOR_VAR1
        case name == 'nnapi' && extra == 'google-edgetpu':
          return NPU_COLOR_VAR2
        case name == 'hexagon':
          return NPU_COLOR_VAR3
        case name == 'nnapi' && extra == 'qti-dsp':
        default:
          return NPU_COLOR
      }
    default:
      return UNKNOWN_COLOR
  }
}

export function pluralize(number, singularForm, pluralForm) {
  return number == 1 ? singularForm : pluralForm
}

export function capitalizeFirstLetter(string) {
  return string.charAt(0).toUpperCase() + string.slice(1)
}

export function allCapsToUpperCamelCase(s) {
  // SIGN_BIT => SignBit
  const words = s.split('_')
  const upperWords = words.map((word) => `${[...word][0].toUpperCase()}${word.substring(1).toLowerCase()}`)

  return upperWords.join('')
}

export const HeaderPaper = styled(Paper)(({ theme }) => ({
  backgroundColor: theme.palette.mode === 'dark' ? '#1A2027' : '#fff',
  ...theme.typography.body2,
  padding: theme.spacing(1.2),
  textAlign: 'center',
  color: theme.palette.text.secondary,
  height: '3.6em',
  overflow: 'hidden',
}))

export const DefaultPaper = styled(Paper)(({ theme }) => ({
  backgroundColor: theme.palette.mode === 'dark' ? '#1A2027' : '#fff',
  ...theme.typography.body2,
  padding: theme.spacing(1),
  color: theme.palette.text.secondary,
}))

export function parseStatusMessage(message) {
  return message.split('\n').map((v, index) => {
    return <Typography key={index}>{v}</Typography>
  })
}

export function typeToString(type) {
  switch (type) {
    case ApiPB.TensorDtype.TENSOR_DTYPE_FLOAT32:
      return 'float32'
    case ApiPB.TensorDtype.TENSOR_DTYPE_FLOAT16:
      return 'float16'
    case ApiPB.TensorDtype.TENSOR_DTYPE_INT32:
      return 'int32'
    case ApiPB.TensorDtype.TENSOR_DTYPE_INT16:
      return 'int16'
    case ApiPB.TensorDtype.TENSOR_DTYPE_UINT16:
      return 'uint16'
    case ApiPB.TensorDtype.TENSOR_DTYPE_INT8:
      return 'int8'
    case ApiPB.TensorDtype.TENSOR_DTYPE_UINT8:
      return 'uint8'
    case ApiPB.TensorDtype.TENSOR_DTYPE_INT64:
      return 'int64'
    default:
      return ''
  }
}

export function formatSpec(spec) {
  let shapeArray = spec.getShapeList()
  let type = typeToString(spec.getDtype())
  if (shapeArray.length == 1) {
    return type + '[' + shapeArray[0] + ']'
  } else {
    return type + '[' + shapeArray.join(', ') + ']'
  }
}

export function modelTypeToString(model_type) {
  switch (model_type) {
    case ApiPB.ModelType.MODEL_TYPE_TORCHSCRIPT:
      return 'TorchScript'
    case ApiPB.ModelType.MODEL_TYPE_MLMODEL:
      return 'Core ML'
    case ApiPB.ModelType.MODEL_TYPE_TFLITE:
      return 'TensorFlow Lite'
    case ApiPB.ModelType.MODEL_TYPE_MLMODELC:
      return 'Compiled Core ML'
    case ApiPB.ModelType.MODEL_TYPE_ONNX:
      return 'ONNX'
    case ApiPB.ModelType.MODEL_TYPE_ORT:
      return 'ONNX Runtime'
    case ApiPB.ModelType.MODEL_TYPE_MLPACKAGE:
      return 'Core ML Package'
    case ApiPB.ModelType.MODEL_TYPE_QNN_LIB_AARCH64_ANDROID:
      return 'QNN Model Library for AArch64 Android'
    case ApiPB.ModelType.MODEL_TYPE_QNN_LIB_X86_64_LINUX:
      return 'QNN Model Library for x86-64 Linux'
    case ApiPB.ModelType.MODEL_TYPE_QNN_CONTEXT_BINARY:
      return 'QNN Context Binary'
    case ApiPB.ModelType.MODEL_TYPE_AIMET_ONNX:
      return 'AIMET ONNX Package'
    case ApiPB.ModelType.MODEL_TYPE_AIMET_PT:
      return 'AIMET Torchscript Package'
    case ApiPB.ModelType.MODEL_TYPE_PRECOMPILED_QNN_ONNX:
      return 'Precompiled QNN Context Binary in ONNX'
    case ApiPB.ModelType.MODEL_TYPE_TETRART:
      return 'Unknown'
    case ApiPB.ModelType.MODEL_TYPE_UNSPECIFIED:
      return 'Unknown'
    default:
      return 'Unknown'
  }
}

export function jobStateToString(state_integer, include_in_progress = false) {
  // Cache a reverse look up of JobState in the function object

  // This pulls the state strings directly from the Protobuf
  /*
    if (!jobStateToString.rev) {
        jobStateToString.rev = {};
        for (let key in ApiPB.JobState) {
            jobStateToString.rev[ApiPB.JobState[key]] = key;
        }
    }
    return jobStateToString.rev[state_integer];
    */

  switch (state_integer) {
    case ApiPB.JobState.JOB_STATE_DONE:
      return 'Results Ready'
    case ApiPB.JobState.JOB_STATE_FAILED:
      return 'Failed'
    case ApiPB.JobState.JOB_STATE_CREATED:
      if (include_in_progress) return 'In Progress: Created'
      return 'Created'
    case ApiPB.JobState.JOB_STATE_OPTIMIZING_MODEL:
      if (include_in_progress) return 'In Progress: Optimizing Model'
      return 'Optimizing Model'
    case ApiPB.JobState.JOB_STATE_QUANTIZING_MODEL:
      if (include_in_progress) return 'In Progress: Quantizing Model'
      return 'Quantizing Model'
    case ApiPB.JobState.JOB_STATE_PROVISIONING_DEVICE:
      if (include_in_progress) return 'In Progress: Provisioning Device'
      return 'Provisioning Device'
    case ApiPB.JobState.JOB_STATE_MEASURING_PERFORMANCE:
      if (include_in_progress) return 'In Progress: Measuring Performance'
      return 'Measuring Performance'
    case ApiPB.JobState.JOB_STATE_RUNNING_INFERENCE:
      if (include_in_progress) return 'In Progress: Running Inference'
      return 'Running Inference'
    default:
      return 'Unknown'
  }
}

export function apiURL(relativePath, params) {
  let queryString = ''
  if (params) {
    const searchParams = new URLSearchParams(params)
    queryString = '?' + searchParams.toString()
  }

  // Should not have trailing slash
  return '/api/v1/' + relativePath + queryString
}

// Get display name, depending on what what fields are available
export function getDisplayName(userInfo) {
  // TODO: Disabled, until this is always loaded
  /* if (userInfo.first_name) {
        return userInfo.first_name
    } else */
  if (userInfo.email) {
    return userInfo.email
  } else {
    return userInfo.pk
  }
}

export function getDisplayNameInitials(userInfo) {
  // TODO: first_name/last_name currently not made available
  if (userInfo.first_name && userInfo.last_name) {
    return userInfo.first_name[0].toUpperCase() + userInfo.last_name[0].toUpperCase()
  } else if (userInfo.first_name) {
    return userInfo.first_name.slice(0, 2).toUpperCase()
  } else if (userInfo.email) {
    return userInfo.email.slice(0, 2).toUpperCase()
  } else {
    return ''
  }
}

// To fix a Safari issue, we use X-Auth-Token in production and then let NGINX
// forward it to Authorization. Since NGINX is not running in dev, we need to
// use the original Authorization token directly there.
function tokenHeaderKey() {
  if (isDev()) {
    return 'Authorization'
  } else {
    return 'X-Auth-Token'
  }
}

// Authentication header
export function authHeader(header) {
  let headers = {}
  headers[tokenHeaderKey()] = `Token ${localStorage.getItem('token')}`
  const finalHeader = {
    ...header,
    headers,
  }
  return finalHeader
}

// Authentication header for Protobuf messages
export function authProtobufHeader(header) {
  const coreHeader = authHeader(header)
  let finalHeader = {
    ...coreHeader,
    responseType: 'arraybuffer',
  }
  finalHeader.headers['Content-Type'] = 'application/x-protobuf'
  return finalHeader
}

// Converts supplied time in microseconds to displayable time in milliseconds.
export function displayInMillis(time) {
  // \u00A0 is a nbsp;
  return `${(time / 1000).toFixed(1)}\u00A0ms`
}

// Converts supplied memory usage in bytes to displayable memory in MB.
export function displayInMegabytes(bytes) {
  // \u00A0 is a nbsp;
  return `${(bytes / 1024 / 1024).toFixed(1)}\u00A0MB`
}

export function displayRangeInMegabytes(lowerBytes, upperBytes) {
  const lowerMB = lowerBytes / 1024 / 1024
  const upperMB = upperBytes / 1024 / 1024

  const decimals = upperMB <= 1 ? 1 : 0

  const lowerMBstr = lowerMB.toFixed(decimals)
  const upperMBstr = upperMB.toFixed(decimals)

  const nbsp = '\u00A0'

  if (lowerMBstr === upperMBstr) {
    return `${lowerMBstr}${nbsp}MB`
  } else {
    return `${lowerMBstr}${nbsp}-${nbsp}${upperMBstr}${nbsp}MB`
  }
}

export function displayRangePbInMegabytes(rangePb) {
  return displayRangeInMegabytes(rangePb.getLower(), rangePb.getUpper())
}

function makeLayerRow(
  name,
  typeName,
  computeUnit,
  delegateName,
  placementRowSpan,
  deleageteOps,
  execTime,
  execCycles,
  execRowSpan,
) {
  let placement = computeUnitToString(computeUnit)
  if (delegateName.length > 0 && placement !== delegateName) {
    placement = placement + ' (' + delegateName + ')'
  }

  return {
    name: name,
    typeName: capitalizeFirstLetter(typeName),
    placement: placement,
    placementRowSpan: placementRowSpan,
    deleageteOps: deleageteOps,
    execTime: execTime,
    execCycles: execCycles,
    execRowSpan: execRowSpan,
    sparklineWidth: 0,
    sparklineColor: computeUnitToColor(computeUnit),
  }
}

export function getLayersTableData(profileInfo) {
  let layersTableData = []

  for (let segment of profileInfo.getSegmentDetailsList()) {
    let segmentLayers = profileInfo.getLayerDetailsList().filter((layer) => layer.getSegmentId() == segment.getId())

    let isFirstRow = true
    let omitPlacementCell = false
    let omitExecTimeCell = false

    // determine if this segment has per-layer timing to show.
    let segmentHasPerLayerTimes = segmentLayers.findIndex((layer) => layer.getExecutionTime() > 0) >= 0
    let segmentHasPerLayerCycles = segmentLayers.findIndex((layer) => layer.getExecutionCycles() > 0) >= 0

    for (let layer of segmentLayers) {
      let execTime = layer.getExecutionTime()
      if (!segmentHasPerLayerTimes) {
        execTime = segment.getExecutionTime()
      }

      // start with the assumption that every cell will span exactly 1 row (the default).
      let placementRowSpan = 1
      let execTimeRowSpan = 1

      if (isFirstRow) {
        // if we're processing the first row of the segment, we need to determine the row spans
        // for the placement and execution time columns.

        // placement will always be the length of the segment
        placementRowSpan = segmentLayers.length
        omitPlacementCell = placementRowSpan > 1

        // exec time depends on whether it is a per-layer or per-segment time
        execTimeRowSpan = segmentHasPerLayerTimes || segmentHasPerLayerCycles ? 1 : segmentLayers.length
        omitExecTimeCell = execTimeRowSpan > 1

        isFirstRow = false
      } else {
        // for all other rows, we use the previously set values for the segment to determine
        // whether to avoid omitting cells because of a previous row span. we use a value of -1
        // to indicate this.
        if (omitPlacementCell) {
          placementRowSpan = -1
        }

        if (omitExecTimeCell) {
          execTimeRowSpan = -1
        }
      }

      layersTableData.push(
        makeLayerRow(
          layer.getName(),
          layer.getLayerTypeName(),
          segment.getComputeUnit(),
          segment.getDelegateName(),
          placementRowSpan,
          layer.getDelegateReportedOpsList(),
          execTime,
          segmentHasPerLayerCycles ? layer.getExecutionCycles() : '',
          execTimeRowSpan,
        ),
      )
    }
  }

  // handle data without segments
  let segmentLayers = profileInfo.getLayerDetailsList().filter((layer) => layer.getSegmentId().length == 0)
  for (let layer of segmentLayers) {
    layersTableData.push(
      makeLayerRow(
        layer.getName(),
        layer.getLayerTypeName(),
        layer.getComputeUnit(),
        layer.getDelegateName(),
        1,
        layer.getDelegateReportedOpsList(),
        layer.getExecutionTime(),
        layer.getExecutionCycles() ? layer.getExecutionCycles : '',
        1,
      ),
    )
  }

  // Which field should we use to calculate the spark line?
  // Use (wall clock) time unless we only have cycles.
  let timingField = 'execTime'
  if (layersTableData.filter((layer) => layer.execTime > 0).length == 0) {
    timingField = 'execCycles'
  }

  let minTime = Number.MAX_SAFE_INTEGER
  let maxTime = 0
  for (let row of layersTableData) {
    if (row.execRowSpan >= 1) {
      minTime = Math.min(minTime, row[timingField])
      maxTime = Math.max(maxTime, row[timingField])
    }
  }

  if (maxTime - minTime > 0) {
    for (let row of layersTableData) {
      if (row.execRowSpan >= 0) {
        let fraction = (row[timingField] - minTime) / (maxTime - minTime)
        row.sparklineWidth = fraction * 100
        row.sparklineColor = lighten(row.sparklineColor, 1 - fraction)
      }
    }
  }

  return layersTableData
}
