import { Arith, RatNum, Context, Z3LowLevel } from 'z3-solver';

import { range } from '@adamburgess/linq/enumerable'

export interface CalculatorInput {
    recipeTime: number
    recipeAmount: number
    wantedPerMinute: number

    machineSpeed: number
    numModules: number

    modules: {
        speed: number
        cost: number
    }[]

    only: number
    range: { min: number, max: number }
}

export interface CalculatorOutput {
    machines: number
    modules: number[]
}

export async function findSolutions(Z3: Context, z3lib: Z3LowLevel, input: CalculatorInput, onResult: (output: CalculatorOutput) => unknown) {
    const solver = new Z3.Optimize();

    // add blank module.
    input.modules = [{ speed: 0, cost: 0 }, ...input.modules];

    let speedsArray = Z3.Array.const('speeds', Z3.Int.sort(), Z3.Real.sort());
    let costsArray = Z3.Array.const('costs', Z3.Int.sort(), Z3.Real.sort());
    for (let i = 0; i < input.modules.length; i++) {
        speedsArray = speedsArray.store(i, input.modules[i].speed);
        costsArray = costsArray.store(i, input.modules[i].cost);
    }

    let modules: Arith[] = [];
    for (let i = 0; i < input.numModules; i++) {
        let m = Z3.Int.const('module' + i);
        modules.push(m);
        solver.add(m.ge(0));
        solver.add(m.lt(input.modules.length));
    }

    // just to make it nice, make sure they're all in order lol
    for (let i = 0; i < input.numModules - 1; i++) {
        for (let j = i + 1; j < input.numModules; j++) {
            solver.add(modules[i].ge(modules[j]));
        }
    }

    let recipeOutputPerMinute = 60 / input.recipeTime * input.recipeAmount;
    let perMinute = Z3.Real.const('per_minute');
    solver.add(perMinute.eq(input.wantedPerMinute));

    let machines = Z3.Real.const('machines');
    let totalSpeed = Z3.Real.const('total_speed');
    let totalCost = Z3.Real.const('total_cost');

    solver.add(perMinute.eq(machines.mul(recipeOutputPerMinute).mul(totalSpeed)));

    solver.add(totalSpeed.eq(
        Z3.Real.val(input.machineSpeed).mul(
            Z3.Sum(Z3.Real.val(1),
                ...[...range(0, input.numModules)].map(i => speedsArray.select(modules[i]))
            )
        ))
    );
    solver.add(totalCost.eq(
        Z3.Sum(Z3.Real.val(0), // what.
            ...[...range(0, input.numModules)].map(i => costsArray.select(modules[i]))
        )
    ));

    // for each bracket only do one solution.
    solver.add(machines.le(input.range.max).and(machines.gt(input.range.min)));

    solver.maximize(machines);
    solver.minimize(totalCost);

    //const SMT = z3lib.Z3.optimize_to_string(Z3.ptr, solver.ptr);

    let total = 0;
    while (true) {
        if (await solver.check() !== 'sat') {
            break;
        }

        const model = await solver.model();

        let moduleValues = [];
        for (let i = 0; i < input.numModules; i++) {
            moduleValues.push(Number(model.get(modules[i])));
        }

        onResult({
            machines: (model.get(machines) as RatNum).asNumber(),
            modules: moduleValues
        });

        if (++total == input.only) {
            break;
        }

        // get unique solutions: ignore other permutations of this machine speed
        solver.add(machines.lt(model.get(machines) as Arith));
        let eq = modules[0].eq(model.get(modules[0]));
        for (let i = 1; i < input.numModules; i++) {
            eq = eq.and(modules[i].eq(model.get(modules[i])));
        }
        solver.add(Z3.Not(eq));
    }

    //return SMT;
}