import { z } from 'zod'

import { assertNever } from '../assert.js'
import {
    collectCciFromExpr,
    evaluateExpression,
    ExpressionAdapter,
    getExpressionSchema,
    traverseExpression,
    validateExpressionAndType,
} from '../expressions.js'
import { getAnyNumReq } from '../type-utils.js'
import { BinaryExpression, VariableType } from '../types.js'

export const binaryAdapter: ExpressionAdapter<BinaryExpression> = {
    evaluate: (context, expr) => {
        const left = evaluateExpression(context, expr.left) as number
        const right = evaluateExpression(context, expr.right) as number

        if (expr.operator === '*') {
            return left * right
        }

        if (expr.operator === '/') {
            return left / right
        }

        if (expr.operator === '+') {
            return left + right
        }

        if (expr.operator === '-') {
            return left - right
        }

        if (expr.operator === '^') {
            return left ** right
        }

        throw assertNever(expr.operator, 'binary expression operator')
    },
    getType: (): VariableType => ({ kind: 'number', format: 'decimal' }), // TODO format
    getSchema: () =>
        z
            .object({
                type: z.literal('binary'),
                left: getExpressionSchema(),
                operator: z.enum(['*', '/', '+', '-', '^']),
                right: getExpressionSchema(),
            })
            .strict(),
    validate: (context, expr) => {
        validateExpressionAndType(
            context,
            expr.left,
            getBinaryExpressionRequiredTypes().operand,
            'BinaryExpression.left',
        )

        validateExpressionAndType(
            context,
            expr.right,
            getBinaryExpressionRequiredTypes().operand,
            'BinaryExpression.right',
        )
    },
    collectCci: (context, expr) => {
        collectCciFromExpr(context, expr.left)
        collectCciFromExpr(context, expr.right)
    },
    traverse: (context, expr) => {
        traverseExpression(context, expr.left)
        traverseExpression(context, expr.right)
    },
}

// eslint-disable-next-line return-types-object-literals/require-return-types-for-object-literals
export const getBinaryExpressionRequiredTypes = () => ({
    operand: getAnyNumReq(),
})
