// Most of the code in this file is adopted from `prosemirror-utils`:
// https://github.com/atlassian/prosemirror-utils
//
// ------------------------------------------------------------------
//
// Copyright 2018 Atlassian Pty Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { findParentNode } from "@tiptap/core";
import { Node } from "@tiptap/pm/model";
import { Selection, Transaction } from "@tiptap/pm/state";
import { TableMap, CellSelection } from "@tiptap/pm/tables";

type CellWithPos = {
    pos: number;
    node: Node | null;
};

export const findTable = (selection: Selection) =>
    findParentNode(node => node.type.name === "table")(selection) ?? null;

export const getCellsInTable = (selection: Selection): CellWithPos[] | null => {
    const table = findTable(selection);
    if (!table) return null;

    const map = TableMap.get(table.node);

    const cells = map.cellsInRect({ left: 0, right: map.width, top: 0, bottom: map.height });
    return cells.map(pos => ({ pos: pos + table.pos + 1, node: table.node.nodeAt(pos) }));
};

export const getCellsInColumn =
    (columnIndex: number) =>
    (selection: Selection): CellWithPos[] | null => {
        const table = findTable(selection);
        if (!table) return null;

        const map = TableMap.get(table.node);

        const isOutOfBounds = columnIndex < 0 || columnIndex > map.width - 1;
        if (isOutOfBounds) throw new Error("Column index is out of bounds");

        const cells = map.cellsInRect({ left: columnIndex, right: columnIndex + 1, top: 0, bottom: map.height });
        return cells.map(pos => ({ pos: pos + table.pos + 1, node: table.node.nodeAt(pos) }));
    };

export const getCellsInRow =
    (rowIndex: number) =>
    (selection: Selection): CellWithPos[] | null => {
        const table = findTable(selection);
        if (!table) return null;

        const map = TableMap.get(table.node);

        const isOutOfBounds = rowIndex < 0 || rowIndex > map.height - 1;
        if (isOutOfBounds) throw new Error("Row index is out of bounds");

        const cells = map.cellsInRect({ left: 0, right: map.width, top: rowIndex, bottom: rowIndex + 1 });
        return cells.map(pos => ({ pos: pos + table.pos + 1, node: table.node.nodeAt(pos) }));
    };

export const isCellSelection = (selection: Selection): selection is CellSelection => {
    return selection instanceof CellSelection;
};

export const selectColumn = (columnIndex: number) => (tr: Transaction) => {
    const cells = getCellsInColumn(columnIndex)(tr.selection);
    return cells ? selectCells(cells)(tr) : tr;
};

export const selectRow = (rowIndex: number) => (tr: Transaction) => {
    const cells = getCellsInRow(rowIndex)(tr.selection);
    return cells ? selectCells(cells)(tr) : tr;
};

export const selectTable = (tr: Transaction) => {
    const cells = getCellsInTable(tr.selection);
    return cells ? selectCells(cells)(tr) : tr;
};

const selectCells = (cells: CellWithPos[]) => (tr: Transaction) => {
    const { $anchor, $head } = getCellsBoundaries(cells)(tr);
    return tr.setSelection(new CellSelection($anchor, $head));
};

const getCellsBoundaries = (cells: CellWithPos[]) => (tr: Transaction) => {
    const $anchor = tr.doc.resolve(cells[0].pos);
    const $head = tr.doc.resolve(cells[cells.length - 1].pos);

    return { $anchor, $head };
};
