import './survival-plot.scss';

import * as d3 from 'd3';
import * as React from 'react';
import { flatten, max } from 'lodash';
import { observer } from 'mobx-react';

export interface Datum {
    key: string;
    values: Array<{
        percentile: number;
        time: number | null;
    }>;
}

export type Data = Datum[];

export interface ISurvivalPlotProps {
    data: Data;
    selectedGroup: number;
    selectedSurvivalTime: number;
    medianSurvival: number;
    updateSurvivalTime: (survivalTime: number) => void;
}

@observer
export class SurvivalPlot extends React.Component<ISurvivalPlotProps> {
    // Set in ref callback function
    rootNode!: HTMLDivElement;
    survivalPlot: any;
    componentDidMount() {
        this.survivalPlot = drawSurvivalPlot(this.props);

        d3.select('#SurvivalPlot').call(this.survivalPlot);
    }
    componentDidUpdate() {
        this.survivalPlot.Group(this.props.selectedGroup);
        this.survivalPlot.medianSurvivalTime(this.props.medianSurvival);
    }
    render() {
        return (
            <div
                id="SurvivalPlot"
                ref={rootNode => {
                    this.rootNode = rootNode as HTMLDivElement;
                }}
            />
        );
    }
}

// Draw survival lines for all groups
function drawSurvivalPlot({
    data,
    selectedGroup,
    selectedSurvivalTime,
    medianSurvival,
    updateSurvivalTime,
}: ISurvivalPlotProps) {
    //data = array of groups. selectedSurvivalTime = starting point of vertical line in days.
    // update methods to add medianSurvival to updateGroup

    var width = 800,
        height = 300,
        margin = {
            top: 20,
            right: 180,
            bottom: 50,
            left: 50,
        };

    var width = width - margin.left - margin.right,
        height = height - margin.top - margin.bottom;

    const largestTimeValue = Math.ceil(max(
        flatten(
            data.map(datum => {
                return datum.values.map(({ time }) => {
                    return time;
                });
            }),
        ).filter(timeOrNull => {
            return timeOrNull !== null;
        }),
    ) as number);

    var x = d3
        .scaleLinear()
        .range([0, width])
        .domain([0, largestTimeValue])
        .clamp(true); // Do not allow domain values below 0 or above 1800. Used to restrict selectedSurvivalTime.

    var y = d3
        .scaleLinear()
        .range([height, 0])
        .domain([0, 100]);

    // var  z = d3.scaleOrdinal(d3.schemeCategory10);
    //@ts-ignore
    var updateGroup;

    function chart(selection: d3.Selection<any, any, any, any>) {
        selection.each(function() {
            var plot = d3.select('#SurvivalPlot');

            var survivalLine = d3
                .line()
                .curve(d3.curveBasis)
                .defined(function(d) {
                    //@ts-ignore
                    return d.time;
                })
                .x(function(d) {
                    //@ts-ignore
                    return x(d.time);
                })
                .y(function(d) {
                    //@ts-ignore
                    return y(d.percentile);
                });

            var svg = plot.append('svg').call(
                //@ts-ignore
                d3
                    .drag()
                    .on('start', moveLine)
                    .on('drag', moveLine)
                    .on('end', dragEnded),
            );

            var g = svg
                .attr('width', width + margin.left + margin.right)
                .attr('height', height + margin.top + margin.bottom)
                .append('g')
                .attr(
                    'transform',
                    'translate(' + margin.left + ',' + margin.top + ')',
                );

            //x-axis
            g
                .append('g')
                .attr('transform', 'translate(0,' + height + ')')
                .call(
                    //@ts-ignore
                    d3.axisBottom(x),
                )
                .append('text')
                .attr('y', -2)
                .attr('x', width - 30)
                .text('Time (years)')
                .attr('class', 'axis-label');

            // y-axis
            g.append('g').call(
                //@ts-ignore
                d3.axisLeft(y).tickSize(-width),
            );
            g
                .append('g')
                .append('text')
                .attr('transform', 'rotate(-90)')
                .attr('y', 10)
                .attr('x', -height)
                .text('Survival (probablity)')
                .attr('class', 'axisLabel');

            // survival line
            var group = g
                .selectAll('.group')
                .data(data)
                .enter()
                .append('g')
                .attr('class', 'group');

            group
                .append('path')
                .filter(function(d) {
                    //@ts-ignore
                    return d.key == selectedGroup;
                })
                .attr('class', 'line')
                .attr('d', function(d) {
                    //@ts-ignore
                    return survivalLine(d.values);
                })
                .attr('id', function(d) {
                    return 'tag' + d.key.replace(/\s+/g, '');
                })
                .classed('selectedLine', true);

            // Initialize report text

            g
                .append('text')
                .attr(
                    'transform',
                    'translate(' + margin.left + ',' + margin.top + ')',
                )
                .attr('x', 375)
                .attr('y', 0)
                .attr('dy', '0em')
                .attr('id', 'Annotation_1')
                .style('font', '24px sans-serif')
                .raise();

            g
                .append('text')
                .attr(
                    'transform',
                    'translate(' + margin.left + ',' + margin.top + ')',
                )
                .attr('x', 375)
                .attr('y', 0)
                .attr('dy', '2em')
                .attr('id', 'Annotation_2')
                .style('font', '24px sans-serif')
                .raise();

            Annotation(selectedSurvivalTime);
            // add annotation for surival and time at point of line intersection

            g
                .append('line')
                .attr('x1', x(selectedSurvivalTime))
                .attr('y1', 0)
                .attr('x2', x(selectedSurvivalTime))
                .attr('y2', height)
                .attr('stroke-width', 5)
                .attr('stroke', 'lightblue')
                .attr('id', 'timeLine')
                .style('opacity', 1);

            updateGroup = function() {
                group.select('.line').remove();
                group
                    .append('path')
                    .filter(function(d) {
                        //@ts-ignore
                        return d.key == selectedGroup;
                    })
                    .attr('class', 'line')
                    .attr('d', function(d) {
                        //@ts-ignore
                        return survivalLine(d.values);
                    })
                    .attr('id', function(d) {
                        return 'tag' + d.key.replace(/\s+/g, '');
                    })
                    .classed('selectedLine', true);

                Annotation(selectedSurvivalTime);
            };
        }); // end selection each
    } // end chart selection.
    //@ts-ignore
    chart.Group = function(value) {
        if (!arguments.length) return selectedGroup;
        selectedGroup = value;
        //@ts-ignore
        if (typeof updateGroup === 'function') updateGroup();
        return chart;
    };
    //@ts-ignore
    chart.medianSurvivalTime = function(newMedianSurvival: number) {
        d3
            .select('#Annotation_2')
            .text('Median survival ' + newMedianSurvival + ' years');
    };

    function Annotation(selectedSurvivalTime: number) {
        var survivalTimeFormat = d3.format('.1f');
        const survivalPercentFormat = d3.format('.0f');
        var survivalPercentile = findYatX(
            x(selectedSurvivalTime),
            d3.select('.selectedLine').node(),
            0.0000001,
        );
        var survival = survivalPercentFormat(y.invert(survivalPercentile[1]));
        d3
            .select('#Annotation_1')
            .text(
                'Survival ' +
                    survival +
                    '% at ' +
                    survivalTimeFormat(selectedSurvivalTime) +
                    ' years',
            );

        d3
            .select('#Annotation_2')
            .text('Median survival ' + medianSurvival + ' years');
    }

    function moveLine() {
        //console.log( d3.event.x, d3.event.clientX ); // log the mouse x,y position
        selectedSurvivalTime = x.invert(d3.event.x - margin.left);
        updateSurvivalTime(selectedSurvivalTime);
        d3
            .select('#timeLine')
            .attr('x1', x(selectedSurvivalTime))
            .attr('x2', x(selectedSurvivalTime))
            .raise()
            .classed('active', true);
        Annotation(selectedSurvivalTime);
    }

    function dragEnded() {
        d3.select('#timeLine').classed('active', false);
    }

    return chart;
}

//For SVG path, find x cordinates from y cordinates.
//@ts-ignore
function findXatY(y, path, error) {
    var length_end = path.getTotalLength(),
        length_start = 0,
        point = path.getPointAtLength((length_end + length_start) / 2), // get the middle point
        bisection_iterations_max = 50,
        bisection_iterations = 0;

    error = error || 0.01;

    while (y < point.y - error || y > point.y + error) {
        // get the middle point
        point = path.getPointAtLength((length_end + length_start) / 2);

        if (y < point.y) {
            length_end = (length_start + length_end) / 2;
        } else {
            length_start = (length_start + length_end) / 2;
        }

        // Increase iteration
        if (bisection_iterations_max < ++bisection_iterations) break;
    }
    return [point.x, point.y];
}

//For SVG path, find x cordinates from y cordinates.
//@ts-ignore
function findYatX(x, path, error) {
    var length_end = path.getTotalLength(),
        length_start = 0,
        point = path.getPointAtLength((length_end + length_start) / 2), // get the middle point
        bisection_iterations_max = 50,
        bisection_iterations = 0;

    error = error || 0.01;

    while (x < point.x - error || x > point.x + error) {
        // get the middle point
        point = path.getPointAtLength((length_end + length_start) / 2);

        if (x < point.x) {
            length_end = (length_start + length_end) / 2;
        } else {
            length_start = (length_start + length_end) / 2;
        }

        // Increase iteration
        if (bisection_iterations_max < ++bisection_iterations) break;
    }
    return [point.x, point.y];
}
