2020年4月11日 星期六

Leetcode題解 Python:Find All Good Strings

這題相當困難,從通過的人數就知道。為什麼困難?因為這需要使用 KMP + 數位DP。

如果不清楚 KMP 跟 數位DP 的人,請先去看我之前的入門級介紹。本篇主要參考

給兩個字串 s1, s2,同樣是 n 大小,給一字串 evil,回傳在 s1, s2 區間內不含 evil 的字串數。

這題目的敘述,便是典型的數位DP。

因此我們先將數位DP的模型寫出來。

       from functools import lru_cache

       @lru_cache(None) #紀錄傳過的參數與結果,若下次回入一樣的參數,便可以直接回傳結果。
        def dfs(pos, stats, bound):
            if stats == np: return 0                
            if pos == n: return 1

            l = ord(s1[pos]) if bound & 1 else ord("a")
            r = ord(s2[pos]) if bound & 2 else ord("z")

            result = 0
            for u in range(l, r+1):
                char = chr(u)
                if bound == 3 and char == s1[pos] and char == s2[pos]:
                    nextBound = 3
                elif bound & 2 and char == s2[pos]:
                    nextBound = 2
                elif bound & 1 and char == s1[pos]:
                    nextBound = 1
                else:
                    nextBound = 0

                nextStats = search(stats, char) #此時尚未安排
                result += dfs(pos+1, nextStats, nextBound)                
                
            return result % (10**9+7) # 題目要求取餘
這裡沒有數字,只有a-z,就算沒有數字,也能用ord()把字母轉成unicode,把 a-z 當成二十六進位制,因此可以取出左右範圍。

接著要講 search() ,能不能使用暴力匹配法呢?

如果這樣使用,逐一匹配,過程中失敗後從 Target 的下一位開始從頭匹配。然而 數位DP 在過程中有部分匹配時,不論Target或Pattern都已經往下一位,萬一發生匹配失敗,Target是無法回到匹配開頭的下一位開始。

既然不能使Target回溯,暴力匹配法也會遇到阻礙,那有甚麼搜尋法是可以讓Target的索引一直遞增下去?使用KMP搜尋。

使用KMP,Target的索引會逐漸遞增到結尾,Target匹配過的部分就不需要再匹配,這能與數位DP結合上。

直接把KMP的模型套入,也決定了search()。
        np = len(evil)
        # 建立 prefixTable
        prefixTable = [0] * np
        for i in range(np):
            if i == 0:
                prefixTable[i] = -1
            else:
                pi = prefixTable[i-1]
                while pi >= -1:
                    if evil[i-1] == evil[pi]:
                        prefixTable[i] = pi + 1
                        break
                    else:
                        if  pi == -1:
                            prefixTable[i] = 0
                            break
                        pi = prefixTable[pi] 

        def search(stats, char):
            while stats > -1 and char != evil[stats]:
                stats = prefixTable[stats]
            return  stats +1 if char == evil[stats] else 0
將兩部分整合之後,可以得到一個完整代碼。
class Solution:
    def findGoodStrings(self, n: int, s1: str, s2: str, evil: str) -> int:
        from functools import lru_cache
        
        np = len(evil)
        # 建立 prefixTable
        prefixTable = [0] * np
        for i in range(np):
            if i == 0:
                prefixTable[i] = -1
            else:
                pi = prefixTable[i-1]
                while pi >= -1:
                    if evil[i-1] == evil[pi]:
                        prefixTable[i] = pi + 1
                        break
                    else:
                        if  pi == -1:
                            prefixTable[i] = 0
                            break
                        pi = prefixTable[pi] 

        def search(stats, char):
            while stats > -1 and char != evil[stats]:
                stats = prefixTable[stats]
            return  stats +1 if char == evil[stats] else 0

        @lru_cache(None)
        def dfs(pos, stats, bound):
            if stats == np: return 0                
            if pos == n: return 1

            l = ord(s1[pos]) if bound & 1 else ord("a")
            r = ord(s2[pos]) if bound & 2 else ord("z")

            result = 0
            for u in range(l, r+1):
                char = chr(u)
                if bound == 3 and char == s1[pos] and char == s2[pos]:
                    nextBound = 3
                elif bound & 2 and char == s2[pos]:
                    nextBound = 2
                elif bound & 1 and char == s1[pos]:
                    nextBound = 1
                else:
                    nextBound = 0

                nextStats = search(stats, char)
                result += dfs(pos+1, nextStats, nextBound)                
                
            return result % (10**9+7)

        return dfs(0, 0, 3)