import { Selection, EnterElement, ScaleBand, range, scaleLinear, axisLeft, format } from "d3";
import { D3OneToOneRenderable } from "../../../D3/D3OneToOneRenderable";
import { ReactCallbacks } from "../../../../Types/ReactCallbacks";
import { HeatmapTraceConfigJSON, getHeatmapTraceId } from "../../../../Types/Trace";
import { ColorSpectrum, getSpectrumInterpolator } from "../../../../Types/ColorSpectrum";

type D3HeatmapTraceLegendConfig = {
    yScale: ScaleBand<any>
    trace: HeatmapTraceConfigJSON
}

export class D3HeatmapTraceLegend extends D3OneToOneRenderable<SVGGElement, SVGGElement, D3HeatmapTraceLegendConfig> {
    private gap = 8
    private width = 16
    private textYOffset = 4
    private gradientScale = scaleLinear()
    private colorScaleAxisClassName = "d3-heatmap-color-scale-axis"
    private labelClassName = "d3-heatmap-legend-label"

    constructor(root: SVGGElement, config: D3HeatmapTraceLegendConfig, className: string, reactCallbacks: ReactCallbacks<any>) {
        super(root, config, className, reactCallbacks)
        this.updateDerivedState()
        this.render()
    }

    protected updateDerivedState(): void {
        this.gradientScale
            .domain([this.config.trace.lowerBound, this.config.trace.upperBound])
            .range([this.config.yScale.bandwidth() - 1, 0]) // 1 px for the line width of the tick mark
    }

    protected enter(newElements: Selection<EnterElement, D3HeatmapTraceLegendConfig, any, any>): Selection<SVGGElement, D3HeatmapTraceLegendConfig, SVGGElement, any> {
        const legend = newElements
            .append("g")
            .attr("class", this.className)
            .attr("transform", `translate(${-this.width - this.gap}, ${this.config.yScale(getHeatmapTraceId(this.config.trace))})`)

        // Setup the SVG
        const legendSVG = legend.append("svg")
            .attr("width", this.width)
            .attr("height", this.config.yScale.bandwidth())
                
        const legendDefs = legendSVG.append("defs")

        this.createGradient(legendDefs, this.config.trace.colorSpectrum)

        legendSVG
            .append("rect")
            .attr("fill", `url(#${this.config.trace.colorSpectrum})`)
            .attr("width", this.width)
            .attr("height", this.config.yScale.bandwidth()) 

        // Create the axis
        const colorScaleAxis = legend
            .append("g")
            .attr("class", this.colorScaleAxisClassName)

        colorScaleAxis.call(axisLeft(this.gradientScale).tickValues([
            this.config.trace.lowerBound, 
            (this.config.trace.lowerBound + this.config.trace.upperBound) / 2,
            this.config.trace.upperBound
        ]).tickFormat(format(".4~f"))) // Up to 4 decimal place precision, but less if not needed.

        // Remove the black line
        colorScaleAxis.select(".domain").attr("stroke", "transparent")

        // Text label
        const foreignObject = legend
            .append("foreignObject")
            .attr("width", 200)
            .attr("height", 20)
            .attr("transform", `translate(${this.width + this.gap}, ${-this.textYOffset - 20})`)

        foreignObject.append<HTMLDivElement>("xhtml:div")
            .style("background", "white")
            .style("width", "fit-content")
            .style("color", "black")
            .style("border-radius", "0 6px 0 0")
            .style("padding-left", "4px")
            .style("padding-right", "8px")
            .attr("class", this.labelClassName)
            .style('font-family', "Source Sans Pro")
            .style("font-weight", "bold")
            .style("font-size", 12)
            .text(this.config.trace.name)

        return legend
    }

    protected update(updatedElements: Selection<SVGGElement, D3HeatmapTraceLegendConfig, any, any>): Selection<SVGGElement, D3HeatmapTraceLegendConfig, SVGGElement, any> {
        const legend = updatedElements.attr("transform", `translate(${-this.width - this.gap}, ${this.config.yScale(getHeatmapTraceId(this.config.trace))})`)
        
        legend.select("svg")
            .attr("height", this.config.yScale.bandwidth())
                .select("rect")
                .attr("height", this.config.yScale.bandwidth())

        legend.select("." + this.labelClassName)
            .style("font-size", 12)
            .attr("transform", `translate(${this.width + this.gap}, ${-this.textYOffset})`)

        const colorScaleAxis = legend.select<SVGGElement>("." + this.colorScaleAxisClassName)

        colorScaleAxis.call(axisLeft(this.gradientScale).tickValues([
            this.config.trace.lowerBound, 
            (this.config.trace.lowerBound + this.config.trace.upperBound) / 2,
            this.config.trace.upperBound
        ]).tickFormat(format(".4~f"))) // Up to 4 decimal place precision, but less if not needed.

        return legend
    }

    private createGradient(defs: Selection<SVGDefsElement, any, any, any>, colorSpectrum: ColorSpectrum) {
        const gradientId = this.config.trace.colorSpectrum
        let gradient = defs.select<SVGLinearGradientElement>(`#${gradientId}`)
    
        if (gradient.empty()) {
    
            // Create a new linear gradient
            gradient = defs
                .append("linearGradient")
                .attr("id", gradientId)
                .attr("x1", "0%")
                .attr("x2", "0%")
                .attr("y1", "100%")
                .attr("y2", "0%")
    
            // Append color stops to the 
            const interpolator = getSpectrumInterpolator(colorSpectrum)
            
            const colorStops = range(11).map(value => ({
                offset: `${Math.round(100 * value / 10)}%`, color: interpolator(value / 10)
            }))
    
            colorStops.forEach(color => {
                gradient
                    .append("stop")
                    .attr("offset", color.offset)
                    .attr("stop-color", color.color)
            })
        }
    
        return gradientId
    }
}

