目录
回溯算法本质上是一个暴力搜索的过程,其常用于解决组合、切割、子集、排列等问题,其一般模板如下:
- void backTracking(参数){
- if(终止条件){
- // 1. 收获结果;
- // 2. return;
- }
-
- for(..遍历){
- // 1. 处理节点
- // 2. 递归搜索
- // 3. 回溯 // 即撤销对节点的处理
- }
- return;
- }
主要思路:
基于回溯法,暴力枚举 k 个数,需注意回溯弹出元素的操作;
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> combine(int n, int k) { - std::vector<int> path;
- backTracking(n, k, path, 1); // 从第 1 个数开始
- return res;
- }
-
- void backTracking(int n, int k, std::vector<int> path, int start){
- if(path.size() == k){
- res.push_back(path);
- return;
- }
-
- for(int i = start; i <= n; i++){
- path.push_back(i);
- backTracking(n, k, path, i + 1); // 递归暴力搜索下一个数
- path.pop_back(); // 回溯
- }
- }
-
- private:
- std::vector
int>> res; - };
-
- int main(int argc, char* argv[]){
- int n = 4, k = 2;
- Solution S1;
- std::vector
int>> res = S1.combine(n, k); - for(auto v : res){
- for(int item : v) std::cout << item << " ";
- std::cout << std::endl;
- }
- return 0;
- }
上题的组合问题中,对于进入循环体 for(int i = start; i <= n; i++):
已选择的元素数量为:path.size()
仍然所需的元素数量为:k - path.size()
剩余的元素集合为:n - i + 1
则为了满足要求,必须满足:k-path.size() <= n - i + 1,即 i <= n - k + path.size() + 1
因此,可以通过以下条件完成剪枝操作:
for(int i = start; i <= i <= n - k + path.size() + 1; i++)
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> combine(int n, int k) { - std::vector<int> path;
- backTracking(n, k, path, 1); // 从第 1 个数开始
- return res;
- }
-
- void backTracking(int n, int k, std::vector<int> path, int start){
- if(path.size() == k){
- res.push_back(path);
- return;
- }
-
- for(int i = start; i <= n - k + path.size() + 1; i++){
- path.push_back(i);
- backTracking(n, k, path, i + 1); // 暴力下一个数
- path.pop_back(); // 回溯
- }
- }
-
- private:
- std::vector
int>> res; - };
-
- int main(int argc, char* argv[]){
- int n = 4, k = 2;
- Solution S1;
- std::vector
int>> res = S1.combine(n, k); - for(auto v : res){
- for(int item : v) std::cout << item << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
类似于上面的组合问题,基于回溯来暴力枚举每一个数,需要注意剪枝操作;
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> combinationSum3(int k, int n) { - std::vector<int> path;
- backTracking(k, n, 0, path, 1);
- return res;
- }
-
- void backTracking(int k, int n, int sum, std::vector<int>path, int start){
- if(sum > n) return; // 剪枝
- if(path.size() == k){ // 递归终止
- if(sum == n){
- res.push_back(path);
- }
- return;
- }
-
- for(int i = start; i <= 9 + path.size() - k + 1; i++){ // 剪枝
- path.push_back(i);
- sum += i;
- backTracking(k, n, sum, path, i + 1); // 递归枚举下一个数
- // 回溯
- sum -= i;
- path.pop_back();
- }
-
- }
- private:
- std::vector
int>> res; - };
-
- int main(int argc, char* argv[]){
- int k = 3, n = 7;
- Solution S1;
- std::vector
int>> res = S1.combinationSum3(k, n); - for(auto v : res){
- for(int item : v) std::cout << item << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
根据传入的字符串,遍历每一个数字字符,基于回溯法,递归遍历每一个数字字符对应的字母;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
letterCombinations(std::string digits) { - if(digits.length() == 0) return res;
- std::vector
letter = {"", " ", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz"}; - backTracking(digits, letter, 0);
- return res;
- }
-
- void backTracking(std::string digits, std::vector
letter, int cur) { - if(cur == digits.size()){
- res.push_back(tmp);
- return;
- }
- int letter_idx = digits[cur] - '0';
- int letter_size = letter[letter_idx].size();
- for(int i = 0; i < letter_size; i++){
- tmp.push_back(letter[letter_idx][i]);
- backTracking(digits, letter, cur+1);
- // 回溯
- tmp.pop_back();
- }
- }
-
- private:
- std::vector
res; - std::string tmp;
- };
-
- int main(int argc, char* argv[]){
- std::string test = "23";
- Solution S1;
- std::vector
res = S1.letterCombinations(test); - for(std::string str : res) std::cout << str << " ";
- std::cout << std::endl;
- return 0;
- }
主要思路:
经典组合问题,通过回溯法暴力递归遍历;
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> combinationSum(std::vector<int>& candidates, int target) { - if(candidates.size() == 0) return res;
- backTracking(candidates, target, 0, 0, path);
- return res;
- }
-
- void backTracking(std::vector<int>& candidates, int target, int sum, int start, std::vector<int>path){
- if(sum > target) return; // 剪枝
- if(sum == target) {
- res.push_back(path);
- return;
- }
-
- for(int i = start; i < candidates.size(); i++){
- sum += candidates[i];
- path.push_back(candidates[i]);
- backTracking(candidates, target, sum, i, path);
- // 回溯
- sum -= candidates[i];
- path.pop_back();
- }
- }
-
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
- int main(int argc, char* argv[]){
- // candidates = [2,3,6,7], target = 7
- std::vector<int> test = {2, 3, 6, 7};
- int target = 7;
- Solution S1;
- std::vector
int>> res = S1.combinationSum(test, target); -
- for(auto vec : res){
- for(int val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
-
- return 0;
- }
主要思路:
基于回溯法,暴力递归遍历数组;
本题不能包含重复的组合,但由于有重复的数字,因此需要进行树层上的去重;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> combinationSum2(std::vector<int>& candidates, int target) { - if(candidates.size() == 0) return res;
- std::sort(candidates.begin(), candidates.end());
- std::vector<bool> visited(candidates.size(), false);
- backTracking(candidates, target, 0, 0, visited);
- return res;
- }
-
- void backTracking(std::vector<int>& candidates, int target, int sum, int start, std::vector<bool>& visited){
- if(sum > target) return; // 剪枝
- if(sum == target){
- res.push_back(path);
- return;
- }
- for(int i = start; i < candidates.size(); i++){
- // 树层剪枝去重
- if(i > 0 && candidates[i-1] == candidates[i]){
- if(visited[i-1] == false) continue;
- }
- sum += candidates[i];
- path.push_back(candidates[i]);
- visited[i] = true;
- backTracking(candidates, target, sum, i+1, visited);
- // 回溯
- sum -= candidates[i];
- path.pop_back();
- visited[i] = false;
- }
- }
-
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
- int main(int argc, char* argv[]){
- // candidates = [10,1,2,7,6,1,5], target = 8
- std::vector<int> test = {10, 1, 2, 7, 6, 1, 5};
- int target = 8;
- Solution S1;
- std::vector
int>> res = S1.combinationSum2(test, target); -
- for(auto vec : res){
- for(int val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
-
- return 0;
- }
主要思路:
基于回溯法,暴力加入子串,判断每一个子串是否是一个回文子串,如果不是回文子串则跳过,否则继续,直到将字符串的所有字符都遍历完;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
> partition(std::string s) { - if(s.length() == 0) return res;
- backTracking(s, 0);
- return res;
- }
-
- void backTracking(std::string s, int startIdx){
- if(startIdx >= s.size()){
- res.push_back(tmp);
- return;
- }
-
- for(int i = startIdx; i < s.size(); i++){
- std::string str = s.substr(startIdx, i - startIdx + 1);
- if(is_valid(str) == false) continue; // 不合法的
- // 合法的
- tmp.push_back(str);
- backTracking(s, i+1);
- // 回溯
- tmp.pop_back();
- }
- }
-
- // 判断是否是回文子串
- bool is_valid(std::string str){
- for(int i = 0, j = str.size()-1; i <= j; i++, j--){
- if(str[i] != str[j]) return false;
- }
- return true;
- }
-
- private:
- std::vector
> res; - std::vector
tmp; - };
-
- int main(int argc, char* argv[]){
- // s = "aab"
- std::string str = "aab";
- Solution S1;
- std::vector
> res = S1.partition(str); - for(auto vec : res){
- for(std::string s : vec) std::cout << s << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
基于回溯法,递归遍历数字字符串,需要判断数字字符串是否合法;
递归终止条件是当前加入的 ‘.’ 号已经为 3 个了,这时还需要判断最后一个数字字符串是否合法;
比较难的地方可能在于判断字符是否有效;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
restoreIpAddresses(std::string s) { - if(s.length() < 4 || s.length() > 12) return res; // 剪枝
- backTracking(s, 0, 0);
- return res;
- }
-
- void backTracking(std::string s, int start_inx, int numpoints){
- if(numpoints == 3){
- std::string str = s.substr(start_inx, s.length() - start_inx);
- if(isvalid(str)){
- res.push_back(s);
- return;
- }
- }
-
- for(int i = start_inx; i < s.length(); i++){
- std::string str = s.substr(start_inx, i - start_inx + 1);
- if(isvalid(str) == false) break;
- s.insert(s.begin() + i + 1, '.'); // 插入 '.'
- numpoints += 1;
- backTracking(s, i + 2, numpoints); // 从 i + 2 开始,因为新加入了一个 '.'字符
- // 回溯
- s.erase(s.begin() + i + 1);
- numpoints -= 1;
- }
- }
-
- bool isvalid(std::string str){
- int len = str.length();
- if(len <= 0 || len > 3) return false;
- if (str[0] == '0' && len > 1) { // 不合法
- return false;
- }
- int num = 0;
- for(int i = 0; i < len; i++){
- num = num * 10 + (str[i] - '0');
- if(num > 255) return false;
- }
- return true;
- }
-
- private:
- std::vector
res; - int numpoints = 0;
- };
-
- int main(int argc, char* argv[]){
- // s = "25525511135"
- std::string test = "25525511135";
- Solution S1;
- std::vector
res = S1.restoreIpAddresses(test); - for(auto str : res) std::cout << str << " ";
- std::cout << std::endl;
- return 0;
- }
主要思路:
基于回溯法,暴力遍历加入每一个元素,每加入一个元素就收获一次结果;
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> subsets(std::vector<int>& nums) { - dfs(nums, 0);
- return res;
- }
-
- void dfs(std::vector<int>& nums, int start_idx){
- if(start_idx > nums.size()) return; // 可以删去
- res.push_back(path);
- for(int i = start_idx; i < nums.size(); i++){
- path.push_back(nums[i]);
- dfs(nums, i+1);
- // 回溯
- path.pop_back();
- }
- }
-
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
-
- int main(int argc, char* argv[]){
- // nums = [1,2,3]
- std::vector<int> nums = {1, 2, 3};
- Solution S1;
- std::vector
int>> res = S1.subsets(nums); - for(auto vec : res){
- for(auto val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
类似于上题的子集问题,暴力遍历每一个元素,每加入一个元素收获一次结果;需要注意的是,本题需要去重,因此采用类似组合问题的去重逻辑:先对数组进行排序,再进行树层上的去重;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> subsetsWithDup(std::vector<int>& nums) { - std::sort(nums.begin(), nums.end());
- std::vector<bool> visited(nums.size(), false);
- dfs(nums, 0, visited);
- return res;
- }
-
- void dfs(std::vector<int>& nums, int start_idx, std::vector<bool>& visited){
- if(start_idx > nums.size()) return; // 可以省略
- res.push_back(path);
- for(int i = start_idx; i < nums.size(); i++){
- // 树层上去重
- if(i > 0 && nums[i-1] == nums[i] && visited[i-1] == false){
- continue;
- }
- path.push_back(nums[i]);
- visited[i] = true;
- dfs(nums, i+1, visited);
- // 回溯
- path.pop_back();
- visited[i] = false;
- }
- }
-
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
-
- int main(int argc, char* argv[]){
- // nums = [1,2,2]
- std::vector<int> nums = {1, 2, 2};
- Solution S1;
- std::vector
int>> res = S1.subsetsWithDup(nums); - for(auto vec : res){
- for(auto val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
基于回溯法,暴力遍历每一个元素;
判断一个元素是否可以加入路径的条件:元素大于路径的最后一个结点(保证递增);因为有重复的元素,因此需要在树层上进行去重,树层去重需要判断当前元素是否在遍历之前出现过,如果出现过则舍弃;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> findSubsequences(std::vector<int>& nums) { - if(nums.size() == 0) return res;
- backTracking(nums, 0);
- return res;
- }
-
- void backTracking(std::vector<int>& nums, int start_idx){
- if(path.size() > 1){
- res.push_back(path); // 每一层递归加入一个结果
- }
-
- std::set<int> my_set; // 每一层递归新建一个set
- for(int i = start_idx; i < nums.size(); i++){
- if(path.size() > 0 && nums[i] < path.back() || my_set.find(nums[i]) != my_set.end()){ // 树层上去重
- continue;
- }
- path.push_back(nums[i]);
- my_set.insert(nums[i]);
- backTracking(nums, i+1);
- // 回溯
- path.pop_back();
- }
- }
-
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
- int main(int argc, char* argv[]){
- // nums = [4,6,7,7]
- std::vector<int> nums = {4, 6, 7, 7};
- Solution S1;
- std::vector
int>> res = S1.findSubsequences(nums); - for(auto vec : res){
- for(int val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
基于回溯法,暴力遍历每一个元素;由于是排列,因此结果的元素数量肯定与数组大小相同;
每一轮递归,都需要从索引 0 开始遍历;使用一个数组来记录当前元素是否被访问过;
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> permute(std::vector<int>& nums) { - if(nums.size() == 0) return res;
- std::vector<bool> visited(nums.size(), false);
- backTracking(nums, visited);
- return res;
- }
-
- void backTracking(std::vector<int>& nums, std::vector<bool>& visited){
- if(path.size() == nums.size()){
- res.push_back(path);
- return;
- }
- for(int i = 0; i < nums.size(); i++){
- if(visited[i] == true) continue;
- path.push_back(nums[i]);
- visited[i] = true;
- backTracking(nums, visited);
- // 回溯
- visited[i] = false;
- path.pop_back();
- }
- }
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
- int main(int argc, char* argv[]){
- // nums = [1,2,3]
- std::vector<int> nums = {1, 2, 3};
- Solution S1;
- std::vector
int>> res = S1.permute(nums); - for(auto vec : res){
- for(auto val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
主要思路类似于上题的全排列,但本题的数组含有重复元素,因此需要进行树层上的去重(先对数组进行排序);
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
int>> permuteUnique(std::vector<int>& nums) { - if(nums.size() == 0) return res;
- std::sort(nums.begin(), nums.end());
- std::vector<bool> visited(nums.size(), false);
- backTracking(nums, visited);
- return res;
- }
-
- void backTracking(std::vector<int>& nums, std::vector<bool>& visited){
- if(path.size() == nums.size()){
- res.push_back(path);
- return;
- }
- for(int i = 0; i < nums.size(); i++){
- if(i > 0 && nums[i-1] == nums[i] && visited[i-1] == false) continue; // 树层上去重
- if(visited[i] == true) continue; // 重复元素
-
- path.push_back(nums[i]);
- visited[i] = true;
- backTracking(nums, visited);
- // 回溯
- path.pop_back();
- visited[i] = false;
- }
- }
-
- private:
- std::vector
int>> res; - std::vector<int> path;
- };
-
- int main(int argc, char* argv[]){
- // nums = [1,1,2]
- std::vector<int> nums = {1, 1, 2};
- Solution S1;
- std::vector
int>> res = S1.permuteUnique(nums); - for(auto vec : res){
- for(auto val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
基于回溯法,递归每一行,暴力枚举每一列是否适合放皇后;
- #include
- #include
- #include
-
- class Solution {
- public:
- std::vector
> solveNQueens(int n) { - std::vector
board(n, std::string(n, '.')) ; - backTracking(n, 0, board);
- return res;
- }
-
- void backTracking(int n, int row, std::vector
& board) { - if(row == n){
- res.push_back(board);
- return;
- }
- for(int col = 0; col < n; col++){
- if(isValid(board, row, col)){ // [row, col]可以放一个皇后
- board[row][col] = 'Q';
- backTracking(n, row+1, board);
- // 回溯
- board[row][col] = '.';
- }
- }
- }
-
- bool isValid(const std::vector
& board, int row, int col) { - // 检查行
- for(int i = 0; i < row; i++){
- if(board[i][col] == 'Q') return false;
- }
- // 检查列
- for(int i = 0; i < board.size(); i++){
- if(board[row][i] == 'Q') return false;
- }
- // 检查主对角线
- for(int i = row - 1, j = col - 1; i >=0 && j >=0; i--, j--){
- if(board[i][j] == 'Q') return false;
- }
- // 检查反对角线
- for(int i = row - 1, j = col + 1; i >= 0 && j <= board.size(); i--, j++){
- if(board[i][j] == 'Q') return false;
- }
- return true;
- }
-
- private:
- std::vector
> res; - };
-
- int main(int argc, char* argv[]){
- // n = 4
- int n = 4;
- Solution S1;
- std::vector
> res = S1.solveNQueens(n); - for(auto vec : res){
- for(auto val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }
主要思路:
基于回溯法,两层循环分别判断行和列,枚举 9 个数是否适合加入到当前行和列;
- #include
- #include
- #include
-
- class Solution {
- public:
- void solveSudoku(std::vector
char >>& board) { - backTracking(board);
- }
-
- bool backTracking(std::vector
char >>& board){ - for(int row = 0; row < board.size(); row++){
- for(int col = 0; col < board[0].size(); col++){
- if(board[row][col] == '.'){
- for(char k = '1'; k <= '9'; k++){
- if(isValid(board, row, col, k)){
- board[row][col] = k;
- if(backTracking(board) == true){ // 递归
- return true;
- }
- // 回溯
- board[row][col] = '.';
- }
- }
- return false; // 尝试了 9 个数也不能递归返回 true, 就返回 false
- }
- }
- }
- return true; // 遍历完没有返回 false,说明找到的合适棋盘
- }
-
- bool isValid(const std::vector
char >>& board, int row, int col, char val){ - // 检查同一行
- for(int i = 0; i < board[0].size(); i++){
- if(board[row][i] == val) return false;
- }
-
- // 检查同一列
- for(int i = 0; i < board.size(); i++){
- if(board[i][col] == val) return false;
- }
-
- // 检查九宫格
- int startRow = (row / 3) * 3;
- int startCol = (col / 3) * 3;
- for (int i = startRow; i < startRow + 3; i++){ // 判断九宫格内是否重复
- for (int j = startCol; j < startCol + 3; j++){
- if (board[i][j] == val){
- return false;
- }
- }
- }
- return true;
- }
- };
-
- int main(int argc, char* argv[]){
- std::vector
char>> board = {{'5', '3', '.', '.', '7', '.', '.', '.', '.'}, - {'6', '.', '.', '1', '9', '5', '.', '.', '.'},
- {'.', '9', '8', '.', '.', '.', '.', '6', '.'},
- {'8', '.', '.', '.', '6', '.', '.', '.', '3'},
- {'4', '.', '.', '8', '.', '3', '.', '.', '1'},
- {'7', '.', '.', '.', '2', '.', '.', '.', '6'},
- {'.', '6', '.', '.', '.', '.', '2', '8', '.'},
- {'.', '.', '.', '4', '1', '9', '.', '.', '5'},
- {'.', '.', '.', '.', '8', '.', '.', '7', '9'}};
- Solution S1;
- S1.solveSudoku(board);
- for(auto vec : board){
- for(auto val : vec) std::cout << val << " ";
- std::cout << std::endl;
- }
- return 0;
- }