/********************************************************************
* libavio/src/Decoder.cpp
*
* Copyright (c) 2022  Stephen Rhodes
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*********************************************************************/

#include "Decoder.h"
#include "Player.h"

AVPixelFormat hw_pix_fmt = AV_PIX_FMT_NONE;
const char * good = "good";
const char * bad = "bad";

AVPixelFormat get_hw_format(AVCodecContext* ctx, const AVPixelFormat* pix_fmts)
{
    const AVPixelFormat* p;

    for (p = pix_fmts; *p != AV_PIX_FMT_NONE; p++) {
        if (*p == hw_pix_fmt) {
            ctx->opaque = (void*)good;
            return *p;
        }
    }

    fprintf(stderr, "Failed to get HW surface format.\n");
    ctx->opaque = (void*)bad;
    return AV_PIX_FMT_NONE;
}

namespace avio
{

Decoder::Decoder(Reader* reader, AVMediaType mediaType, AVHWDeviceType hw_device_type) : reader(reader), mediaType(mediaType)
{
    try {
        const char* str = av_get_media_type_string(mediaType);
        strMediaType = (str ? str : "UNKNOWN MEDIA TYPE");

        stream_index = av_find_best_stream(reader->fmt_ctx, mediaType, -1, -1, NULL, 0);
        if (stream_index < 0) {
            std::stringstream str;
            str << "Error opening stream, unable to find " << strMediaType << " stream";
            throw Exception(str.str());
        }
        stream = reader->fmt_ctx->streams[stream_index];
        dec = avcodec_find_decoder(stream->codecpar->codec_id);

        if (!dec) {
            std::stringstream str;
            str << "avcodec_find_decoder could not find " << avcodec_get_name(stream->codecpar->codec_id);
            throw Exception(str.str());
        }

        ex.ck(dec_ctx = avcodec_alloc_context3(dec), AAC3);
        ex.ck(avcodec_parameters_to_context(dec_ctx, stream->codecpar), APTC);
        dec_ctx->opaque = nullptr;

        if (mediaType == AVMEDIA_TYPE_VIDEO && dec_ctx->pix_fmt != AV_PIX_FMT_YUV420P) {
            ex.ck(sws_ctx = sws_getContext(dec_ctx->width, dec_ctx->height, dec_ctx->pix_fmt,
                dec_ctx->width, dec_ctx->height, AV_PIX_FMT_YUV420P, SWS_BICUBIC, NULL, NULL, NULL), SGC);
            cvt_frame = av_frame_alloc();
            cvt_frame->width = dec_ctx->width;
            cvt_frame->height = dec_ctx->height;
            cvt_frame->format = AV_PIX_FMT_YUV420P;
            av_frame_get_buffer(cvt_frame, 0);
        }

        this->hw_device_type = hw_device_type;
        ex.ck(frame = av_frame_alloc(), AFA);
        if (hw_device_type != AV_HWDEVICE_TYPE_NONE) {
            ex.ck(sw_frame = av_frame_alloc(), AFA);
            for (int i = 0;; i++) {
                const AVCodecHWConfig* config;
                config = avcodec_get_hw_config(dec, i);

                if (!config) {
                    std::stringstream str;
                    str << strMediaType << " Decoder " << dec->name << " does not support device type " << av_hwdevice_get_type_name(hw_device_type);
                    throw Exception(str.str());
                }

                if (config->methods & AV_CODEC_HW_CONFIG_METHOD_HW_DEVICE_CTX && config->device_type == hw_device_type) {
                    hw_pix_fmt = config->pix_fmt;
                    break;
                }
            }

            ex.ck(av_hwdevice_ctx_create(&hw_device_ctx, hw_device_type, NULL, NULL, 0), AHCC);
            dec_ctx->hw_device_ctx = av_buffer_ref(hw_device_ctx);
            dec_ctx->get_format = get_hw_format;
            hwPixFmtName = av_get_pix_fmt_name(hw_pix_fmt);

            if (hw_pix_fmt == AV_PIX_FMT_VAAPI && dec_ctx->codec_id != AV_CODEC_ID_H264) {
                std::stringstream str;
                str << "Hardware decoder VAAPI incompatible with codec " << dec_ctx->codec->long_name;
                throw Exception(str.str());
            }

            ex.ck(sws_ctx = sws_getContext(dec_ctx->width, dec_ctx->height, AV_PIX_FMT_NV12,
                dec_ctx->width, dec_ctx->height, AV_PIX_FMT_YUV420P, SWS_BICUBIC, NULL, NULL, NULL), SGC);

            cvt_frame = av_frame_alloc();
            cvt_frame->width = dec_ctx->width;
            cvt_frame->height = dec_ctx->height;
            cvt_frame->format = AV_PIX_FMT_YUV420P;
            av_frame_get_buffer(cvt_frame, 0);
        }

        ex.ck(avcodec_open2(dec_ctx, dec, NULL), AO2);
    }
    catch (const Exception& e) {
        std::stringstream str;
        std::string msg = e.what();
        if (msg.find("av_hwdevice_ctx_create") != std::string::npos)
            msg = "Hardware decoder not supported";
        
        str << "Decoder constructor exception: " << msg;
        ((Player*)reader->player)->request_reconnect = false;
        throw Exception(str.str());
    }
}

Decoder::~Decoder()
{
    if (frame) {
        av_frame_free(&frame);
    }
    if (sw_frame) {
        av_frame_free(&sw_frame);
    }
    if (cvt_frame) {
        av_frame_free(&cvt_frame);
    }
    if (dec_ctx) {
        flush();
        avcodec_free_context(&dec_ctx);
    }
    if (hw_device_ctx) {
        av_buffer_unref(&hw_device_ctx);
    }
    if (sws_ctx) {
        sws_freeContext(sws_ctx);
    }
}

void Decoder::flush()
{
    if (!dec_ctx) throw Exception("dec_ctx null");

    avcodec_flush_buffers(dec_ctx);
}

int Decoder::decode(AVPacket* pkt)
{
    if (!dec_ctx) throw Exception("dec_ctx null");

    if (dec_ctx->opaque && pkt) {
        if (strcmp(good, (const char *)dec_ctx->opaque)) {
            std::stringstream str;
            str << "hardware decoder failed to initialize: " << hwPixFmtName;
            std::cout << str.str() << std::endl;
            errorCallback(str.str(), ((Player*)reader->player)->uri, false);
        }
    }

    int ret = 0;
    try 
    {
        int width = dec_ctx->width;
        int height = dec_ctx->height;

        ex.ck(ret = avcodec_send_packet(dec_ctx, pkt), ASP);

        while (ret >= 0) {
            ret = avcodec_receive_frame(dec_ctx, frame);
            if (ret < 0) {
                if (ret == AVERROR_EOF || ret == AVERROR(EAGAIN)) {
                    return 0;
                }
                else if (ret < 0) {
                    ex.ck(ret, "error during decoding");
                }
            }

            if (frame->width != width || frame->height != height) {
                if (sw_frame) {
                    av_frame_free(&sw_frame);
                    ex.ck(sw_frame = av_frame_alloc(), AFA);
                }
                if (cvt_frame) {
                    av_frame_free(&cvt_frame);
                    ex.ck(cvt_frame = av_frame_alloc(), AFA);
                    cvt_frame->width = dec_ctx->width;
                    cvt_frame->height = dec_ctx->height;
                    cvt_frame->format = AV_PIX_FMT_YUV420P;
                    av_frame_get_buffer(cvt_frame, 0);
                }
                if (sws_ctx) {
                    sws_freeContext(sws_ctx);
                    ex.ck(sws_ctx = sws_getContext(dec_ctx->width, dec_ctx->height, AV_PIX_FMT_NV12,
                        dec_ctx->width, dec_ctx->height, AV_PIX_FMT_YUV420P, SWS_BICUBIC, NULL, NULL, NULL), SGC);
                }
            }

            Frame f;
            if (frame->format == hw_pix_fmt) {
                ex.ck(ret = av_hwframe_transfer_data(sw_frame, frame, 0), AHTD);
                ex.ck(av_frame_copy_props(sw_frame, frame));
                ex.ck(sws_scale(sws_ctx, sw_frame->data, sw_frame->linesize, 0, dec_ctx->height, 
                    cvt_frame->data, cvt_frame->linesize), SS);
                cvt_frame->pts = sw_frame->pts;

                f = Frame(cvt_frame);
                cvt_frame->width = dec_ctx->width;
                cvt_frame->height = dec_ctx->height;
                cvt_frame->format = AV_PIX_FMT_YUV420P;
                av_frame_get_buffer(cvt_frame, 0);
            }
            else {
                f = Frame(frame);
            }

            f.set_rts(stream);
            if (show_frames) std::cout << strMediaType << " decoder " << f.description() << std::endl;
            frame_q->push_move(f);
        }
    }
    catch (const QueueClosedException& e) { std::cout << "in process decoder closed queue exception" << std::endl; }
    catch (const Exception& e) {
        std::stringstream str;
        str << strMediaType << " Decoder::decode exception: " << e.what();
        if (infoCallback) infoCallback(str.str(), ((Player*)reader->player)->uri);
        else std::cout << str.str() << std::endl;
        ret = -1;
    }

    return ret;
}


}
