Java实现双数组Trie树(DoubleArrayTrie,DAT)

开发 后端
传统的Trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。双数组Trie就是优化了空间的Trie树,原理本文就不讲了,请参考An Efficient Implementation of Trie Structures,本程序的编写也是参考这篇论文的。

传统的Trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。

双数组Trie就是优化了空间的Trie树,原理本文就不讲了,请参考An Efficient Implementation of Trie Structures,本程序的编写也是参考这篇论文的。

关于几点论文没有提及的细节和与论文不一一致的实现:

1.对于插入字符串,如果有一个字符串是另一个字符串的子串的话,我是将结束符也作为一条边,产生一个新的结点,这个结点新节点的Base我置为0

所以一个字符串结束也有2中情况:一个是Base值为负,存储剩余字符(可能只有一个结束符)到Tail数组;另一个是Base为0。

所以在查询的时候要考虑一下这两种情况

2.对于***种冲突(论文中的Case 3),可能要将Tail中的字符串取出一部分,作为边放到索引中。论文是使用将尾串左移的方式,我的方式直接修改Base值,而不是移动尾串。

下面是java实现的代码,可以处理相同字符串插入,子串的插入等情况

  1. /*  
  2.  * Name:   Double Array Trie  
  3.  * Author: Yaguang Ding  
  4.  * Mail: dingyaguang117@gmail.com  
  5.  * Blog: blog.csdn.net/dingyaguang117  
  6.  * Date:   2012/5/21  
  7.  * Note: a word ends may be either of these two case:  
  8.  * 1. Base[cur_p] == pos  ( pos<0 and Tail[-pos] == 'END_CHAR' )  
  9.  * 2. Check[Base[cur_p] + Code('END_CHAR')] ==  cur_p  
  10.  */ 
  11.  
  12.  
  13. import java.util.ArrayList;  
  14. import java.util.HashMap;  
  15. import java.util.Map;  
  16. import java.util.Arrays;  
  17.  
  18.  
  19. public class DoubleArrayTrie {  
  20.     final char END_CHAR = '\0';  
  21.     final int DEFAULT_LEN = 1024;  
  22.     int Base[]  = new int [DEFAULT_LEN];  
  23.     int Check[] = new int [DEFAULT_LEN];  
  24.     char Tail[] = new char [DEFAULT_LEN];  
  25.     int Pos = 1;  
  26.     Map<Character ,Integer> CharMap = new HashMap<Character,Integer>();  
  27.     ArrayList<Character> CharList = new ArrayList<Character>();  
  28.       
  29.     public DoubleArrayTrie()  
  30.     {  
  31.         Base[1] = 1;  
  32.           
  33.         CharMap.put(END_CHAR,1);  
  34.         CharList.add(END_CHAR);  
  35.         CharList.add(END_CHAR);  
  36.         for(int i=0;i<26;++i)  
  37.         {  
  38.             CharMap.put((char)('a'+i),CharMap.size()+1);  
  39.             CharList.add((char)('a'+i));  
  40.         }  
  41.           
  42.     }  
  43.     private void Extend_Array()  
  44.     {  
  45.         Base = Arrays.copyOf(Base, Base.length*2);  
  46.         Check = Arrays.copyOf(Check, Check.length*2);  
  47.     }  
  48.       
  49.     private void Extend_Tail()  
  50.     {  
  51.         Tail = Arrays.copyOf(Tail, Tail.length*2);  
  52.     }  
  53.       
  54.     private int GetCharCode(char c)  
  55.     {  
  56.         if (!CharMap.containsKey(c))  
  57.         {  
  58.             CharMap.put(c,CharMap.size()+1);  
  59.             CharList.add(c);  
  60.         }  
  61.         return CharMap.get(c);  
  62.     }  
  63.     private int CopyToTailArray(String s,int p)  
  64.     {  
  65.         int _Pos = Pos;  
  66.         while(s.length()-p+1 > Tail.length-Pos)  
  67.         {  
  68.             Extend_Tail();  
  69.         }  
  70.         for(int i=p; i<s.length();++i)  
  71.         {  
  72.             Tail[_Pos] = s.charAt(i);  
  73.             _Pos++;  
  74.         }  
  75.         return _Pos;  
  76.     }  
  77.       
  78.     private int x_check(Integer []set)  
  79.     {  
  80.         for(int i=1; ; ++i)  
  81.         {  
  82.             boolean flag = true;  
  83.             for(int j=0;j<set.length;++j)  
  84.             {  
  85.                 int cur_p = i+set[j];  
  86.                 if(cur_p>= Base.length) Extend_Array();  
  87.                 if(Base[cur_p]!= 0 || Check[cur_p]!= 0)  
  88.                 {  
  89.                     flag = false;  
  90.                     break;  
  91.                 }  
  92.             }  
  93.             if (flag) return i;  
  94.         }  
  95.     }  
  96.       
  97.     private ArrayList<Integer> GetChildList(int p)  
  98.     {  
  99.         ArrayList<Integer> ret = new ArrayList<Integer>();  
  100.         for(int i=1; i<=CharMap.size();++i)  
  101.         {  
  102.             if(Base[p]+i >= Check.length) break;  
  103.             if(Check[Base[p]+i] == p)  
  104.             {  
  105.                 ret.add(i);  
  106.             }  
  107.         }  
  108.         return ret;  
  109.     }  
  110.       
  111.     private boolean TailContainString(int start,String s2)  
  112.     {  
  113.         for(int i=0;i<s2.length();++i)  
  114.         {  
  115.             if(s2.charAt(i) != Tail[i+start]) return false;  
  116.         }  
  117.           
  118.         return true;  
  119.     }  
  120.     private boolean TailMatchString(int start,String s2)  
  121.     {  
  122.         s2 += END_CHAR;  
  123.         for(int i=0;i<s2.length();++i)  
  124.         {  
  125.             if(s2.charAt(i) != Tail[i+start]) return false;  
  126.         }  
  127.         return true;  
  128.     }  
  129.       
  130.       
  131.     public void Insert(String s) throws Exception  
  132.     {  
  133.         s += END_CHAR;  
  134.           
  135.         int pre_p = 1;  
  136.         int cur_p;  
  137.         for(int i=0; i<s.length(); ++i)  
  138.         {  
  139.             //获取状态位置  
  140.             cur_p = Base[pre_p]+GetCharCode(s.charAt(i));  
  141.             //如果长度超过现有,拓展数组  
  142.             if (cur_p >= Base.length) Extend_Array();  
  143.               
  144.             //空闲状态  
  145.             if(Base[cur_p] == 0 && Check[cur_p] == 0)  
  146.             {  
  147.                 Base[cur_p] = -Pos;  
  148.                 Check[cur_p] = pre_p;  
  149.                 Pos = CopyToTailArray(s,i+1);  
  150.                 break;  
  151.             }else 
  152.             //已存在状态  
  153.             if(Base[cur_p] > 0 && Check[cur_p] == pre_p)  
  154.             {  
  155.                 pre_p = cur_p;  
  156.                 continue;  
  157.             }else 
  158.             //冲突 1:遇到 Base[cur_p]小于0的,即遇到一个被压缩存到Tail中的字符串  
  159.             if(Base[cur_p] < 0 && Check[cur_p] == pre_p)  
  160.             {  
  161.                 int head = -Base[cur_p];  
  162.                   
  163.                 if(s.charAt(i+1)== END_CHAR && Tail[head]==END_CHAR)    //插入重复字符串  
  164.                 {  
  165.                     break;  
  166.                 }  
  167.                   
  168.                 //公共字母的情况,因为上一个判断已经排除了结束符,所以一定是2个都不是结束符  
  169.                 if (Tail[head] == s.charAt(i+1))  
  170.                 {  
  171.                     int avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1))});  
  172.                     Base[cur_p] = avail_base;  
  173.                       
  174.                     Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;  
  175.                     Base[avail_base+GetCharCode(s.charAt(i+1))] = -(head+1);  
  176.                     pre_p = cur_p;  
  177.                     continue;  
  178.                 }  
  179.                 else 
  180.                 {  
  181.                     //2个字母不相同的情况,可能有一个为结束符  
  182.                     int avail_base ;  
  183.                     avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1)),GetCharCode(Tail[head])});  
  184.                       
  185.                     Base[cur_p] = avail_base;  
  186.                       
  187.                     Check[avail_base+GetCharCode(Tail[head])] = cur_p;  
  188.                     Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;  
  189.                       
  190.                     //Tail 为END_FLAG 的情况  
  191.                     if(Tail[head] == END_CHAR)  
  192.                         Base[avail_base+GetCharCode(Tail[head])] = 0;  
  193.                     else 
  194.                         Base[avail_base+GetCharCode(Tail[head])] = -(head+1);  
  195.                     if(s.charAt(i+1) == END_CHAR)   
  196.                         Base[avail_base+GetCharCode(s.charAt(i+1))] = 0;  
  197.                     else 
  198.                         Base[avail_base+GetCharCode(s.charAt(i+1))] = -Pos;  
  199.                       
  200.                     Pos = CopyToTailArray(s,i+2);  
  201.                     break;  
  202.                 }  
  203.             }else 
  204.             //冲突2:当前结点已经被占用,需要调整pre的base  
  205.             if(Check[cur_p] != pre_p)  
  206.             {  
  207.                 ArrayList<Integer> list1 = GetChildList(pre_p);  
  208.                 int toBeAdjust;  
  209.                 ArrayList<Integer> list = null;  
  210.                 if(true)  
  211.                 {  
  212.                     toBeAdjust = pre_p;  
  213.                     list = list1;  
  214.                 }  
  215.                   
  216.                 int origin_base = Base[toBeAdjust];  
  217.                 list.add(GetCharCode(s.charAt(i)));  
  218.                 int avail_base = x_check((Integer[])list.toArray(new Integer[list.size()]));  
  219.                 list.remove(list.size()-1);  
  220.                   
  221.                 Base[toBeAdjust] = avail_base;  
  222.                 for(int j=0; j<list.size(); ++j)  
  223.                 {  
  224.                     //BUG   
  225.                     int tmp1 = origin_base + list.get(j);  
  226.                     int tmp2 = avail_base + list.get(j);  
  227.                       
  228.                     Base[tmp2] = Base[tmp1];  
  229.                     Check[tmp2] = Check[tmp1];  
  230.                       
  231.                     //有后续  
  232.                     if(Base[tmp1] > 0)  
  233.                     {  
  234.                         ArrayList<Integer> subsequence = GetChildList(tmp1);  
  235.                         for(int k=0; k<subsequence.size(); ++k)  
  236.                         {  
  237.                             Check[Base[tmp1]+subsequence.get(k)] = tmp2;  
  238.                         }  
  239.                     }  
  240.                       
  241.                     Base[tmp1] = 0;  
  242.                     Check[tmp1] = 0;  
  243.                 }  
  244.                   
  245.                 //更新新的cur_p  
  246.                 cur_p = Base[pre_p]+GetCharCode(s.charAt(i));  
  247.                   
  248.                 if(s.charAt(i) == END_CHAR)  
  249.                     Base[cur_p] = 0;  
  250.                 else 
  251.                     Base[cur_p] = -Pos;  
  252.                 Check[cur_p] = pre_p;  
  253.                 Pos = CopyToTailArray(s,i+1);  
  254.                 break;  
  255.             }  
  256.         }  
  257.     }  
  258.       
  259.     public boolean Exists(String word)  
  260.     {  
  261.         int pre_p = 1;  
  262.         int cur_p = 0;  
  263.           
  264.         for(int i=0;i<word.length();++i)  
  265.         {  
  266.             cur_p = Base[pre_p]+GetCharCode(word.charAt(i));  
  267.             if(Check[cur_p] != pre_p) return false;  
  268.             if(Base[cur_p] < 0)  
  269.             {  
  270.                 if(TailMatchString(-Base[cur_p],word.substring(i+1)))  
  271.                     return true;  
  272.                 return false;  
  273.             }  
  274.             pre_p = cur_p;  
  275.         }  
  276.         if(Check[Base[cur_p]+GetCharCode(END_CHAR)] == cur_p)  
  277.             return true;  
  278.         return false;  
  279.     }  
  280.       
  281.     //内部函数,返回匹配单词的最靠后的Base index,  
  282.     class FindStruct  
  283.     {  
  284.         int p;  
  285.         String prefix="";  
  286.     }  
  287.     private FindStruct Find(String word)  
  288.     {  
  289.         int pre_p = 1;  
  290.         int cur_p = 0;  
  291.         FindStruct fs = new FindStruct();  
  292.         for(int i=0;i<word.length();++i)  
  293.         {  
  294.             // BUG  
  295.             fs.prefix += word.charAt(i);  
  296.             cur_p = Base[pre_p]+GetCharCode(word.charAt(i));  
  297.             if(Check[cur_p] != pre_p)  
  298.             {  
  299.                 fs.p = -1;  
  300.                 return fs;  
  301.             }  
  302.             if(Base[cur_p] < 0)  
  303.             {  
  304.                 if(TailContainString(-Base[cur_p],word.substring(i+1)))  
  305.                 {  
  306.                     fs.p = cur_p;  
  307.                     return fs;  
  308.                 }  
  309.                 fs.p = -1;  
  310.                 return fs;  
  311.             }  
  312.             pre_p = cur_p;  
  313.         }  
  314.         fs.p =  cur_p;  
  315.         return fs;  
  316.     }  
  317.       
  318.     public ArrayList<String> GetAllChildWord(int index)  
  319.     {  
  320.         ArrayList<String> result = new ArrayList<String>();  
  321.         if(Base[index] == 0)  
  322.         {  
  323.             result.add("");  
  324.             return result;  
  325.         }  
  326.         if(Base[index] < 0)  
  327.         {  
  328.             String r="";  
  329.             for(int i=-Base[index];Tail[i]!=END_CHAR;++i)  
  330.             {  
  331.                 r+= Tail[i];  
  332.             }  
  333.             result.add(r);  
  334.             return result;  
  335.         }  
  336.         for(int i=1;i<=CharMap.size();++i)  
  337.         {  
  338.             if(Check[Base[index]+i] == index)  
  339.             {  
  340.                 for(String s:GetAllChildWord(Base[index]+i))  
  341.                 {  
  342.                     result.add(CharList.get(i)+s);  
  343.                 }  
  344.                 //result.addAll(GetAllChildWord(Base[index]+i));  
  345.             }  
  346.         }  
  347.         return result;  
  348.     }  
  349.       
  350.     public ArrayList<String> FindAllWords(String word)  
  351.     {  
  352.         ArrayList<String> result = new ArrayList<String>();  
  353.         String prefix = "";  
  354.         FindStruct fs = Find(word);  
  355.         int p = fs.p;  
  356.         if (p == -1return result;  
  357.         if(Base[p]<0)  
  358.         {  
  359.             String r="";  
  360.             for(int i=-Base[p];Tail[i]!=END_CHAR;++i)  
  361.             {  
  362.                 r+= Tail[i];  
  363.             }  
  364.             result.add(fs.prefix+r);  
  365.             return result;  
  366.         }  
  367.           
  368.         if(Base[p] > 0)  
  369.         {  
  370.             ArrayList<String> r =  GetAllChildWord(p);  
  371.             for(int i=0;i<r.size();++i)  
  372.             {  
  373.                 r.set(i, fs.prefix+r.get(i));  
  374.             }  
  375.             return r;  
  376.         }  
  377.           
  378.         return result;  
  379.     }  
  380.       

测  试

  1. import java.io.BufferedReader;  
  2. import java.io.FileInputStream;  
  3. import java.io.IOException;  
  4. import java.io.InputStream;  
  5. import java.io.InputStreamReader;  
  6. import java.util.ArrayList;  
  7. import java.util.Scanner;  
  8.  
  9. import javax.xml.crypto.Data;  
  10.  
  11.  
  12. public class Main {  
  13.  
  14.     public static void main(String[] args) throws Exception {  
  15.         ArrayList<String> words = new ArrayList<String>();  
  16.         BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream("E:/兔子的试验学习中心[课内]/ACM大赛/ACM第四届校赛/E命令提示/words3.dic")));  
  17.         String s;  
  18.         int num = 0;  
  19.         while((s=reader.readLine()) != null)  
  20.         {  
  21.             words.add(s);  
  22.             num ++;  
  23.         }  
  24.         DoubleArrayTrie dat = new DoubleArrayTrie();  
  25.           
  26.         for(String word: words)  
  27.         {  
  28.             dat.Insert(word);  
  29.         }  
  30.           
  31.         System.out.println(dat.Base.length);  
  32.         System.out.println(dat.Tail.length);  
  33.           
  34.         Scanner sc = new Scanner(System.in);  
  35.         while(sc.hasNext())  
  36.         {  
  37.             String word = sc.next();  
  38.             System.out.println(dat.Exists(word));  
  39.             System.out.println(dat.FindAllWords(word));  
  40.         }  
  41.           
  42.     }  
  43.  
  44. }  

下面是测试结果,构造6W英文单词的DAT,大概需要20秒

 

我增长数组的时候是每次长度增加到2倍,初始1024

Base和Check数组的长度为131072

Tail的长度为262144

原文地址:Java实现双数组Trie树(DoubleArrayTrie,DAT)

责任编辑:林师授 来源: dingyaguang117博客
相关推荐

2021-06-30 17:38:03

Trie 树字符Java

2020-10-30 09:56:59

Trie树之美

2021-06-04 10:18:03

Trie字典树数据

2012-09-25 09:19:26

Spring数据库双数据库

2022-09-14 07:59:27

字典树Trie基数树

2017-09-06 10:55:19

Java

2016-12-08 11:01:39

红黑树Java

2022-10-28 09:10:40

数据结构字典树

2012-04-09 16:22:43

C#

2010-10-27 17:00:32

oracle树查询

2009-11-16 16:17:45

PHP数组排序

2021-05-12 19:19:44

字典树数据结构

2024-11-12 08:00:00

LSM树GolangMemTable

2023-01-09 18:15:21

数组Python类型

2012-01-06 15:18:53

Java

2023-09-27 09:39:08

Java优化

2009-08-13 10:35:05

Scala数组排序

2014-12-10 10:02:14

华为银行数据中心网络

2021-09-07 11:01:41

二叉搜索树序数组

2009-05-07 13:36:38

Java静态数组动态数组
点赞
收藏

51CTO技术栈公众号