LFU 缓存

目录

一个缓存结构需要实现如下功能:

void set(int key,int value):加入或者修改 key 对应的 value

int get(int key):查询 key 对应的 value 值

但是缓存最多放 K 条记录,如果新的 K + 1 条记录需要加入,就需要根据策略删掉一条记录,然后才能把新记录加入。

这个策略为:在缓存结构的 K 条记录中,哪一个 key 从进入缓存结构的时刻开始,被调用 set 或者 get 次数最少,就删掉这个key 的记录;如果调用次数最少的 key 有多个,上次调用发送最早的 key 被删除。

这个就是 LFU 缓存替换算法。实现这个结构,K 作为参数。

解法一:哈希表 + 有序表

缓存中的数据是 k,v 对,所以用 map 存储数据。由于缓存数据需要根据:使用次数和使用时间进行淘汰,如果遍历数据查找需要淘汰的数据,耗时比较高。因此需要维护一个有序结构,方便查找要淘汰的数据。

  • 有序数组:
    • 查找要淘汰数据的耗时为:O(1),删除后需要移动数据,耗时为:O(K)
    • 每次操作都需要更新 count 和 time 并维护数组的有序,此时也需要查找并移动数据,耗时为为:O(K)
  • 有序链表:
    • 查找要淘汰数据的耗时为:O(1),(比如头结点或者尾结点)
    • 更新操作时,查找对应节点的耗时为:O(K)
  • 有序表:
    • 查找并移除要淘汰的数据的耗时为:O(log K)
    • 更新操作的耗时时为:O(log K)
  • 小顶堆:
    • 查找并移除要淘汰的数据的耗时为:O(1),移除堆顶后需要堆化的耗时为:O(log K)。
    • 更新数据后,也需要堆化,耗时为:O(log K)。

时间复杂度:O(log K)

空间复杂度:O(K)

import java.util.*;

public class LFU1 {
    public static class Node implements Comparable<Node> {
        public int key;
        public int value;
        // 这个节点发生get或者set的次数总和
        public int count;
        // 最后一次操作的时间
        public int time;
        
        public Node(int key, int value, int count, int time) {
            this.key = key;
            this.value = value;
            this.count = count;
            this.time = time;
        }

        // 缓存淘汰优先级
        // 最少使用(count 越小越容易被淘汰)
        // count 相同,时间越早越容易被淘汰(time 越小越容易被淘汰)
        @Override
        public int compareTo(Node o) {
            return count == o.count ? Integer.compare(time, o.time) : Integer.compare(count, o.count);
        }

        @Override
        public String toString() {
            return "Node{" + "key=" + key + ", value=" + value + ", count=" + count + ", time=" + time + '}';
        }
    }

    public static class LFUCache {
        // 缓存过期优先级
        TreeSet<Node> set = new TreeSet<>();
        Map<Integer, Node> map = new HashMap<>();
        int capacity;
        int time = 0; // 用来计算缓存时间

        public LFUCache(int capacity) {
            this.capacity = Math.max(capacity, 0);
        }

        public Integer get(int key) {
            if (!map.containsKey(key)) {
                return null;
            }
            Node node = map.get(key);
            set(key, node.value);
            return node.value;
        }

        public void set(int key, int value) {
            this.time += 1;
            // 更新
            if (map.containsKey(key)) {
                Node node = map.get(key);
                // 删除再插入(node 的count 和 time 变化了,TreeSet 认为是不同的数据)
                set.remove(node);
                node.time = this.time;
                node.count++;
                node.value = value;
                set.add(node);
                map.put(key, node);
                return;
            }

            // 新增
            // 如果内存满了,淘汰一条旧数据
            if (this.capacity == this.map.size()) {
                remove();
            }
            Node node = new Node(key, value, 1, this.time);
            map.put(key, node);
            set.add(node);
        }

        public void remove() {
            if (map.size() == 0) {
                return;
            }
            Node node = set.first();
            map.remove(node.key);
            set.remove(node);
        }
    }
}

解法二:哈希表 + 小顶堆

将有序表更换为小顶堆。

删除数据时,heap.pop()

更新数据后,堆化:heap.heapify(node)。更新数据使得 time 和 count 变大,因此只需要从 node 节点开始向下堆化

时间复杂度:O(log K)

空间复杂度:O(K)

import java.util.*;

public class LFU3 {
    public static class Node {
        public int key;
        public int value;
        // 这个节点发生get或者set的次数总和
        public int count;
        // 最后一次操作的时间
        public int time;

        public Node(int key, int value, int count, int time) {
            this.key = key;
            this.value = value;
            this.count = count;
            this.time = time;
        }

        @Override
        public String toString() {
            return "Node{" + "key=" + key + ", value=" + value + ", count=" + count + ", time=" + time + '}';
        }
    }


    public static class NodeComparator implements Comparator<Node> {

        // 缓存淘汰优先级
        // 最少使用(count 越小越容易被淘汰)
        // count 相同,时间越早越容易被淘汰(time 越小越容易被淘汰)
        @Override
        public int compare(Node o1, Node o2) {
            return o1.count == o2.count ? Integer.compare(o1.time, o2.time) : Integer.compare(o1.count, o2.count);
        }
    }

    public static class LFUCache {
        // 缓存过期优先级
        HeapGreater<Node> heap = new HeapGreater<>(new NodeComparator());


        Map<Integer, Node> map = new HashMap<>();
        int capacity;
        int time = 0; // 用来计算缓存时间

        public LFUCache(int capacity) {
            this.capacity = Math.max(capacity, 0);
        }

        public Integer get(int key) {
            if (!map.containsKey(key)) {
                return null;
            }
            Node node = map.get(key);
            set(key, node.value);
            return node.value;
        }

        public void set(int key, int value) {
            this.time += 1;
            // 更新
            if (map.containsKey(key)) {
                Node node = map.get(key);
                // 删除再插入(node 的count 和 time 变化了,TreeSet 认为是不同的数据)
                node.time = this.time;
                node.count++;
                node.value = value;
                heap.heapify(node);
                map.put(key, node);
                return;
            }

            // 新增
            // 如果内存慢了,淘汰一条旧数据
            if (this.capacity == this.map.size()) {
                remove();
            }
            Node node = new Node(key, value, 1, this.time);
            map.put(key, node);
            heap.push(node);
        }

        public void remove() {
            if (map.size() == 0) {
                return;
            }
            Node node = heap.pop();
            map.remove(node.key);
        }
    }
}

加强堆的部分代码

import java.util.*;

public class HeapGreater<T> {
    private ArrayList<T> heap;
    private HashMap<T, Integer> indexMap;
    private int heapSize;
    private Comparator<? super T> comp;

    public HeapGreater(Comparator<? super T> c) {
        heap = new ArrayList<>();
        indexMap = new HashMap<>();
        heapSize = 0;
        comp = c;
    }

    public void push(T obj) {
        heap.add(obj);
        indexMap.put(obj, heapSize);
        heapInsert(heapSize++);
    }

    public T pop() {
        T ans = heap.get(0);
        swap(0, heapSize - 1);
        indexMap.remove(ans);
        heap.remove(--heapSize);
        heapify(0);
        return ans;
    }

    private void heapInsert(int index) {
        while (comp.compare(heap.get(index), heap.get((index - 1) / 2)) < 0) {
            swap(index, (index - 1) / 2);
            index = (index - 1) / 2;
        }
    }

    public void heapify(T obj) {
        heapify(indexMap.get(obj));
    }

    private void heapify(int index) {
        int left = index * 2 + 1;
        while (left < heapSize) {
            int best = left + 1 < heapSize && comp.compare(heap.get(left + 1), heap.get(left)) < 0 ? (left + 1) : left;
            best = comp.compare(heap.get(best), heap.get(index)) < 0 ? best : index;
            if (best == index) {
                break;
            }
            swap(best, index);
            index = best;
            left = index * 2 + 1;
        }
    }

    private void swap(int i, int j) {
        T o1 = heap.get(i);
        T o2 = heap.get(j);
        heap.set(i, o2);
        heap.set(j, o1);
        indexMap.put(o2, i);
        indexMap.put(o1, j);
    }
}

解法三:哈希表 + 二维双向链表

如下图就是一个二维双向链表。桶与桶之间是双向链表,桶内有一个双向链表。桶内双向链表上的数据拥有相同的操作次数,越靠近头部的数据,操作时间越近(从链表头部插入新数据,那么要过期数据从尾部移除)。所以要过期一个数据,删除操作数最小的桶(头桶)中链表的尾节点。

下图演示了桶之间链表和桶内链表的变化过程。

**原则:缺少操作数的桶,就新建桶。节点移走后出现空桶,将空桶删除。注意在这个过程中保持:桶之间的双向链表正确连接。**

调用 get(B) :B 的count 从 3 变为 4,没有操作数为 4 的桶,就新建桶,将新桶插入在桶 3 和 桶 5 之间。将节点 B 从桶 3 中移除,插入桶 4 中。

再次调用 get(B):B 的 count 从 4 变为 5,有操作数为 5 的桶,将节点 B 从桶 4 中移除,插入桶 5 (注意是头节点)中。桶 4 移除节点 B 后,成为空桶,删除空桶。将桶 3 与 桶 5 直接连接。

当需要删除一条过期数据时,我们需要在头桶中,删除桶内链表的尾节点(尾结点是最早操作的数据)。

时间复杂度:O(1)

空间复杂度:O(K)

import java.util.*;

public class LFU2 {
    public static class Node {
        public int key;
        public int value;
        // 这个节点发生get或者set的次数总和
        public int count;
        // 双向链表上一个节点
        public Node up;
        // 双向链表下一个节点
        public Node down;

        public Node(int key, int value, int count) {
            this.key = key;
            this.value = value;
            this.count = count;
        }

        @Override
        public String toString() {
            return "Node{" + "key=" + key + ", value=" + value + ", count=" + count + '}';
        }
    }

    public static class NodeList {
        // 桶内链表的头节点
        public Node head;
        // 桶内链表的尾节点
        public Node tail;
        // 桶之间的前一个桶
        public NodeList last;
        // 桶之间的后一个桶
        public NodeList next;


        public NodeList(Node node) {
            this.head = node;
            this.tail = node;
        }

        // 把一个新的节点加入这个桶,新的节点都放在顶端变成新的头部
        public void addNodeFromHead(Node newHead) {
            newHead.down = head;
            head.up = newHead;
            head = newHead;
        }

        // 判断这个桶是不是空的
        public boolean isEmpty() {
            return this.head == null;
        }

        // 删除 node 节点并保证 node 的上下环境重新连接
        public void deleteNode(Node node) {
            if (head == tail) {
                this.head = null;
                this.tail = null;
            } else if (node == head) {
                head = node.down;
                head.up = null;
            } else if (node == tail) {
                tail = node.up;
                tail.down = null;
            } else {
                node.up.down = node.down;
                node.down.up = node.up;
            }
            node.up = null;
            node.down = null;
        }
    }


    // 总得缓存结构
    public static class LFUCache {
        // 缓存的大小限制
        public int capacity;
        // 缓存中目前有多少个节点
        public int size = 0;

        public Map<Integer, Node> map = new HashMap<>();
        // 表示节点 node在 哪个桶里
        public Map<Node, NodeList> heads = new HashMap<>();
        // 整个桶链表的头节点
        private NodeList headList;

        public LFUCache(int k) {
            this.capacity = k;
        }

        // removeNodeList:刚刚减少了一个节点的桶
        // 这个函数的功能是,判断刚刚减少了一个节点的桶是不是已经空了。
        // 1)如果不空,什么也不做
        // 2)如果空了,removeNodeList 还是整个缓存结构最左的桶 (headList)。
        // 删掉这个桶的同时也要让最左的桶变成removeNodeList的下一个。
        // 3)如果空了,removeNodeList不是整个缓存结构最左的桶(headList)。
        // 把这个桶删除,并保证上一个的桶和下一个桶之间还是双向链表的连接方式
        // 函数的返回值表示刚刚减少了一个节点的桶是不是已经空了,空了返回true;不空返回false
        private boolean modifyHeadList(NodeList removeNodeList) {
            if (!removeNodeList.isEmpty()) {
                return false;
            }
            if (removeNodeList == headList) {
                headList = removeNodeList.next;
                if (headList != null) {
                    headList.last = null;
                }
            } else {
                removeNodeList.last.next = removeNodeList.next;
                if (removeNodeList.next != null) {
                    removeNodeList.next.last = removeNodeList.last;
                }
            }
            return true;
        }

        // node 这个节点的次数 +1 了,这个节点原来在 oldNodeList 里。
        // 把 node 从 oldNodeList 删掉,然后放到次数 +1 的桶中
        // 整个过程既要保证桶之间仍然是双向链表,也要保证节点之间仍然是双向链表
        private void move(Node node, NodeList oldNodeList) {
            // 从 oldNodeList 中删除
            oldNodeList.deleteNode(node);
            // preList表示次数 +1 的桶的前一个桶是谁
            // 如果 oldNodeList 删掉 node 之后还有节点(oldNodeList 不需要删除),oldNodeList 就是次数 +1 的桶的前一个桶
            // 如果 oldNodeList 删掉 node 之后空了,oldNodeList是需要删除的,所以次数 +1 的桶的前一个桶,是 oldNodeList 的前一个
            NodeList preList = modifyHeadList(oldNodeList) ? oldNodeList.last : oldNodeList;


            NodeList nextList = oldNodeList.next;
            // 如果 oldNodeList 没有后续了,那么肯定需要新建一个桶来盛放 node
            if (nextList == null) {
                NodeList newList = new NodeList(node);
                if (preList != null) {
                    preList.next = newList;
                }
                newList.last = preList;
                if (headList == null) {
                    headList = newList;
                }
                heads.put(node, newList);
            } else {
                // oldNodeList 有后续了,并且 oldNodeList 的后续count == node.count,直接将 node 添加到这个桶里。
                if (nextList.head.count == node.count) {
                    nextList.addNodeFromHead(node);
                    heads.put(node, nextList);
                } else {
                    // oldNodeList 的后续 count != node.count ,那么需要新建一个桶来放 node。
                    NodeList newList = new NodeList(node);
                    if (preList != null) {
                        preList.next = newList;
                    }
                    newList.last = preList;
                    newList.next = nextList;
                    nextList.last = newList;
                    if (headList == nextList) {
                        headList = newList;
                    }
                    heads.put(node, newList);
                }
            }
        }

        public void set(int key, int value) {
            // 更新
            if (map.containsKey(key)) {
                Node node = map.get(key);
                node.count++;
                node.value = value;
                move(node, heads.get(node));
            } else {
                // 新增
                // 如果缓存已满,需要淘汰一条旧数据
                if (size == capacity) {
                    // 从头部新增,从尾部删除。桶内双向链表的顺序,就是相同 count 的 time 值的排序。
                    // headList 是 count 值最小的桶,headList.tail 是 time 最小的节点。
                    Node node = headList.tail;
                    headList.deleteNode(node);
                    // 删除数据节点后,维护一下 桶,看看是否需要删除
                    modifyHeadList(headList);

                    map.remove(node.key);
                    heads.remove(node);
                    size--;
                }

                Node node = new Node(key, value, 1);
                if (headList == null) {
                    headList = new NodeList(node);
                } else {
                    // 新增节点 count = 1,如果 headList 的count 也是 1,直接将 node 加入 headList
                    if (headList.head.count == node.count) {
                        headList.addNodeFromHead(node);
                    } else {
                        // 如果如果 headList 的count 不是 1,需要新建一个 count = 1 的桶,作为 headList
                        NodeList newList = new NodeList(node);
                        newList.next = headList;
                        headList.last = newList;
                        headList = newList;
                    }
                }
                size++;
                map.put(key, node);
                heads.put(node, headList);
            }
        }

        public Integer get(int key) {
            if (!map.containsKey(key)) {
                return null;
            }

            Node node = map.get(key);
            node.count++;
            move(node, heads.get(node));
            return node.value;
        }
    }
}

对数器

        public static boolean check(LFU1.LFUCache lfu1, LFU2.LFUCache lfu2, LFU3.LFUCache lfu3) {
            if (lfu1.map.size() != lfu2.heads.size() || lfu1.map.size() != lfu3.map.size()) {
                return false;
            }

            for (int key : lfu1.map.keySet()) {
                if (!lfu2.map.containsKey(key) || !lfu3.map.containsKey(key)) {
                    return false;
                }
                Node node = lfu1.map.get(key);
                LFU2.Node node2 = lfu2.map.get(key);
                LFU2.Node node3 = lfu2.map.get(key);

                if (node.value != node2.value || node.count != node2.count || node.value != node3.value || node.count != node3.count) {
                    return false;
                }
            }
            return true;

        }

        public static void check() {
            LFU1.LFUCache lfu1 = new LFU1.LFUCache(3);
            LFU2.LFUCache lfu2 = new LFU2.LFUCache(3);
            LFU3.LFUCache lfu3 = new LFU3.LFUCache(3);

            for (int i = 0; i < 100000; i++) {
                int command = (int) (Math.random() * 3) % 2;
                int key = (int) (Math.random() * 10);
                int value = (int) (Math.random() * 10);

                if (command == 0) {
                    lfu1.set(key, value);
                    lfu2.set(key, value);
                    lfu3.set(key, value);
                } else {
                    lfu1.get(key);
                    lfu2.get(key);
                    lfu3.get(key);
                }
                if (!check(lfu1, lfu2, lfu3)) {
                    System.out.println("ERROR:res1:" + key);
                }
            }

            System.out.println("Nice");
        }

        public static void main(String[] args) {
            check();
        }