import { z } from 'zod'

import {
    collectCciFromExpr,
    evaluateExpression,
    ExpressionAdapter,
    getExpressionSchema,
    getExprType,
    traverseExpression,
    validateExpressionAndType,
} from '../expressions.js'
import { getPrimitiveReq } from '../type-utils.js'
import { EqualsExpression, VariableType } from '../types.js'
import { mergeTypes, validateType } from '../var-types.js'

export const equalsAdapter: ExpressionAdapter<EqualsExpression> = {
    evaluate: (context, expr) => {
        const left = evaluateExpression(context, expr.left)
        const right = evaluateExpression(context, expr.right)

        // For now, we assume primitive types (str, num, bool), not reference types
        return left === right
    },
    getType: (): VariableType => ({ kind: 'bool' }),
    getSchema: () =>
        z
            .object({
                type: z.literal('equals'),
                left: getExpressionSchema(),
                right: getExpressionSchema(),
            })
            .strict(),
    validate: (context, expr) => {
        validateExpressionAndType(
            context,
            expr.left,
            getEqualsExpressionRequiredTypes().operand,
            'EqualsExpression.left',
        )

        validateExpressionAndType(
            context,
            expr.right,
            getEqualsExpressionRequiredTypes().operand,
            'EqualsExpression.right',
        )

        const leftType = getExprType(context.types, expr.left, true)
        const rightType = getExprType(context.types, expr.right, true)
        validateType(mergeTypes(leftType, rightType))
    },
    collectCci: (context, expr) => {
        collectCciFromExpr(context, expr.left)
        collectCciFromExpr(context, expr.right)
    },
    traverse: (context, expr) => {
        traverseExpression(context, expr.left)
        traverseExpression(context, expr.right)
    },
}

// eslint-disable-next-line return-types-object-literals/require-return-types-for-object-literals
export const getEqualsExpressionRequiredTypes = () => ({
    operand: getPrimitiveReq(),
})
