# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import warnings

import awkward_cpp

import awkward as ak
from awkward._backends.backend import Backend, KernelKeyType
from awkward._backends.dispatch import register_backend
from awkward._kernels import JaxKernel
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Final

np = NumpyMetadata.instance()
numpy = Numpy.instance()


@register_backend(Jax)  # type: ignore[type-abstract]
class JaxBackend(Backend):
    name: Final[str] = "jax"

    _jax: Jax

    @property
    def nplike(self) -> Jax:
        return self._jax

    def __init__(self):
        warnings.warn(
            "The JAX backend is deprecated and will be removed in a future release of Awkward Array. "
            "Please plan to migrate your code accordingly.",
            DeprecationWarning,
            stacklevel=2,
        )
        self._jax = Jax.instance()

    def __getitem__(self, index: KernelKeyType) -> JaxKernel:
        # JAX uses Awkward's C++ kernels for index-only operations
        return JaxKernel(awkward_cpp.cpu_kernels.kernel[index], index)

    def prepare_reducer(self, reducer: ak._reducers.Reducer) -> ak._reducers.Reducer:
        from awkward._connect.jax import get_jax_reducer

        return get_jax_reducer(reducer)
