/**
 * Copyright 2022 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.micrometer.tracing.brave.bridge;

import brave.Span;
import brave.Tracer;
import brave.internal.baggage.BaggageFields;
import brave.internal.propagation.StringPropagationAdapter;
import brave.propagation.Propagation;
import brave.propagation.TraceContext;
import brave.propagation.TraceContextOrSamplingFlags;
import io.micrometer.common.lang.Nullable;
import io.micrometer.common.util.StringUtils;
import io.micrometer.common.util.internal.logging.InternalLogger;
import io.micrometer.common.util.internal.logging.InternalLoggerFactory;
import io.micrometer.tracing.Baggage;
import io.micrometer.tracing.BaggageManager;
import io.micrometer.tracing.internal.EncodingUtils;

import java.util.*;

import static java.util.Collections.singletonList;

/**
 * Adopted from OpenTelemetry API.
 * <p>
 * Implementation of the TraceContext propagation protocol. See <a
 * href=https://github.com/w3c/distributed-tracing>w3c/distributed-tracing</a>.
 *
 * @author OpenTelemetry Authors
 * @author Marcin Grzejszczak
 * @since 1.0.0
 */
@SuppressWarnings({ "unchecked", "deprecation" })
public class W3CPropagation extends Propagation.Factory implements Propagation<String> {

    static final String TRACEPARENT = "traceparent";

    static final String TRACESTATE = "tracestate";

    private static final InternalLogger logger = InternalLoggerFactory.getInstance(W3CPropagation.class.getName());

    private static final List<String> FIELDS = Collections.unmodifiableList(Arrays.asList(TRACEPARENT, TRACESTATE));

    private static final String VERSION = "00";

    private static final int VERSION_SIZE = 2;

    private static final char TRACEPARENT_DELIMITER = '-';

    private static final int TRACEPARENT_DELIMITER_SIZE = 1;

    private static final int LONG_BYTES = Long.SIZE / Byte.SIZE;

    private static final int BYTE_BASE16 = 2;

    private static final int LONG_BASE16 = BYTE_BASE16 * LONG_BYTES;

    private static final int TRACE_ID_HEX_SIZE = 2 * LONG_BASE16;

    private static final int SPAN_ID_SIZE = 8;

    private static final int SPAN_ID_HEX_SIZE = 2 * SPAN_ID_SIZE;

    private static final int FLAGS_SIZE = 1;

    private static final int TRACE_OPTION_HEX_SIZE = 2 * FLAGS_SIZE;

    private static final int TRACE_ID_OFFSET = VERSION_SIZE + TRACEPARENT_DELIMITER_SIZE;

    private static final int SPAN_ID_OFFSET = TRACE_ID_OFFSET + TRACE_ID_HEX_SIZE + TRACEPARENT_DELIMITER_SIZE;

    private static final int TRACE_OPTION_OFFSET = SPAN_ID_OFFSET + SPAN_ID_HEX_SIZE + TRACEPARENT_DELIMITER_SIZE;

    private static final int TRACEPARENT_HEADER_SIZE = TRACE_OPTION_OFFSET + TRACE_OPTION_HEX_SIZE;

    private static final String INVALID_TRACE_ID = "00000000000000000000000000000000";

    private static final String INVALID_SPAN_ID = "0000000000000000";

    // private static final char TRACESTATE_ENTRY_DELIMITER = ',';

    private static final Set<String> VALID_VERSIONS;

    private static final String VERSION_00 = "00";

    @Nullable
    private final W3CBaggagePropagator baggagePropagator;

    @Nullable
    private final BaggageManager braveBaggageManager;

    /**
     * Creates an instance of {@link W3CPropagation} with baggage support.
     * @param baggageManager baggage manager
     * @param localFields local fields to be registered as baggage
     */
    public W3CPropagation(BaggageManager baggageManager, List<String> localFields) {
        this.baggagePropagator = new W3CBaggagePropagator(baggageManager, localFields);
        this.braveBaggageManager = baggageManager;
    }

    /**
     * Creates an instance of {@link W3CPropagation} without baggage support.
     */
    public W3CPropagation() {
        this.baggagePropagator = null;
        this.braveBaggageManager = null;
    }

    private static boolean isTraceIdValid(CharSequence traceId) {
        return (traceId.length() == TRACE_ID_HEX_SIZE) && !INVALID_TRACE_ID.contentEquals(traceId)
                && EncodingUtils.isValidBase16String(traceId);
    }

    private static boolean isSpanIdValid(String spanId) {
        return (spanId.length() == SPAN_ID_HEX_SIZE) && !INVALID_SPAN_ID.equals(spanId)
                && EncodingUtils.isValidBase16String(spanId);
    }

    private static TraceContext extractContextFromTraceParent(String traceparent) {
        // TODO(bdrutu): Do we need to verify that version is hex and that
        // for the version the length is the expected one?
        boolean isValid = (traceparent.length() == TRACEPARENT_HEADER_SIZE
                || (traceparent.length() > TRACEPARENT_HEADER_SIZE
                        && traceparent.charAt(TRACEPARENT_HEADER_SIZE) == TRACEPARENT_DELIMITER))
                && traceparent.charAt(TRACE_ID_OFFSET - 1) == TRACEPARENT_DELIMITER
                && traceparent.charAt(SPAN_ID_OFFSET - 1) == TRACEPARENT_DELIMITER
                && traceparent.charAt(TRACE_OPTION_OFFSET - 1) == TRACEPARENT_DELIMITER;
        if (!isValid) {
            logger.info("Unparseable traceparent header. Returning INVALID span context.");
            return null;
        }

        try {
            String version = traceparent.substring(0, 2);
            if (!VALID_VERSIONS.contains(version)) {
                return null;
            }
            if (version.equals(VERSION_00) && traceparent.length() > TRACEPARENT_HEADER_SIZE) {
                return null;
            }

            String traceId = traceparent.substring(TRACE_ID_OFFSET, TRACE_ID_OFFSET + TRACE_ID_HEX_SIZE);
            String spanId = traceparent.substring(SPAN_ID_OFFSET, SPAN_ID_OFFSET + SPAN_ID_HEX_SIZE);
            if (isTraceIdValid(traceId) && isSpanIdValid(spanId)) {
                String traceIdHigh = traceId.substring(0, traceId.length() / 2);
                String traceIdLow = traceId.substring(traceId.length() / 2);
                byte isSampled = TraceFlags.byteFromHex(traceparent, TRACE_OPTION_OFFSET);
                return TraceContext.newBuilder()
                    .traceIdHigh(EncodingUtils.longFromBase16String(traceIdHigh))
                    .traceId(EncodingUtils.longFromBase16String(traceIdLow))
                    .spanId(EncodingUtils.longFromBase16String(spanId))
                    .sampled(isSampled == TraceFlags.IS_SAMPLED)
                    .build();
            }
            return null;
        }
        catch (IllegalArgumentException e) {
            logger.info("Unparseable traceparent header. Returning INVALID span context.");
            return null;
        }
    }

    @Override
    public <K> Propagation<K> create(KeyFactory<K> keyFactory) {
        return StringPropagationAdapter.create(this, keyFactory);
    }

    @Override
    public Propagation<String> get() {
        return this;
    }

    @Override
    public List<String> keys() {
        return FIELDS;
    }

    @Override
    public <R> TraceContext.Injector<R> injector(Setter<R, String> setter) {
        return (context, carrier) -> {
            Objects.requireNonNull(context, "context");
            Objects.requireNonNull(setter, "setter");
            char[] chars = TemporaryBuffers.chars(TRACEPARENT_HEADER_SIZE);
            chars[0] = VERSION.charAt(0);
            chars[1] = VERSION.charAt(1);
            chars[2] = TRACEPARENT_DELIMITER;
            String traceId = padLeftWithZeros(context.traceIdString(), TRACE_ID_HEX_SIZE);
            for (int i = 0; i < traceId.length(); i++) {
                chars[TRACE_ID_OFFSET + i] = traceId.charAt(i);
            }
            chars[SPAN_ID_OFFSET - 1] = TRACEPARENT_DELIMITER;
            String spanId = context.spanIdString();
            for (int i = 0; i < spanId.length(); i++) {
                chars[SPAN_ID_OFFSET + i] = spanId.charAt(i);
            }
            chars[TRACE_OPTION_OFFSET - 1] = TRACEPARENT_DELIMITER;
            copyTraceFlagsHexTo(chars, TRACE_OPTION_OFFSET, context);
            setter.put(carrier, TRACEPARENT, new String(chars, 0, TRACEPARENT_HEADER_SIZE));
            addTraceState(setter, context, carrier);
            if (this.baggagePropagator != null) {
                this.baggagePropagator.injector(setter).inject(context, carrier);
            }
        };
    }

    private <R> void addTraceState(Setter<R, String> setter, TraceContext context, R carrier) {
        if (carrier != null && this.braveBaggageManager != null) {
            Baggage baggage = this.braveBaggageManager.getBaggage(BraveTraceContext.fromBrave(context), TRACESTATE);
            if (baggage == null) {
                return;
            }
            String traceState = baggage.get(BraveTraceContext.fromBrave(context));
            if (StringUtils.isNotBlank(traceState)) {
                setter.put(carrier, TRACESTATE, traceState);
            }
        }
    }

    private String padLeftWithZeros(String string, int length) {
        if (string.length() >= length) {
            return string;
        }
        else {
            StringBuilder sb = new StringBuilder(length);
            for (int i = string.length(); i < length; i++) {
                sb.append('0');
            }

            return sb.append(string).toString();
        }
    }

    void copyTraceFlagsHexTo(char[] dest, int destOffset, TraceContext context) {
        dest[destOffset] = '0';
        dest[destOffset + 1] = Boolean.TRUE.equals(context.sampled()) ? '1' : '0';
    }

    /**
     * <h3>This does not set the shared flag when extracting headers</h3>
     *
     * <p>
     * {@link brave.propagation.TraceContext#shared()} is not set here because it is not a
     * remote propagation field. {@code shared} is a field in the Zipkin JSON v2 format
     * only set <em>after</em> header extraction, for {@link Span.Kind#SERVER} spans
     * implicitly via {@link brave.Tracer#joinSpan(TraceContext)}.
     *
     * <p>
     * Blindly setting {@code shared} regardless of this is harmful when
     * {@link Tracer#currentSpan()} or similar are used, as any data tagged with these
     * could also set the shared flag when reporting. Particularly, this can cause
     * problems for multi- {@linkplain Span.Kind#CONSUMER} spans. Regardless, setting
     * invalid flags add overhead.
     *
     * <p>
     * In summary, while {@code shared} is propagated in-process, it has never been
     * propagated out of process, and so should never be set when extracting headers.
     * Hence, this code will not set {@link brave.propagation.TraceContext#shared()}.
     */
    @Override
    public <R> TraceContext.Extractor<R> extractor(Getter<R, String> getter) {
        Objects.requireNonNull(getter, "getter");
        return carrier -> {
            String traceParent = getter.get(carrier, TRACEPARENT);
            if (traceParent == null) {
                return withBaggage(TraceContextOrSamplingFlags.EMPTY, carrier, getter);
            }
            TraceContext contextFromParentHeader = extractContextFromTraceParent(traceParent);
            if (contextFromParentHeader == null) {
                return withBaggage(TraceContextOrSamplingFlags.EMPTY, carrier, getter);
            }
            String traceStateHeader = getter.get(carrier, TRACESTATE);
            TraceContextOrSamplingFlags context = context(contextFromParentHeader, traceStateHeader);
            if (this.baggagePropagator == null || this.braveBaggageManager == null) {
                return context;
            }
            return withBaggage(context, carrier, getter);
        };
    }

    private <R> TraceContextOrSamplingFlags withBaggage(TraceContextOrSamplingFlags context, R carrier,
            Getter<R, String> getter) {
        if (context.context() == null) {
            return context;
        }
        return this.baggagePropagator.contextWithBaggage(carrier, context, getter);
    }

    TraceContextOrSamplingFlags context(TraceContext contextFromParentHeader, String traceStateHeader) {
        if (!StringUtils.isNotBlank(traceStateHeader)) {
            return TraceContextOrSamplingFlags.create(contextFromParentHeader);
        }
        try {
            return TraceContextOrSamplingFlags
                .newBuilder(TraceContext.newBuilder()
                    .traceId(contextFromParentHeader.traceId())
                    .traceIdHigh(contextFromParentHeader.traceIdHigh())
                    .spanId(contextFromParentHeader.spanId())
                    .sampled(contextFromParentHeader.sampled())
                    .build())
                .build();
        }
        catch (IllegalArgumentException e) {
            logger.info("Unparseable tracestate header. Returning span context without state.");
            return TraceContextOrSamplingFlags.create(contextFromParentHeader);
        }
    }

    static {
        // A valid version is 1 byte representing an 8-bit unsigned integer, version ff is
        // invalid.
        VALID_VERSIONS = new HashSet<>();
        for (int i = 0; i < 255; i++) {
            String version = Long.toHexString(i);
            if (version.length() < 2) {
                version = '0' + version;
            }
            VALID_VERSIONS.add(version);
        }
    }

}

/**
 * Taken from OpenTelemetry API.
 */
@SuppressWarnings("deprecation")
class W3CBaggagePropagator {

    private static final InternalLogger log = InternalLoggerFactory.getInstance(W3CBaggagePropagator.class);

    private static final String TRACE_STATE = "tracestate";

    private static final String FIELD = "baggage";

    private static final List<String> FIELDS = singletonList(FIELD);

    private final BaggageManager braveBaggageManager;

    private final List<String> localFields;

    W3CBaggagePropagator(BaggageManager baggageManager, List<String> localFields) {
        this.braveBaggageManager = baggageManager;
        this.localFields = localFields;
    }

    public List<String> keys() {
        return FIELDS;
    }

    public <R> TraceContext.Injector<R> injector(Propagation.Setter<R, String> setter) {
        return (context, carrier) -> {
            BaggageFields extra = context.findExtra(BaggageFields.class);
            if (extra == null || extra.getAllFields().isEmpty()) {
                return;
            }
            StringBuilder headerContent = new StringBuilder();
            // We ignore local keys - they won't get propagated
            String[] strings = this.localFields.toArray(new String[0]);
            Map<String, String> filtered = extra.toMapFilteringFieldNames(strings);
            for (Map.Entry<String, String> entry : filtered.entrySet()) {
                if (TRACE_STATE.equalsIgnoreCase(entry.getKey())) {
                    continue;
                }
                headerContent.append(entry.getKey()).append("=").append(entry.getValue());
                // TODO: [OTEL] No metadata support
                // String metadataValue = entry.getEntryMetadata().getValue();
                // if (metadataValue != null && !metadataValue.isEmpty()) {
                // headerContent.append(";").append(metadataValue);
                // }
                headerContent.append(",");
            }
            if (headerContent.length() > 0) {
                headerContent.setLength(headerContent.length() - 1);
                setter.put(carrier, FIELD, headerContent.toString());
            }
        };
    }

    <R> TraceContextOrSamplingFlags contextWithBaggage(R carrier, TraceContextOrSamplingFlags flags,
            Propagation.Getter<R, String> getter) {
        String baggageHeader = getter.get(carrier, FIELD);
        List<AbstractMap.SimpleEntry<Baggage, String>> pairs = baggageHeader == null || baggageHeader.isEmpty()
                ? Collections.emptyList() : addBaggageToContext(baggageHeader);
        return flags.toBuilder().addExtra(new BraveBaggageFields(pairs)).build();
    }

    List<AbstractMap.SimpleEntry<Baggage, String>> addBaggageToContext(String baggageHeader) {
        List<AbstractMap.SimpleEntry<Baggage, String>> pairs = new ArrayList<>();
        String[] entries = baggageHeader.split(",");
        for (String entry : entries) {
            int beginningOfMetadata = entry.indexOf(";");
            if (beginningOfMetadata > 0) {
                entry = entry.substring(0, beginningOfMetadata);
            }
            String[] keyAndValue = entry.split("=");
            for (int i = 0; i < keyAndValue.length; i += 2) {
                try {
                    String key = keyAndValue[i].trim();
                    String value = keyAndValue[i + 1].trim();
                    Baggage baggage = this.braveBaggageManager.createBaggage(key);
                    pairs.add(new AbstractMap.SimpleEntry<>(baggage, value));
                }
                catch (Exception e) {
                    if (log.isDebugEnabled()) {
                        log.debug("Exception occurred while trying to parse baggage with key value ["
                                + Arrays.toString(keyAndValue) + "]. Will ignore that entry.", e);
                    }
                }
            }
        }
        return pairs;
    }

}

/**
 * Taken from OpenTelemetry API.
 * <p>
 * {@link ThreadLocal} buffers for use when creating new derived objects such as
 * {@link String}s. These buffers are reused within a single thread - it is _not safe_ to
 * use the buffer to generate multiple derived objects at the same time because the same
 * memory will be used. In general, you should get a temporary buffer, fill it with data,
 * and finish by converting into the derived object within the same method to avoid
 * multiple usages of the same buffer.
 */
final class TemporaryBuffers {

    private static final ThreadLocal<char[]> CHAR_ARRAY = new ThreadLocal<>();

    private TemporaryBuffers() {
    }

    /**
     * A {@link ThreadLocal} {@code char[]} of size {@code len}. Take care when using a
     * large value of {@code len} as this buffer will remain for the lifetime of the
     * thread. The returned buffer will not be zeroed and may be larger than the requested
     * size, you must make sure to fill the entire content to the desired value and set
     * the length explicitly when converting to a {@link String}.
     */
    public static char[] chars(int len) {
        char[] buffer = CHAR_ARRAY.get();
        if (buffer == null) {
            buffer = new char[len];
            CHAR_ARRAY.set(buffer);
        }
        else if (buffer.length < len) {
            buffer = new char[len];
            CHAR_ARRAY.set(buffer);
        }
        return buffer;
    }

    // Visible for testing
    static void clearChars() {
        CHAR_ARRAY.set(null);
    }

}

/**
 * Taken from OpenTelemetry API.
 */
final class TraceFlags {

    // Bit to represent whether trace is sampled or not.
    static final byte IS_SAMPLED = 0x1;

    private TraceFlags() {
    }

    /**
     * Extract the byte representation of the flags from a hex-representation.
     */
    static byte byteFromHex(CharSequence src, int srcOffset) {
        return EncodingUtils.byteFromBase16String(src, srcOffset);
    }

}
