# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import random

import pytest

from pyiceberg.utils.bin_packing import ListPacker, PackingIterator

INT_MAX = 2147483647


@pytest.mark.parametrize(
    "splits, lookback, split_size, open_cost",
    [
        ([random.randint(0, 64) for x in range(200)], 20, 128, 4),  # random splits
        ([], 20, 128, 4),  # no splits
        (
            [0] * 100 + [random.randint(0, 64) in range(10)] + [0] * 100,
            20,
            128,
            4,
        ),  # sparse
    ],
)
def test_bin_packing(splits: list[int], lookback: int, split_size: int, open_cost: int) -> None:
    def weight_func(x: int) -> int:
        return max(x, open_cost)

    item_list_sums: list[int] = [sum(item) for item in PackingIterator(splits, split_size, lookback, weight_func)]
    assert all(split_size >= item_sum >= 0 for item_sum in item_list_sums)


@pytest.mark.parametrize(
    "splits, target_weight, lookback, largest_bin_first, expected_lists",
    [
        (
            [36, 36, 36, 36, 73, 110, 128],
            128,
            2,
            True,
            [[110], [128], [36, 73], [36, 36, 36]],
        ),
        (
            [36, 36, 36, 36, 73, 110, 128],
            128,
            2,
            False,
            [[36, 36, 36], [36, 73], [110], [128]],
        ),
        (
            [64, 64, 128, 32, 32, 32, 32],
            128,
            1,
            True,
            [[64, 64], [128], [32, 32, 32, 32]],
        ),
        (
            [64, 64, 128, 32, 32, 32, 32],
            128,
            1,
            False,
            [[64, 64], [128], [32, 32, 32, 32]],
        ),
    ],
)
def test_bin_packing_lookback(
    splits: list[int], target_weight: int, lookback: int, largest_bin_first: bool, expected_lists: list[list[int]]
) -> None:
    def weight_func(x: int) -> int:
        return x

    packer: ListPacker[int] = ListPacker(target_weight, lookback, largest_bin_first)

    assert list(PackingIterator(splits, target_weight, lookback, weight_func, largest_bin_first)) == expected_lists
    assert list(packer.pack(splits, weight_func)) == expected_lists


@pytest.mark.parametrize(
    "splits, target_weight, lookback, largest_bin_first, expected_lists",
    [
        # Single Lookback Tests
        ([1, 2, 3, 4, 5], 3, 1, False, [[1, 2], [3], [4], [5]]),
        ([1, 2, 3, 4, 5], 4, 1, False, [[1, 2], [3], [4], [5]]),
        ([1, 2, 3, 4, 5], 5, 1, False, [[1], [2, 3], [4], [5]]),
        ([1, 2, 3, 4, 5], 6, 1, False, [[1, 2, 3], [4], [5]]),
        ([1, 2, 3, 4, 5], 7, 1, False, [[1, 2], [3, 4], [5]]),
        ([1, 2, 3, 4, 5], 8, 1, False, [[1, 2], [3, 4], [5]]),
        ([1, 2, 3, 4, 5], 9, 1, False, [[1, 2, 3], [4, 5]]),
        ([1, 2, 3, 4, 5], 11, 1, False, [[1, 2, 3], [4, 5]]),
        ([1, 2, 3, 4, 5], 12, 1, False, [[1, 2], [3, 4, 5]]),
        ([1, 2, 3, 4, 5], 14, 1, False, [[1], [2, 3, 4, 5]]),
        ([1, 2, 3, 4, 5], 15, 1, False, [[1, 2, 3, 4, 5]]),
        # Unlimited Lookback Tests
        ([1, 2, 3, 4, 5], 3, INT_MAX, False, [[1, 2], [3], [4], [5]]),
        ([1, 2, 3, 4, 5], 4, INT_MAX, False, [[2], [1, 3], [4], [5]]),
        ([1, 2, 3, 4, 5], 5, INT_MAX, False, [[2, 3], [1, 4], [5]]),
        ([1, 2, 3, 4, 5], 6, INT_MAX, False, [[3], [2, 4], [1, 5]]),
        ([1, 2, 3, 4, 5], 7, INT_MAX, False, [[1], [3, 4], [2, 5]]),
        ([1, 2, 3, 4, 5], 8, INT_MAX, False, [[1, 2, 4], [3, 5]]),
        ([1, 2, 3, 4, 5], 9, INT_MAX, False, [[1, 2, 3], [4, 5]]),
        ([1, 2, 3, 4, 5], 10, INT_MAX, False, [[2, 3], [1, 4, 5]]),
        ([1, 2, 3, 4, 5], 11, INT_MAX, False, [[1, 3], [2, 4, 5]]),
        ([1, 2, 3, 4, 5], 12, INT_MAX, False, [[1, 2], [3, 4, 5]]),
        ([1, 2, 3, 4, 5], 13, INT_MAX, False, [[2], [1, 3, 4, 5]]),
        ([1, 2, 3, 4, 5], 14, INT_MAX, False, [[1], [2, 3, 4, 5]]),
        ([1, 2, 3, 4, 5], 15, INT_MAX, False, [[1, 2, 3, 4, 5]]),
    ],
)
def test_reverse_bin_packing_lookback(
    splits: list[int], target_weight: int, lookback: int, largest_bin_first: bool, expected_lists: list[list[int]]
) -> None:
    packer: ListPacker[int] = ListPacker(target_weight, lookback, largest_bin_first)
    result = packer.pack_end(splits, lambda x: x)
    assert result == expected_lists
