import { seriesCanvasHeatmap } from "d3fc"
import { TimeSeriesData } from "../../../../Data/TimeSeriesData"
import { HeatmapTraceConfig, getHeatmapTraceId } from "../../../../Types/Trace"
import { TimeSeriesPageManager } from "../../../../Data/TimeSeriesPageManager"
import { ModalityPage } from "../../../../Data/ModalityPage"
import { D3Trace } from "./D3Trace"
import { TraceRenderStrategy, TraceRendererOptions } from "./RenderStrategy"
import { ScaleBand, ScaleSequential, range, scaleBand, scaleSequential } from "d3"
import { binMeansUsingTimestampBins, getBinEdges } from "../../Histogram/histogram"
import { getSpectrumInterpolator } from "../../../../Types/ColorSpectrum"

export class HeatmapRenderStrategy implements TraceRenderStrategy {
	private directRenderer = seriesCanvasHeatmap()
	private offscreenRenderer = seriesCanvasHeatmap()
	private xScaleBand: ScaleBand<any> = scaleBand()

	private pageManager: TimeSeriesPageManager<ModalityPage>
	private config: HeatmapTraceConfig
	private d3Trace: D3Trace
	private colorScale: ScaleSequential<string>
	private desiredNumberOfBars = 2048

	constructor(pageManager: TimeSeriesPageManager<ModalityPage>, d3Trace: D3Trace, config: HeatmapTraceConfig) {
		this.config = config
		this.d3Trace = d3Trace
		this.pageManager = pageManager
		this.xScaleBand.range(this.config.xScale.range())

		this.directRenderer
			.xValue((d: [string, string, number]) => d[0])
			.yValue((d: [string, string, number]) => d[1])
			.colorValue((d: [string, string, number]) => d[2])
			.xScale(this.xScaleBand)
			.yScale(this.config.yScale)

		this.offscreenRenderer
			.xValue((d: [string, string, number]) => d[0])
			.yValue((d: [string, string, number]) => d[1])
			.colorValue((d: [string, string, number]) => d[2])
			.xScale(this.xScaleBand)
			.yScale(this.config.yScale)

		this.xScaleBand.domain(range(0, this.desiredNumberOfBars).map(value => value.toString()))
		this.colorScale = scaleSequential(getSpectrumInterpolator(this.config.colorSpectrum))
		this.colorScale.domain([this.config.lowerBound, this.config.upperBound])
	}

	getRenderCacheKey(): string {
		const { graphId, dataKey, dataSource, colorSpectrum, lowerBound, upperBound } = this.config as HeatmapTraceConfig
		return [graphId, dataKey, dataSource, colorSpectrum, lowerBound, upperBound].join("-")
	}

	public updateConfig(traceConfig: HeatmapTraceConfig) {
		this.config = traceConfig
		this.xScaleBand.range(traceConfig.xScale.range())
	}

	public getOffscreenRenderer = (options?: TraceRendererOptions) => {
		if (options?.xScale) {
			this.offscreenRenderer.xScale(this.xScaleBand.range(options.xScale.range()))
		}

		if (options?.yScale) {
			this.offscreenRenderer.yScale(this.config.yScale.range(options.yScale.range()))
		}

		return this.offscreenRenderer
	}

	public getDirectRenderer = (options?: TraceRendererOptions) => {
		if (options?.xScale) {
			this.directRenderer.xScale(this.xScaleBand.range(options.xScale.range()))
		}

		if (options?.yScale) {
			this.directRenderer.yScale(this.config.yScale.range(options.yScale.range()))
		}

		return this.directRenderer
	}

	public render() {
		this.pageManager.getPagesInView().forEach(page => this.d3Trace.renderPage(page))
	}

	public renderTimeSeriesData(data: TimeSeriesData, page: ModalityPage, renderer: any, offset: number = 0) {
		// If there is only one data point, just pretend that the value is constant.
		if (data.data.length === 1) {
			data = {
				data: new Float32Array([data.data[0], data.data[0]]),
				times: [page.startTime, page.endTime]
			}
		}

		const { means, domain } = this.binAndCalculateMeans(data, page)

		const binnedData = means.map((mean, index) => [index.toString(), "0", mean])
		const context = renderer.context()

		const pageWidth = this.config.xScale(page.endTime) - this.config.xScale(page.startTime)
		this.xScaleBand.domain(domain).range([0, pageWidth])
		const xBandWidth = this.xScaleBand.bandwidth()
		const yBandWidth = this.config.yScale.bandwidth()
		const heatmapTraceId = getHeatmapTraceId(this.config)
		const yOffset = this.config.yScale(heatmapTraceId) ?? 0

		binnedData.forEach(([x, y, value]) => {
			context.fillStyle = this.colorScale(value as number)

			// Rounding prevents white bars from showing up in between each band due to floating point precision
			context.fillRect(Math.floor((this.xScaleBand(x) ?? 0) + offset), Math.floor(yOffset), Math.ceil(xBandWidth), Math.ceil(yBandWidth))
		})

		return []
	}

	private binAndCalculateMeans(timeSeriesData: TimeSeriesData, page: ModalityPage) {
		const minimumBinSizeMs = this.getMedianBinSizeSorted(timeSeriesData.times)
		const step = Math.max((page.endTime - page.startTime) / this.desiredNumberOfBars, minimumBinSizeMs)
		const edges = getBinEdges(page.startTime, step, page.endTime)
		const domain = range(edges.length - 1).map(val => val.toString())
		const means = binMeansUsingTimestampBins(timeSeriesData, edges, this.colorScale.domain()[0])

		return { means, domain }
	}

	// Try to find the best band size to use based on the median sampling rate.
	// Using this method, it eliminates 50% or more unneeded gaps which cause visual noise.
	private getMedianBinSizeSorted(times: (number | undefined)[]): number {
		if (times.length < 2) {
			return 500 // milliseconds
		}

		const diffs: number[] = new Array(times.length - 2).fill(Infinity)

		for (let i = 1; i < times.length; i++) {
			const current: number = times[i] ?? 0
			const previous: number = times[i - 1] ?? 0

			if (!current || !previous) {
				continue
			}

			const diff = current - previous
			
			if (diff > 0) {
				diffs[i - 1] = diff
			}
		}

		diffs.sort((a, b) => a - b)

		return diffs[Math.floor(diffs.length / 2)]
	}
}
