使用预处理的方式执行SQL

package main

import (
   "database/sql"
   "fmt"
   _ "github.com/go-sql-driver/mysql"
   "strconv"
)

//========================= 通用模型·开始 =========================//

// CommonModel 通用模型
type CommonModel struct {
   Table string // 表名(不含表前缀)
}

var db *sql.DB    // 数据库操作对象
var prefix string // 表名前缀

// GetDB 获取数据库操作对象
func (commonModel *CommonModel) GetDB() *sql.DB {
   if db == nil {
      // 驱动和DSN,其中DSN格式为『账号:密码@tcp(主机地址:端口号)/数据库名?charset=字符集』
      driverName := "mysql"
      dataSourceName := "username:password@tcp(localhost:3306)/dbname?charset=utf8mb4"

      var err error

      db, err = sql.Open(driverName, dataSourceName)
      if err != nil {
         panic("数据库参数错误:" + err.Error())
      }

      // 连接数据库
      if db.Ping() != nil {
         panic("数据库连接失败:" + err.Error())
      }
   }

   return db
}

// GetPrefix 获取表名前缀
func (commonModel *CommonModel) GetPrefix() string {
   if prefix == "" {
      prefix = "tmp_"
   }

   return prefix
}

// Query 执行SELECT类语句
// 说明:为方便使用,字段值的数据类型统一为string,调用方如有需要可在拿到返回数据后再自行转换。
func (commonModel *CommonModel) Query(query string, args []any) (list []map[string]string) {
   // SQL预处理
   stmt, err := commonModel.GetDB().Prepare(query)
   if err != nil {
      panic("SQL预处理失败:" + err.Error())
   }

   // 关闭预处理句柄
   defer func(stmt *sql.Stmt) {
      _ = stmt.Close()
   }(stmt)

   // 传入参数执行SQL,获得查询结果集
   rows, err := stmt.Query(args...) // 注意变量名后面要加三个小数点
   if err != nil {
      panic("SQL执行失败:" + err.Error())
   }

   // 关闭结果集
   defer func(rows *sql.Rows) {
      _ = rows.Close()
   }(rows)

   // 获取所有字段名
   columns, err := rows.Columns()
   if err != nil {
      panic("获取查询结果集字段名失败:" + err.Error())
   }

   values := make([][]byte, len(columns)) // 用于保存字段值(由于字段类型不定,故使用[]byte类型)
   dest := make([]any, len(columns))
   for i := range values {
      dest[i] = &values[i]
   }

   // 遍历结果集
   for rows.Next() {
      err = rows.Scan(dest...)
      if err != nil {
         panic("遍历查询结果集失败:" + err.Error())
      }

      // 遍历单条记录的所有字段值
      row := map[string]string{} // 单条记录的数据
      for j := range values {
         row[columns[j]] = string(values[j]) // row["字段名"] = "字段值"
      }

      list = append(list, row)
   }

   return list
}

// Exec 执行INSERT|UPDATE|DELETE类语句
func (commonModel *CommonModel) Exec(exec string, args []any) (result sql.Result) {
   // SQL预处理
   stmt, err := commonModel.GetDB().Prepare(exec)
   if err != nil {
      panic("SQL预处理失败:" + err.Error())
   }

   // 关闭预处理句柄
   defer func(stmt *sql.Stmt) {
      _ = stmt.Close()
   }(stmt)

   // 传入参数执行SQL
   result, err = stmt.Exec(args...) // 注意变量名后面要加三个小数点
   if err != nil {
      panic("SQL执行失败:" + err.Error())
   }

   return
}

//========================= 通用模型·结束 =========================//

//========================= 文章模型·开始 =========================//

// ArticleModel 文章模型(继承通用模型)
type ArticleModel struct {
   CommonModel // 继承CommonModel
}

// ArticleModelNew ArticleModel构造方法
func ArticleModelNew() *ArticleModel {
   articleModel := new(ArticleModel)
   articleModel.Table = articleModel.GetPrefix() + "article" // 拼接上表前缀,构造出完整表名
   return articleModel
}

// GetOneById 根据id获取单条记录
func (articleModel *ArticleModel) GetOneById(id int64) (row map[string]string) {
   query := fmt.Sprintf("SELECT * FROM `%s` WHERE `id` = ? LIMIT 1", articleModel.Table)

   // 构造查询参数(顺序要和SQL语句中的问号占位符一一对应)
   var args []any
   args = append(args, id)

   rows := articleModel.Query(query, args)
   if len(rows) == 1 {
      row = rows[0]
   }

   return
}

// GetList 获取多条记录(指定起始id和结束id)
func (articleModel *ArticleModel) GetList(start int64, end int64) (rows []map[string]string) {
   query := fmt.Sprintf("SELECT * FROM `%s` WHERE ? <= `id` AND `id` <= ?", articleModel.Table)

   // 构造查询参数(顺序要和SQL语句中的问号占位符一一对应)
   var args []any
   args = append(args, start)
   args = append(args, end)

   rows = articleModel.Query(query, args)

   return
}

// Add 插入记录
func (articleModel *ArticleModel) Add(value map[string]any) (insertId int64) {
   query := fmt.Sprintf("INSERT INTO `%s` SET `title` = ?, `click` = ?", articleModel.Table)

   // 构造查询参数(顺序要和SQL语句中的问号占位符一一对应)
   var args []any
   args = append(args, value["title"])
   args = append(args, value["click"])

   result := articleModel.Exec(query, args)

   // 获取影响记录数
   rowsAffected, err := result.RowsAffected()
   if err != nil {
      panic("获取影响记录数失败:" + err.Error())
   }

   if rowsAffected == 1 {
      insertId, err = result.LastInsertId()
      if err != nil {
         panic("获取插入记录文章ID失败:" + err.Error())
      }
   }

   return
}

// DeleteById 根据id删除记录(单条)
func (articleModel *ArticleModel) DeleteById(id int64) (rowsAffected int64) {
   query := fmt.Sprintf("DELETE FROM `%s` WHERE `id` = ? LIMIT 1", articleModel.Table)

   // 构造查询参数(顺序要和SQL语句中的问号占位符一一对应)
   var args []any
   args = append(args, id)

   result := articleModel.Exec(query, args)

   // 获取影响记录数
   rowsAffected, err := result.RowsAffected()
   if err != nil {
      panic("获取影响记录数失败:" + err.Error())
   }

   return
}

//========================= 文章模型·结束 =========================//

func main() {
   articleModel := ArticleModelNew()

   //========== 查询单条记录(存在) ==========//
   row := articleModel.GetOneById(255)
   if len(row) > 0 {
      fmt.Printf("第%s篇文章  %s  [点击:%s] \n", row["id"], row["title"], row["click"])
      // 第255篇文章  PHP是世界上最好の语言  [点击:9527]
   } else {
      fmt.Println("文章不存在")
   }

   //========== 查询单条记录(不存在) ==========//
   row = articleModel.GetOneById(99999)
   if len(row) > 0 {
      fmt.Printf("第%s篇文章  %s  [点击:%s] \n", row["id"], row["title"], row["click"])
   } else {
      fmt.Println("文章不存在") // 文章不存在
   }

   //========== 查询多条记录 ==========//
   rows := articleModel.GetList(3, 7) // 查询id为3~7的记录
   for _, row = range rows {
      fmt.Printf("第%s篇文章  %s  [点击:%s] \n", row["id"], row["title"], row["click"])
   }
   // 第3篇文章  ……此处省略內容若干……  [点击:1024]
   // 第4篇文章  ……此处省略內容若干……  [点击:1024]
   // 第5篇文章  ……此处省略內容若干……  [点击:1024]
   // 第6篇文章  ……此处省略內容若干……  [点击:1024]
   // 第7篇文章  ……此处省略內容若干……  [点击:1024]

   //========== 插入记录 ==========//
   value := map[string]any{}
   value["title"] = "颈椎病的康复与治疗"
   value["click"] = 512
   insertId := articleModel.Add(value)
   if insertId > 0 {
      fmt.Println("插入记录成功,文章ID为:" + strconv.FormatInt(insertId, 10)) // 插入记录成功,文章ID为:*****
   } else {
      fmt.Println("插入记录失败")
   }

   //========== 删除记录 ==========//
   rowsAffected := articleModel.DeleteById(insertId)
   fmt.Printf("共删除%d条记录 \n", rowsAffected) // 共删除1条记录
}

//========== 总结 ==========//
// 1、使用预处理的方式执行SQL能够彻底解决SQL注入问题,所以强烈建议使用该方式操作数据库。
// 2、使用预处理的方式执行SQL可以在general日志里看到Prepare和Execute两条信息,其中Prepare就是预处理的SQL(带问号占位符的字
//    符串),而Execute是最终执行的SQL(问号占位符已经替换成实际内容)。

Copyright © 2024 码农人生. All Rights Reserved