#! /usr/bin/env escript
%% This line tells emacs to use -*- erlang -*- mode -*- coding: iso-8859-1 -*-
%%! -pa ../ebin/ -pa tmp/ -sname protoexerciser

%%% Copyright (C) 2011  Tomas Abrahamsson
%%%
%%% Author: Tomas Abrahamsson <tab@lysator.liu.se>
%%%
%%% This library is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU Lesser General Public
%%% License as published by the Free Software Foundation; either
%%% version 2.1 of the License, or (at your option) any later version.
%%%
%%% This library is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
%%% Lesser General Public License for more details.
%%%
%%% You should have received a copy of the GNU Lesser General Public
%%% License along with this library; if not, write to the Free Software
%%% Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
%%% MA  02110-1301  USA

-mode(compile).
%-compile(export_all).
-export([main/1]).

main(Args) ->
    run_tests(Args).

run_tests(["--echo", Str | Rest]) ->
    io:format("~s~n", [Str]),
    run_tests(Rest);
run_tests([MsgModuleStr, MsgNameStr | MsgFileAndRest]) ->
    MsgModule = list_to_atom(MsgModuleStr),
    MsgName = list_to_atom(MsgNameStr),
    {MsgFiles, Rest} = pickup_msg_file_or_files(MsgFileAndRest),
    MsgBins = [begin
        {{ok, B}, _} = {file:read_file(MsgFile), MsgFile},
        B
    end || MsgFile <- MsgFiles],
    DataSize = iolist_size(MsgBins),
    EMsgName = case encode_needs_msgname(MsgModule) of
        true -> MsgName;
        false -> undefined
    end,
    ok = enif_protobuf:load_cache(MsgModule:get_msg_defs()),
    Decoder = fun() -> decode_bins(MsgBins, MsgModule, MsgName) end,
    Msgs = [MsgModule:decode_msg(Bin, MsgName) || Bin <- MsgBins],
    Encoder = fun() -> encode_msgs(Msgs, MsgModule, EMsgName) end,

    io:format("Benchmarking ~s ~s with file ~s~n", [MsgModule, MsgName, string:join(MsgFiles, ",")]),
    run_test("Serialize to binary", DataSize, Encoder),
    run_test("Deserialize from binary", DataSize, Decoder),
    io:format("~n"),
    run_tests(Rest);
run_tests([]) ->
    ok.

encode_needs_msgname(MsgModule) ->
    Exports = MsgModule:module_info(exports),
    case lists:keyfind(encode_msg, 1, Exports) of
        {encode_msg, 1} -> false; %% records
        {encode_msg, 2} -> true   %% maps
    end.

pickup_msg_file_or_files(["--multi" | MsgFilesAndRest]) ->
    EndMarker = "--end-multi",
    {MsgFiles, [EndMarker | Rest]} =
        lists:splitwith(fun(S) -> S /= EndMarker end, MsgFilesAndRest),
    {MsgFiles, Rest};
pickup_msg_file_or_files([MsgFile | Rest]) ->
    {[MsgFile], Rest}.

%% decode_bins([MsgBin | Rest], MsgModule, MsgName) ->
%%     MsgModule:decode_msg(MsgBin, MsgName),
%%     decode_bins(Rest, MsgModule, MsgName);
%% decode_bins([], _, _) ->
%%     ok.

%% encode_msgs([Msg | Rest], MsgModule, undefined) ->
%%     MsgModule:encode_msg(Msg), %% MsgName not needed for records
%%     encode_msgs(Rest, MsgModule, undefined);
%% encode_msgs([Msg | Rest], MsgModule, MsgName) ->
%%     MsgModule:encode_msg(Msg, MsgName), %% MsgName needed for maps
%%     encode_msgs(Rest, MsgModule, MsgName);
%% encode_msgs([], _, _) ->
%%     ok.

decode_bins([MsgBin | Rest], MsgModule, MsgName) ->
    enif_protobuf:decode(MsgBin, MsgName),
    decode_bins(Rest, MsgModule, MsgName);
decode_bins([], _, _) ->
    ok.

encode_msgs([Msg | Rest], MsgModule, undefined) ->
    %% Msg0 = setelement(3, Msg, unicode:characters_to_binary(element(3, Msg))),
    enif_protobuf:encode(Msg), %% MsgName not needed for records
    encode_msgs(Rest, MsgModule, undefined);
encode_msgs([Msg | Rest], MsgModule, MsgName) ->
    enif_protobuf:encode(Msg, MsgName), %% MsgName needed for maps
    encode_msgs(Rest, MsgModule, MsgName);
encode_msgs([], _, _) ->
    ok.

run_test(Description, DataSize, Action) ->
    {P, M} = spawn_monitor(
        fun() ->
            run_test_aux(Description, DataSize, Action)
        end),
    receive
        {'DOWN', M, process, P, normal} ->
            ok;
        {'DOWN', M, process, P, Other} ->
            error({aux_died, Other})
    end.

run_test_aux(Description, DataSize, Action) ->
    MinSampleTime = 2, %% seconds
    TargetTime = 30,   %% seconds
    {Elapsed, NumIterations} = iterate_until_elapsed(MinSampleTime, Action),
    TargetNumIterations = round((TargetTime / Elapsed) * NumIterations),
    Elapsed2 = time_action(TargetNumIterations, Action),
    io:format("~s: ~w iterations in ~.3fs; ~.2fMB/s~n",
        [Description, TargetNumIterations, Elapsed2, (TargetNumIterations * DataSize) / (Elapsed2 * 1024 * 1024)]),
    ok.

iterate_until_elapsed(MaxDuration, Action) ->
    iterate_until_elapsed_2(1, MaxDuration, Action).

iterate_until_elapsed_2(NumIterations, MaxDuration, Action) ->
    case time_action(NumIterations, Action) of
        Elapsed when Elapsed < MaxDuration ->
            iterate_until_elapsed_2(NumIterations * 2, MaxDuration, Action);
        Elapsed when Elapsed >= MaxDuration ->
            {Elapsed, NumIterations}
    end.

time_action(NumIterations, Action) ->
    garbage_collect(),
    T0 = os:timestamp(),
    iterate_action(NumIterations, Action),
    T1 = os:timestamp(),
    timer:now_diff(T1, T0) / 1000000.

iterate_action(N, Action) when N > 0 ->
    Action(),
    iterate_action(N - 1, Action);
iterate_action(0, _Action) ->
    ok.
