-- Filename: square_struct.vhd.vhd
-- Created by HDL-SCHEM-Editor at Fri Oct 18 10:44:46 2024
library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;
use ieee.math_real.all;
architecture struct of square is
    function calculate_additions_per_period return natural is
    begin
        if g_latency=0 then
            return g_operand_width;
        end if;
        return integer(ceil(real(g_operand_width)/real(g_latency)));
    end function;
    constant c_additions_per_period : natural := calculate_additions_per_period; 

    function calculate_operand_width_internal return natural is
    begin
        if g_latency=0 then
            return g_operand_width;
        end if;
        return c_additions_per_period*g_latency;
    end function;
    constant c_operand_width_internal : natural := calculate_operand_width_internal;

    function calculate_number_of_periods return natural is
    begin
        if g_latency=0 then
            return 1;
        end if;
        return g_latency;
    end function;
    constant c_number_of_periods : natural := calculate_number_of_periods;

    type t_summand_array is array (natural range <>) of unsigned(c_operand_width_internal-2 downto 0);
    type t_mask          is array (natural range <>) of t_summand_array(c_additions_per_period-1 downto 0);
    function init_mask return t_mask is
        variable mask_v : t_mask(c_number_of_periods-1 downto 0);
    begin
        for p in 0 to c_number_of_periods-1 loop
            for a in 0 to c_additions_per_period-1 loop
                mask_v(p)(a) := to_unsigned(2**(p*c_additions_per_period+a+1)-1, c_operand_width_internal-1);
            end loop;
        end loop;
        return mask_v;
    end function;
    constant c_mask : t_mask(c_number_of_periods-1 downto 0) := init_mask;

    type t_shift_registers is array (natural range <>) of unsigned(c_number_of_periods-1 downto 0);
    signal counter                         : natural range 0 to c_number_of_periods-1;
    signal factor_bits                     : signed(c_additions_per_period-1 downto 0);
    signal fix_negative_operand            : std_logic_vector(c_additions_per_period-1 downto 0);
    signal last_step                       : std_logic;
    signal mask                            : t_summand_array(c_additions_per_period-1 downto 0);
    signal operand_bits_from_shiftregister : signed(c_additions_per_period-1 downto 0);
    signal operand_int                     : unsigned(c_operand_width_internal-1 downto 0);
    signal operand_int_lsb_s               : std_logic;
    signal partial_sum                     : t_summand_array(c_additions_per_period downto 0);
    signal partial_sum_lsb                 : std_logic_vector(c_additions_per_period-1 downto 0);
    signal partial_sum_stored              : unsigned(c_operand_width_internal-2 downto 0);
    signal ready                           : std_logic;
    signal reg_enable                      : std_logic;
    signal shift_registers                 : t_shift_registers(c_additions_per_period-1 downto 0);
    signal square_all                      : signed(2*c_operand_width_internal downto 0);
    component square_control is
        generic (
            g_counter_max : natural := 8
        );
        port (
            clk_i        : in  std_logic;
            res_i        : in  std_logic;
            start_i      : in  std_logic;
            counter_o    : out natural range 0 to g_counter_max;
            last_step_o  : out std_logic;
            ready_o      : out std_logic;
            reg_enable_o : out std_logic
        );
    end component;
    component square_step is
        generic (
            g_operand_width_internal : natural := 8
        );
        port (
            factor_bit_i           : in  std_logic;
            fix_negative_operand_i : in  std_logic;
            mask_i                 : in  unsigned(g_operand_width_internal-2 downto 0);
            operand_int_i          : in  unsigned(g_operand_width_internal-2 downto 0);
            partial_sum_i          : in  unsigned(g_operand_width_internal-2 downto 0);
            partial_sum_lsb_o      : out std_logic;
            partial_sum_o          : out unsigned(g_operand_width_internal-2 downto 0)
        );
    end component;
begin
    square_control_inst : square_control
        generic map (
            g_counter_max => c_number_of_periods-1
        )
        port map (
            clk_i        => clk_i,
            res_i        => res_i,
            start_i      => start_i,
            counter_o    => counter,
            last_step_o  => last_step,
            ready_o      => ready,
            reg_enable_o => reg_enable
        );
    fix_negative_operand <= (last_step, others => '0');
    mask <= c_mask(counter);
    -- When g_operand_width is not an integer multiple of g_latency,
    -- then the operand must be extended by additional bits until its
    -- length is an integer multiple of g_latency:
    operand_int <= unsigned(resize(operand_i, operand_int'length));
    -- During the first step the operand bits are taken from the input,
    -- afterwards they are taken from the outputs of the shift registers:
    process (start_i, operand_i, operand_bits_from_shiftregister, partial_sum_stored)
    begin
        if start_i='1' or g_latency=0 or g_latency=1 then
            partial_sum(0) <= (partial_sum_stored'range => '0');
            if c_additions_per_period=g_operand_width then
                factor_bits <= operand_i(g_operand_width-1) & operand_i(g_operand_width-1 downto 1);
            else
                factor_bits <= operand_i(c_additions_per_period downto 1);
            end if;
        else
            partial_sum(0) <= partial_sum_stored;
            factor_bits <= operand_bits_from_shiftregister;
        end if;
    end process;
    square_step_g: for i in 0 to c_additions_per_period-1 generate
        square_step_inst : square_step
            generic map (
                g_operand_width_internal => c_operand_width_internal
            )
            port map (
                factor_bit_i           => factor_bits(i),
                fix_negative_operand_i => fix_negative_operand(i),
                mask_i                 => mask(i),
                operand_int_i          => operand_int(c_operand_width_internal-2 downto 0),
                partial_sum_i          => partial_sum(i),
                partial_sum_lsb_o      => partial_sum_lsb(i),
                partial_sum_o          => partial_sum(i+1)
            );
    end generate square_step_g;
    register_g: if g_latency/=0 generate
        ready_o <= ready;
        process(res_i, clk_i)
        begin
            if res_i='1' then
                partial_sum_stored <= (others => '0');
                shift_registers    <= (others => (others => '0'));
                operand_int_lsb_s  <= '0';
            elsif rising_edge(clk_i) then
                if start_i='1' then
                    operand_int_lsb_s <= operand_int(0);
                    partial_sum_stored <= partial_sum(c_additions_per_period);
                    -- Example with operand_int = "abcdefgh":
                    -- When c_additions_per_period=1 and c_number_of_periods=8:
                    -- shiftregister(0) = "Saabcdef"
                    -- When c_additions_per_period=2 and c_number_of_periods=4:
                    -- shiftregister(0) = "Sace"
                    -- shiftregister(1) = "Sabd"
                    for a in c_additions_per_period-1 downto 0 loop
                        for p in 0 to c_number_of_periods-1 loop
                            if p/=c_number_of_periods-1 then
                                if p*c_additions_per_period+a+c_additions_per_period+1<=operand_int'high then
                                    shift_registers(a)(p) <= operand_int(p*c_additions_per_period+a+c_additions_per_period+1);
                                else
                                    shift_registers(a)(p) <= operand_int(p*c_additions_per_period+a+c_additions_per_period);
                                end if;
                            else
                                shift_registers(a)(p) <= partial_sum_lsb(a);
                            end if;
                        end loop;
                    end loop;
                elsif reg_enable='1' then
                    partial_sum_stored <= partial_sum(c_additions_per_period);
                    for a in c_additions_per_period-1 downto 0 loop
                        shift_registers(a) <= partial_sum_lsb(a) & shift_registers(a)(c_number_of_periods-1 downto 1);
                    end loop;
                end if;
            end if;
        end process;
    end generate register_g;
    combinatoric_g: if g_latency=0 generate
        ready_o <= start_i;
        operand_int_lsb_s <= operand_int(0);
        process(partial_sum_lsb)
        begin
            for a in 0 to c_additions_per_period-1 loop
                shift_registers(a)(0) <= partial_sum_lsb(a);
            end loop;
        end process;
        partial_sum_stored <= partial_sum(c_additions_per_period);
    end generate combinatoric_g;
    -- Copy the operand_i bits from the outputs of the shift registers into an array:
    process (shift_registers)
    begin
        for i in 0 to c_additions_per_period-1 loop
            operand_bits_from_shiftregister(i) <= shift_registers(i)(0);
        end loop;
    end process;
    -- Copy all result bits into an array:
    process(shift_registers, operand_int_lsb_s, partial_sum_stored)
    begin
        square_all(1 downto 0) <= '0' & operand_int_lsb_s;
        for p in 0 to c_number_of_periods-1 loop
            for a in 0 to c_additions_per_period-1 loop
                square_all(p*c_additions_per_period+2+a) <= shift_registers(a)(p);
            end loop;
        end loop;
        for i in 0 to c_operand_width_internal-2 loop
            square_all(c_number_of_periods*c_additions_per_period+2+i) <= partial_sum_stored(i);
        end loop;
    end process;
    -- Connect the output:
    square_o <= '0'&square_all(2*g_operand_width-2 downto 0);
end architecture;
