/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.indices;

import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.LRUQueryCache;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryCache;
import org.apache.lucene.search.QueryCachingPolicy;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.UsageTrackingQueryCachingPolicy;
import org.apache.lucene.search.Weight;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.lucene.ShardCoreKeyMap;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.cache.query.QueryCacheStats;

@PublicApi(since="1.0.0")
public class IndicesQueryCache
implements QueryCache,
Closeable {
    private static final Logger logger = LogManager.getLogger(IndicesQueryCache.class);
    public static final Setting<ByteSizeValue> INDICES_CACHE_QUERY_SIZE_SETTING = Setting.memorySizeSetting("indices.queries.cache.size", "10%", Setting.Property.NodeScope);
    public static final Setting<Integer> INDICES_CACHE_QUERY_COUNT_SETTING = Setting.intSetting("indices.queries.cache.count", 10000, 1, Setting.Property.NodeScope);
    public static final Setting<Boolean> INDICES_QUERIES_CACHE_ALL_SEGMENTS_SETTING = Setting.boolSetting("indices.queries.cache.all_segments", false, Setting.Property.NodeScope);
    public static final Setting<Float> INDICES_QUERIES_CACHE_SKIP_CACHE_FACTOR = Setting.floatSetting("indices.queries.cache.skip_cache_factor", 10.0f, 1.0f, Setting.Property.NodeScope, Setting.Property.Dynamic);
    public static final Setting<Integer> INDICES_QUERY_CACHE_MIN_FREQUENCY = Setting.intSetting("indices.queries.cache.min_frequency", 5, 1, Setting.Property.NodeScope, Setting.Property.Dynamic);
    public static final Setting<Integer> INDICES_QUERY_CACHE_COSTLY_MIN_FREQUENCY = Setting.intSetting("indices.queries.cache.costly_min_frequency", 2, 1, Setting.Property.NodeScope, Setting.Property.Dynamic);
    private final LRUQueryCache cache;
    private final ShardCoreKeyMap shardKeyMap = new ShardCoreKeyMap();
    private final Map<ShardId, Stats> shardStats = new ConcurrentHashMap<ShardId, Stats>();
    private volatile long sharedRamBytesUsed;
    private final Map<Object, StatsAndCount> stats2 = Collections.synchronizedMap(new IdentityHashMap());

    public IndicesQueryCache(Settings settings) {
        this(settings, null);
    }

    public IndicesQueryCache(Settings settings, ClusterSettings clusterSettings) {
        ByteSizeValue size = INDICES_CACHE_QUERY_SIZE_SETTING.get(settings);
        int count = INDICES_CACHE_QUERY_COUNT_SETTING.get(settings);
        float skipCacheFactor = INDICES_QUERIES_CACHE_SKIP_CACHE_FACTOR.get(settings).floatValue();
        logger.debug("using [node] query cache with size [{}] max filter count [{}] skipCacheFactor [{}]", (Object)size, (Object)count, (Object)Float.valueOf(skipCacheFactor));
        if (INDICES_QUERIES_CACHE_ALL_SEGMENTS_SETTING.get(settings).booleanValue()) {
            this.cache = new OpenSearchLRUQueryCache(count, size.getBytes(), context -> true, 1.0f);
        } else {
            this.cache = new OpenSearchLRUQueryCache(count, size.getBytes());
            this.cache.setSkipCacheFactor(skipCacheFactor);
            if (clusterSettings != null) {
                clusterSettings.addSettingsUpdateConsumer(INDICES_QUERIES_CACHE_SKIP_CACHE_FACTOR, this::setSkipCacheFactor);
            } else {
                logger.warn("clusterSettings is null, so {} is not dynamic", (Object)INDICES_QUERIES_CACHE_SKIP_CACHE_FACTOR.getKey());
            }
        }
        this.sharedRamBytesUsed = 0L;
    }

    public void setSkipCacheFactor(float skipCacheFactor) {
        logger.debug("set cluster settings {} {} -> {}", (Object)INDICES_QUERIES_CACHE_SKIP_CACHE_FACTOR.getKey(), (Object)Float.valueOf(this.cache.getSkipCacheFactor()), (Object)Float.valueOf(skipCacheFactor));
        this.cache.setSkipCacheFactor(skipCacheFactor);
    }

    public QueryCacheStats getStats(ShardId shard) {
        HashMap<ShardId, QueryCacheStats> stats = new HashMap<ShardId, QueryCacheStats>();
        for (Map.Entry<ShardId, Stats> entry : this.shardStats.entrySet()) {
            stats.put(entry.getKey(), entry.getValue().toQueryCacheStats());
        }
        QueryCacheStats shardStats = new QueryCacheStats();
        QueryCacheStats info = (QueryCacheStats)stats.get(shard);
        if (info == null) {
            info = new QueryCacheStats();
        }
        shardStats.add(info);
        if (stats.isEmpty()) {
            shardStats.add(new QueryCacheStats(this.sharedRamBytesUsed, 0L, 0L, 0L, 0L));
        } else {
            long totalSize = 0L;
            for (QueryCacheStats s : stats.values()) {
                totalSize += s.getCacheSize();
            }
            double weight = totalSize == 0L ? 1.0 / (double)stats.size() : (double)shardStats.getCacheSize() / (double)totalSize;
            long additionalRamBytesUsed = Math.round(weight * (double)this.sharedRamBytesUsed);
            shardStats.add(new QueryCacheStats(additionalRamBytesUsed, 0L, 0L, 0L, 0L));
        }
        return shardStats;
    }

    public Weight doCache(Weight weight, QueryCachingPolicy policy) {
        while (weight instanceof CachingWeightWrapper) {
            weight = ((CachingWeightWrapper)weight).in;
        }
        Weight in = this.cache.doCache(weight, policy);
        return new CachingWeightWrapper(in);
    }

    public void clearIndex(String index) {
        Set<Object> coreCacheKeys = this.shardKeyMap.getCoreKeysForIndex(index);
        for (Object coreKey : coreCacheKeys) {
            this.cache.clearCoreCacheKey(coreKey);
        }
        if (this.cache.getCacheSize() == 0L) {
            this.cache.clear();
        }
    }

    @Override
    public void close() {
        assert (this.shardKeyMap.size() == 0) : this.shardKeyMap.size();
        assert (this.shardStats.isEmpty()) : this.shardStats.keySet();
        assert (this.stats2.isEmpty()) : this.stats2;
        this.cache.clear();
    }

    private boolean empty(Stats stats) {
        if (stats == null) {
            return true;
        }
        return stats.cacheSize == 0L && stats.ramBytesUsed == 0L;
    }

    public void onClose(ShardId shardId) {
        assert (this.empty(this.shardStats.get(shardId)));
        this.shardStats.remove(shardId);
    }

    private class OpenSearchLRUQueryCache
    extends LRUQueryCache {
        OpenSearchLRUQueryCache(int maxSize, long maxRamBytesUsed, Predicate<LeafReaderContext> leavesToCache, float skipFactor) {
            super(maxSize, maxRamBytesUsed, leavesToCache, skipFactor);
        }

        OpenSearchLRUQueryCache(int maxSize, long maxRamBytesUsed) {
            super(maxSize, maxRamBytesUsed);
        }

        private Stats getStats(Object coreKey) {
            ShardId shardId = IndicesQueryCache.this.shardKeyMap.getShardId(coreKey);
            if (shardId == null) {
                return null;
            }
            return IndicesQueryCache.this.shardStats.get(shardId);
        }

        private Stats getOrCreateStats(Object coreKey) {
            ShardId shardId = IndicesQueryCache.this.shardKeyMap.getShardId(coreKey);
            Stats stats = IndicesQueryCache.this.shardStats.get(shardId);
            if (stats == null) {
                stats = new Stats(shardId);
                IndicesQueryCache.this.shardStats.put(shardId, stats);
            }
            return stats;
        }

        protected void onClear() {
            super.onClear();
            for (Stats stats : IndicesQueryCache.this.shardStats.values()) {
                stats.cacheSize = 0L;
                stats.ramBytesUsed = 0L;
            }
            IndicesQueryCache.this.stats2.clear();
            IndicesQueryCache.this.sharedRamBytesUsed = 0L;
        }

        protected void onQueryCache(Query filter, long ramBytesUsed) {
            super.onQueryCache(filter, ramBytesUsed);
            IndicesQueryCache.this.sharedRamBytesUsed += ramBytesUsed;
        }

        protected void onQueryEviction(Query filter, long ramBytesUsed) {
            super.onQueryEviction(filter, ramBytesUsed);
            IndicesQueryCache.this.sharedRamBytesUsed -= ramBytesUsed;
        }

        protected void onDocIdSetCache(Object readerCoreKey, long ramBytesUsed) {
            super.onDocIdSetCache(readerCoreKey, ramBytesUsed);
            Stats shardStats = this.getOrCreateStats(readerCoreKey);
            ++shardStats.cacheSize;
            ++shardStats.cacheCount;
            shardStats.ramBytesUsed += ramBytesUsed;
            StatsAndCount statsAndCount = IndicesQueryCache.this.stats2.get(readerCoreKey);
            if (statsAndCount == null) {
                statsAndCount = new StatsAndCount(shardStats);
                IndicesQueryCache.this.stats2.put(readerCoreKey, statsAndCount);
            }
            ++statsAndCount.count;
        }

        protected void onDocIdSetEviction(Object readerCoreKey, int numEntries, long sumRamBytesUsed) {
            super.onDocIdSetEviction(readerCoreKey, numEntries, sumRamBytesUsed);
            if (numEntries > 0) {
                StatsAndCount statsAndCount = IndicesQueryCache.this.stats2.get(readerCoreKey);
                Stats shardStats = statsAndCount.stats;
                shardStats.cacheSize -= (long)numEntries;
                shardStats.ramBytesUsed -= sumRamBytesUsed;
                statsAndCount.count -= numEntries;
                if (statsAndCount.count == 0) {
                    IndicesQueryCache.this.stats2.remove(readerCoreKey);
                }
            }
        }

        protected void onHit(Object readerCoreKey, Query filter) {
            super.onHit(readerCoreKey, filter);
            Stats shardStats = this.getStats(readerCoreKey);
            ++shardStats.hitCount;
        }

        protected void onMiss(Object readerCoreKey, Query filter) {
            super.onMiss(readerCoreKey, filter);
            Stats shardStats = this.getOrCreateStats(readerCoreKey);
            ++shardStats.missCount;
        }
    }

    private static class Stats
    implements Cloneable {
        final ShardId shardId;
        volatile long ramBytesUsed;
        volatile long hitCount;
        volatile long missCount;
        volatile long cacheCount;
        volatile long cacheSize;

        Stats(ShardId shardId) {
            this.shardId = shardId;
        }

        QueryCacheStats toQueryCacheStats() {
            return new QueryCacheStats(this.ramBytesUsed, this.hitCount, this.missCount, this.cacheCount, this.cacheSize);
        }

        public String toString() {
            return "{shardId=" + String.valueOf(this.shardId) + ", ramBytedUsed=" + this.ramBytesUsed + ", hitCount=" + this.hitCount + ", missCount=" + this.missCount + ", cacheCount=" + this.cacheCount + ", cacheSize=" + this.cacheSize + "}";
        }
    }

    private class CachingWeightWrapper
    extends Weight {
        private final Weight in;

        protected CachingWeightWrapper(Weight in) {
            super(in.getQuery());
            this.in = in;
        }

        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            IndicesQueryCache.this.shardKeyMap.add(context.reader());
            return this.in.explain(context, doc);
        }

        public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
            IndicesQueryCache.this.shardKeyMap.add(context.reader());
            return this.in.scorerSupplier(context);
        }

        public int count(LeafReaderContext context) throws IOException {
            IndicesQueryCache.this.shardKeyMap.add(context.reader());
            return this.in.count(context);
        }

        public boolean isCacheable(LeafReaderContext ctx) {
            return this.in.isCacheable(ctx);
        }
    }

    public static class OpenseachUsageTrackingQueryCachingPolicy
    extends UsageTrackingQueryCachingPolicy {
        private volatile int minFrequency;
        private volatile int minFrequencyForCostly;

        public OpenseachUsageTrackingQueryCachingPolicy(ClusterSettings clusterSettings) {
            this.minFrequency = clusterSettings.get(INDICES_QUERY_CACHE_MIN_FREQUENCY);
            this.minFrequencyForCostly = clusterSettings.get(INDICES_QUERY_CACHE_COSTLY_MIN_FREQUENCY);
            clusterSettings.addSettingsUpdateConsumer(INDICES_QUERY_CACHE_MIN_FREQUENCY, this::setMinFrequency);
            clusterSettings.addSettingsUpdateConsumer(INDICES_QUERY_CACHE_COSTLY_MIN_FREQUENCY, this::setMinFrequencyForCostly);
        }

        protected int minFrequencyToCache(Query query) {
            if (this.isCostly(query)) {
                return this.minFrequencyForCostly;
            }
            int minFrequency = this.minFrequency;
            if (query instanceof BooleanQuery || query instanceof DisjunctionMaxQuery) {
                --minFrequency;
            }
            return Math.max(1, minFrequency);
        }

        private boolean isCostly(Query query) {
            return query instanceof MultiTermQuery || query.getClass().getSimpleName().equals("MultiTermQueryConstantScoreBlendedWrapper") || query.getClass().getSimpleName().equals("MultiTermQueryConstantScoreWrapper") || this.isPointQuery(query);
        }

        private boolean isPointQuery(Query query) {
            for (Class<?> clazz = query.getClass(); clazz != Query.class; clazz = clazz.getSuperclass()) {
                String simpleName = clazz.getSimpleName();
                if (!simpleName.startsWith("Point") || !simpleName.endsWith("Query")) continue;
                return true;
            }
            return false;
        }

        public void setMinFrequency(int minFrequency) {
            this.minFrequency = minFrequency;
        }

        public void setMinFrequencyForCostly(int minFrequencyForCostly) {
            this.minFrequencyForCostly = minFrequencyForCostly;
        }
    }

    private static class StatsAndCount {
        volatile int count;
        final Stats stats;

        StatsAndCount(Stats stats) {
            this.stats = stats;
            this.count = 0;
        }

        public String toString() {
            return "{stats=" + String.valueOf(this.stats) + " ,count=" + this.count + "}";
        }
    }
}

