• 回溯框架总结


    什么是回溯算法

    其实回溯算法和我们常说的 DFS 算法非常类似,本质上就是一种暴力穷举算法。回溯算法和 DFS 算法的细微差别是:回溯算法是在遍历「树枝」,DFS 算法是在遍历「节点」,本文就是简单提一下,等你看到后文图论算法基础 时就能深刻理解这句话的含义了。
    废话不多说,直接上回溯算法框架,解决一个回溯问题,实际上就是一个决策树的遍历过程,站在回溯树的一个节点上,你只需要思考 3 个问题:
    1、路径:也就是已经做出的选择。
    2、选择列表:也就是你当前可以做的选择。
    3、结束条件:也就是到达决策树底层,无法再做选择的条件。

    代码方面,回溯算法的框架:

    result = []
    def backtrack(路径, 选择列表):
        if 满足结束条件:
            result.add(路径)
            return
        for 选择 in 选择列表:
        	
        	if(illegal()):(剪枝函数,如果当前选择列表当中的元素不合法,就跳过本轮循环)
        		continue;
            做选择
            backtrack(路径, 选择列表)
            撤销选择
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    其核心就是 for 循环里面的递归,在递归调用之前「做选择」,在递归调用之后「撤销选择」,特别简单。

    什么问题可以用回溯算法

    回溯算法多用于解决排列组合问题,遇到问题可以先尝试画递归树,如果问题能够用递归树的形式进行描述,那么就可以用回溯法来解决。
    值得一提的是,回溯算法还经常被用来遍历二维数组(岛屿问题、矩阵最长递增路径),如果你把二维矩阵中的每一个位置看做一个节点,这个节点的上下左右四个位置就是相邻节点,那么整个矩阵就可以抽象成一个树结构,注意,这里的上是原路返回到父节点,这与普通的树是不同的,二维数组的遍历框架如下:

    // 方向数组,分别代表上、下、左、右
    int[][] dirs = new int[][]{{-1,0}, {1,0}, {0,-1}, {0,1}};
    
    void dfs(int[][] grid, int i, int j, boolean[][] visited) {
        int m = grid.length, n = grid[0].length;
        if (i < 0 || j < 0 || i >= m || j >= n) {
            // 超出索引边界
            return;
        }
        if (visited[i][j]) {
            // 已遍历过 (i, j)
            return;
        }
    
        // 进入节点 (i, j)
        visited[i][j] = true;
        // 递归遍历上下左右的节点
        for (int[] d : dirs) {
            int next_i = i + d[0];
            int next_j = j + d[1];
            dfs(grid, next_i, next_j, visited);
        }
        // 离开节点 (i, j)
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    总的来说,回溯算法能够解决的问题主要有两类,能够转换成树结构的问题和二维数组的遍历问题。
    代表问题有以下几种:

    全排列问题

    无重复数字的全排列

    class Solution:
        def permute(self , num: List[int]) -> List[List[int]]:
            # write code here
            used=[False]*len(num)
            res=[]
            temp=[]
            def backtrack():
                if len(temp)==len(num):
                    res.append(temp[:])
                    return
                for i in range(len(num)):
                    if used[i]:
                        continue
                    temp.append(num[i])
                    used[i]=True
                    backtrack()
                    temp.pop(-1)
                    used[i]=False
            backtrack()
            return res
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    有重复数字的全排列

    class Solution:
        def permuteUnique(self , num: List[int]) -> List[List[int]]:
            # write code here
            num.sort()
            used=[False]*len(num)
            res=[]
            temp=[]
            def backtrack():
                if len(temp)==len(num):
                    res.append(temp[:])
                    return
                for i in range(len(num)):
                    if used[i]:
                        continue
                    if i>0 and num[i]==num[i-1] and used[i-1]:
                        continue
                    temp.append(num[i])
                    used[i]=True
                    backtrack()
                    temp.pop(-1)
                    used[i]=False
            backtrack()
            return res
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    括号生成问题

    class Solution:
        def generateParenthesis(self , n: int) -> List[str]:
            # write code here
            left=n
            right=n
            res=[]
            temp=[]
            def backtrack(left,right):
                if right<left:
                    return
                if right<0 or left<0:
                    return
                if left==0 and right==0:
                    s="".join(temp)
                    res.append(s)
                temp.append("(")
                backtrack(left-1,right)
                temp.pop(-1)
    
                temp.append(")")
                backtrack(left,right-1)
                temp.pop(-1)
            backtrack(left,right)
            return res
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    N皇后问题

    class Solution:
        def Nqueen(self , n: int) -> int:
            # write code here
            matrix = [["." for _ in range(n)] for _ in range(n)]
            res = 0
            def check(r, c, matrix):
                for i in range(r):
                    if matrix[i][c] == "Q":
                        return False
                i, j = r, c
                while i > 0 and j > 0:
                    if matrix[i - 1][j - 1] == "Q":
                        return False
                    i -= 1
                    j -= 1
                i, j = r, c
                while i > 0 and j < n - 1:
                    if matrix[i - 1][j + 1] == "Q":
                        return False
                    i -= 1
                    j += 1
                return True
            def dfs(r):
                nonlocal res, matrix
                if r == n:
                    res += 1
                    return
                for i in range(n):
                    if check(r, i, matrix):
                        matrix[r][i] = "Q"
                        dfs(r + 1)
                        matrix[r][i] = "."
            dfs(0)
            return res
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    矩阵路径问题

    矩阵最长递增路径

    class Solution:
        global dirs
        #记录四个方向
        dirs = [[-1, 0], [1, 0], [0, -1], [0, 1]] 
        global n, m
        #深度优先搜索,返回最大单元格数
        def dfs(self, matrix:List[List[int]], dp: List[List[int]], i:int, j:int) :
            if dp[i][j] != 0:
                return dp[i][j]
            dp[i][j] += 1
            for k in range(4):
                nexti = i + dirs[k][0]
                nextj = j + dirs[k][1]
                #判断条件
                if  nexti >= 0 and nexti < n and nextj >= 0 and nextj < m and matrix[nexti][nextj] > matrix[i][j]:
                    dp[i][j] = max(dp[i][j], self.dfs(matrix, dp, nexti, nextj) + 1)
            return dp[i][j]
        
        def solve(self , matrix: List[List[int]]) -> int:
            global n,m
            #矩阵不为空
            if len(matrix) == 0 or len(matrix[0]) == 0:
                return 0
            res = 0
            n = len(matrix)
            m = len(matrix[0])
            #i,j处的单元格拥有的最长递增路径
            dp = [[0 for col in range(m)] for row in range(n)]  
            for i in range(n):
                for j in range(m):
                    #更新最大值
                    res = max(res, self.dfs(matrix, dp, i, j)) 
            return res
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    岛屿问题

    class Solution:
        def solve(self , grid: List[List[str]]) -> int:
            # write code here
            def dfs(i,j):
                if i<0 or j<0 or i>len(grid)-1 or j >len(grid[0])-1:
                    return
                if grid[i][j]=="0":
                    return
                grid[i][j]="0"
                dfs(i+1,j)
                dfs(i,j+1)
                dfs(i-1,j)
                dfs(i,j-1)
            res=0
            for i in range(len(grid)):
                for j in range(len(grid[0])):
                    if grid[i][j]=="1":
                        res+=1
                        dfs(i,j)
            return res
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
  • 相关阅读:
    Vue3 实现一个自定义toast(小弹窗)
    Jira使用教程-不古出品
    使用Docker配置深度学习的运行环境
    Apache DolphinScheduler 3.0.0 升级到 3.1.8 教程
    【分布式系统】Filebeat+Kafka+ELK 的服务部署
    修复Arch Linux和Manjaro Linux无法显示emoji的问题
    猿创征文|机器学习实战(8)——随机森林
    【SpringCloud学习07】微服务保护之Sentinel(1)
    Arduino PID整定
    博弈论:gym104065j
  • 原文地址:https://blog.csdn.net/weixin_42385782/article/details/128053644