|
| 1 | +{-# LANGUAGE RecordWildCards #-} |
| 2 | +{-# LANGUAGE OverloadedStrings #-} |
| 3 | + |
| 4 | +module Cardano.Db.Statement.Helpers where |
| 5 | + |
| 6 | +import Cardano.BM.Trace (logDebug) |
| 7 | +import Cardano.Db.Error (CallSite (..), DbError (..)) |
| 8 | +import Cardano.Db.Types (DbAction (..), DbTxMode (..), DbTransaction (..), DbEnv (..)) |
| 9 | +import Cardano.Prelude (MonadIO (..), ask, when, MonadError (..)) |
| 10 | +import Data.Time (getCurrentTime, diffUTCTime) |
| 11 | +import GHC.Stack (HasCallStack, getCallStack, callStack, SrcLoc (..)) |
| 12 | +import qualified Data.Text as Text |
| 13 | +import qualified Hasql.Decoders as HsqlD |
| 14 | +import qualified Hasql.Encoders as HsqlE |
| 15 | +import qualified Hasql.Session as HsqlS |
| 16 | +import qualified Hasql.Statement as HsqlS |
| 17 | +import qualified Hasql.Transaction as HsqlT |
| 18 | +import qualified Hasql.Transaction.Sessions as HsqlT |
| 19 | +import qualified Data.Text.Encoding as TextEnc |
| 20 | + |
| 21 | +-- | Runs a database transaction with optional logging. |
| 22 | +-- |
| 23 | +-- This function executes a `DbTransaction` within the `DbAction` monad, handling |
| 24 | +-- the transaction mode (read-only or write) and logging execution details if |
| 25 | +-- enabled in the `DbEnv`. It captures timing information and call site details |
| 26 | +-- for debugging purposes when logging is active. |
| 27 | +-- |
| 28 | +-- ==== Parameters |
| 29 | +-- * @mode@: The transaction mode (`Write` or `ReadOnly`). |
| 30 | +-- * @DbTransaction{..}@: The transaction to execute, containing the function name, |
| 31 | +-- call site, and the `Hasql` transaction. |
| 32 | +-- |
| 33 | +-- ==== Returns |
| 34 | +-- * @DbAction m a@: The result of the transaction wrapped in the `DbAction` monad. |
| 35 | +runDbT |
| 36 | + :: MonadIO m |
| 37 | + => DbTxMode |
| 38 | + -> DbTransaction a |
| 39 | + -> DbAction m a |
| 40 | +runDbT mode DbTransaction{..} = DbAction $ do |
| 41 | + dbEnv <- ask |
| 42 | + let logMsg msg = when (dbEnableLogging dbEnv) $ liftIO $ logDebug (dbTracer dbEnv) msg |
| 43 | + |
| 44 | + -- Run the session and handle the result |
| 45 | + let runSession = do |
| 46 | + result <- liftIO $ HsqlS.run session (dbConnection dbEnv) |
| 47 | + case result of |
| 48 | + Left err -> throwError $ QueryError "Transaction failed" dtCallSite err |
| 49 | + Right val -> pure val |
| 50 | + |
| 51 | + if dbEnableLogging dbEnv |
| 52 | + then do |
| 53 | + logMsg $ "Starting transaction: " <> dtFunctionName <> locationInfo |
| 54 | + start <- liftIO getCurrentTime |
| 55 | + result <- runSession |
| 56 | + end <- liftIO getCurrentTime |
| 57 | + let duration = diffUTCTime end start |
| 58 | + logMsg $ "Transaction completed: " <> dtFunctionName <> locationInfo <> " in " <> Text.pack (show duration) |
| 59 | + pure result |
| 60 | + else runSession |
| 61 | + where |
| 62 | + session = HsqlT.transaction HsqlT.Serializable txMode dtTx |
| 63 | + txMode = case mode of |
| 64 | + Write -> HsqlT.Write |
| 65 | + ReadOnly -> HsqlT.Read |
| 66 | + locationInfo = " at " <> csModule dtCallSite <> ":" <> |
| 67 | + csFile dtCallSite <> ":" <> Text.pack (show $ csLine dtCallSite) |
| 68 | + |
| 69 | +-- | Creates a `DbTransaction` with a function name and call site. |
| 70 | +-- |
| 71 | +-- Constructs a `DbTransaction` record for use with `runDbT`, capturing the |
| 72 | +-- function name and call site from the current stack trace. This is useful |
| 73 | +-- for logging and debugging database operations. |
| 74 | +-- |
| 75 | +-- ==== Parameters |
| 76 | +-- * @funcName@: The name of the function or operation being performed. |
| 77 | +-- * @transx@: The `Hasql` transaction to encapsulate. |
| 78 | +-- |
| 79 | +-- ==== Returns |
| 80 | +-- * @DbTransaction a@: A transaction record with metadata. |
| 81 | +mkDbTransaction :: Text.Text -> HsqlT.Transaction a -> DbTransaction a |
| 82 | +mkDbTransaction funcName transx = |
| 83 | + DbTransaction |
| 84 | + { dtFunctionName = funcName |
| 85 | + , dtCallSite = mkCallSite |
| 86 | + , dtTx = transx |
| 87 | + } |
| 88 | + where |
| 89 | + mkCallSite :: HasCallStack => CallSite |
| 90 | + mkCallSite = |
| 91 | + case reverse (getCallStack callStack) of |
| 92 | + (_, srcLoc) : _ -> CallSite |
| 93 | + { csModule = Text.pack $ srcLocModule srcLoc |
| 94 | + , csFile = Text.pack $ srcLocFile srcLoc |
| 95 | + , csLine = srcLocStartLine srcLoc |
| 96 | + } |
| 97 | + [] -> error "No call stack info" |
| 98 | + |
| 99 | +-- | Inserts multiple records into a table in a single transaction using UNNEST. |
| 100 | +-- |
| 101 | +-- This function performs a bulk insert into a specified table, using PostgreSQL’s |
| 102 | +-- `UNNEST` to expand arrays of field values into rows. It’s designed for efficiency, |
| 103 | +-- executing all inserts in one SQL statement, and returns the generated IDs. |
| 104 | +-- |
| 105 | +-- ==== Parameters |
| 106 | +-- * @table@: Text - The name of the table to insert into. |
| 107 | +-- * @cols@: [Text] - List of column names (excluding the ID column). |
| 108 | +-- * @types@: [Text] - List of PostgreSQL type casts for each column (e.g., "bigint[]"). |
| 109 | +-- * @extract@: ([a] -> [b]) - Function to extract fields from a list of records into a tuple of lists. |
| 110 | +-- * @enc@: HsqlE.Params [b] - Encoder for the extracted fields as a tuple of lists. |
| 111 | +-- * @dec@: HsqlD.Result [c] - Decoder for the returned IDs. |
| 112 | +-- * @xs@: [a] - List of records to insert. |
| 113 | +-- |
| 114 | +-- ==== Returns |
| 115 | +-- * @DbAction m [c]@: The list of generated IDs wrapped in the `DbAction` monad. |
| 116 | +bulkInsert |
| 117 | + :: Text.Text -- Table name |
| 118 | + -> [Text.Text] -- Column names |
| 119 | + -> [Text.Text] -- Type casts for UNNEST |
| 120 | + -> ([a] -> b) -- Field extractor (e.g., to tuple) |
| 121 | + -> HsqlE.Params b -- Bulk encoder |
| 122 | + -> HsqlD.Result [c] -- ID decoder |
| 123 | + -> [a] -- Records |
| 124 | + -> HsqlT.Transaction [c] -- Resulting IDs |
| 125 | +bulkInsert table cols types extract enc dec xs = |
| 126 | + HsqlT.statement params $ HsqlS.Statement sql enc dec True |
| 127 | + where |
| 128 | + params = extract xs |
| 129 | + sql = TextEnc.encodeUtf8 $ |
| 130 | + "INSERT INTO " <> table <> " (" <> Text.intercalate ", " cols <> ") \ |
| 131 | + \SELECT * FROM UNNEST (" <> Text.intercalate ", " (zipWith (\i t -> "$" <> Text.pack (show i) <> "::" <> t) [1..] types) <> ") \ |
| 132 | + \RETURNING id" |
0 commit comments