Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
wz1000 committed Jul 11, 2024
1 parent fdc918e commit d2e809a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 79 deletions.
117 changes: 68 additions & 49 deletions trace-foreign-calls/src/Plugin/TraceForeignCalls.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}

module Plugin.TraceForeignCalls (plugin) where

Expand Down Expand Up @@ -43,8 +44,9 @@ processRenamed ::
-> TcGblEnv
-> HsGroup GhcRn
-> TcM (TcGblEnv, HsGroup GhcRn)
processRenamed options tcGblEnv group = do
runInstrument options $ (tcGblEnv,) <$> processGroup group
processRenamed options tcGblEnv group
| moduleUnit (tcg_mod tcGblEnv) `elem` [primUnit, bignumUnit] = pure (tcGblEnv, group)
| otherwise = runInstrument options $ (tcGblEnv,) <$> processGroup group

{-------------------------------------------------------------------------------
Binding groups
Expand Down Expand Up @@ -109,12 +111,15 @@ processForeignDecl ::
-> Instrument (Either (LForeignDecl GhcRn) ReplacedForeignImport)
processForeignDecl decl@(L _ ForeignExport{}) =
return $ Left decl
processForeignDecl (L _ ForeignImport{
processForeignDecl decl@(L _ ForeignImport{
fd_i_ext = NoExtField
, fd_name = rfiOriginalName
, fd_sig_ty = rfiSigType
, fd_fi = rfiForeignImport
}) = do
})
| CImport _ (unLoc -> conv) _ _ _ <- rfiForeignImport
, conv == PrimCallConv = return $ Left decl
| otherwise = do
rfiSuffixedName <- renameForeignImport rfiOriginalName
return $ Right ReplacedForeignImport{
rfiOriginalName
Expand Down Expand Up @@ -232,56 +237,67 @@ mkWrapperBody ::
ReplacedForeignImport
-> Instrument ([Name], LHsExpr GhcRn)
mkWrapperBody rfi@ReplacedForeignImport {rfiSuffixedName, rfiSigType} = do
traceEventIO <- findName nameTraceEventIO
let callTraceEventIO :: LHsExpr GhcRn -> ExprLStmt GhcRn
callTraceEventIO arg = noLocValue $
BodyStmt
noValue
(callNamedFn traceEventIO [arg])
regularBodyStmt
NoSyntaxExprRn

evaluate <- findName nameEvaluate
let callEvaluate :: LHsExpr GhcRn -> LHsExpr GhcRn
callEvaluate arg = callNamedFn evaluate [arg]

unsafePerformIO <- findName nameUnsafePerformIO
let callUnsafePerformIO :: LHsExpr GhcRn -> LHsExpr GhcRn
callUnsafePerformIO arg = callNamedFn unsafePerformIO [arg]
traceEventHash <- findName nameTraceEventHash
let callTraceEvent :: LHsExpr GhcRn -> LHsExpr GhcRn -> LHsExpr GhcRn
callTraceEvent arg st = callNamedFn traceEventHash [arg, st]

seqHash <- findName nameSeq
let callSeq :: LHsExpr GhcRn -> LHsExpr GhcRn -> LHsExpr GhcRn
callSeq arg st = callNamedFn seqHash [arg, st]

runRW <- findName nameRunRW
let callRunRW :: LHsExpr GhcRn -> LHsExpr GhcRn
callRunRW arg = callNamedFn runRW [arg]

noDuplicate <- findName nameNoDuplicate
let callNoDuplicate :: LHsExpr GhcRn -> LHsExpr GhcRn
callNoDuplicate arg = callNamedFn noDuplicate [arg]

(args, resultTy) <- uniqArgsFor (sig_body $ unLoc rfiSigType)
let callUninstrumented :: LHsExpr GhcRn
callUninstrumented = callLNamedFn rfiSuffixedName (map namedVar args)

s <- uniqInternalName "s"
s' <- uniqInternalName "s'"
f <- uniqInternalName "f"
result <- uniqInternalName "result"
result' <- uniqInternalName "result'"
eventLogCall <- mkEventLogCall rfi
eventLogReturn <- mkEventLogReturn rfi
let doBlock :: LHsExpr GhcRn
doBlock = noLocValue $ HsDo noValue (DoExpr Nothing) $ noLocValue [
callTraceEventIO eventLogCall
, noLocValue $
BindStmt
regularBindStmt
(namedVarPat result)
( case checkIsIO resultTy of
Just _ -> callUninstrumented
Nothing -> callEvaluate callUninstrumented
)
, callTraceEventIO eventLogReturn
, noLocValue $
LastStmt
noValue
(callNamedFn returnMName [namedVar result])
Nothing
NoSyntaxExprRn
]

return (
args
, case checkIsIO resultTy of
Just _ -> doBlock
Nothing -> callUnsafePerformIO doBlock
)
let wrapped = case checkIsIO resultTy of
{- case foo of
IO f -> IO (\s -> case f (traceEvent# call s) of
(# s', result #) -> (# traceEvent# return s', result #))
-}
Just _ -> HsCase CaseAlt callUninstrumented
$ mkMatchGroup (Generated OtherExpansion SkipPmc) . noLocValue . pure
$ mkHsCaseAlt (noLocValue $ ConPat noExtField (noLocValue ioDataConName) $ PrefixCon [] [namedVarPat f])
$ mkHsApp (namedVar ioDataConName)
$ mkHsLam [namedVarPat s]
$ noLocValue
$ HsCase CaseAlt (mkHsApp (namedVar f) (callTraceEvent eventLogCall $ namedVar s))
$ mkMatchGroup (Generated OtherExpansion SkipPmc) . noLocValue . pure
$ mkHsCaseAlt (noLocValue $ TuplePat noExtField [namedVarPat s', namedVarPat result] Unboxed)
$ noLocValue $ ExplicitTuple noExtField [Present noExtField (callTraceEvent eventLogReturn (namedVar s')), Present noExtField (namedVar result) ] Unboxed
{- case runRW# (\s -> case seq# foo (traceEvent# call (noDuplicate s)) of
(# s', result #) -> (# traceEvent# return s', result #)) of
(# _ , result' #) -> result'
-}
Nothing -> HsCase CaseAlt
( callRunRW
$ mkHsLam [namedVarPat s]
$ noLocValue
$ HsCase CaseAlt
(callSeq callUninstrumented $ callTraceEvent eventLogCall $ callNoDuplicate $ namedVar s)
$ mkMatchGroup (Generated OtherExpansion SkipPmc) . noLocValue . pure
$ mkHsCaseAlt (noLocValue $ TuplePat noExtField [namedVarPat s', namedVarPat result] Unboxed)
$ noLocValue $ ExplicitTuple noExtField [Present noExtField (callTraceEvent eventLogReturn (namedVar s')), Present noExtField (namedVar result) ] Unboxed
)
( mkMatchGroup (Generated OtherExpansion SkipPmc) . noLocValue . pure
$ mkHsCaseAlt (noLocValue $ TuplePat noExtField [noLocValue (WildPat noExtField), namedVarPat result'] Unboxed)
(namedVar result'))
return ( args , noLocValue wrapped )

{-------------------------------------------------------------------------------
Generate eventlog events
Expand All @@ -293,10 +309,10 @@ mkEventLogCall ReplacedForeignImport{
rfiOriginalName
, rfiForeignImport
} = do
noCallStack <- asksOption optionsDisableCallStack
noCallStack <- pure True -- asksOption optionsDisableCallStack

if noCallStack then
return $ stringExpr prefix
return $ ubstringExpr prefix
else do
callStack <- findName nameCallStack
prettyCalllStack <- findName namePrettyCallStack
Expand Down Expand Up @@ -341,7 +357,7 @@ mkEventLogCall ReplacedForeignImport{
-- | Eventlog description for the return of the foreign function
mkEventLogReturn :: ReplacedForeignImport -> Instrument (LHsExpr GhcRn)
mkEventLogReturn ReplacedForeignImport{rfiOriginalName} = do
return $ stringExpr $ concat [
return $ ubstringExpr $ concat [
"trace-foreign-calls: return "
, occNameString . nameOccName . unLoc $ rfiOriginalName
]
Expand Down Expand Up @@ -403,6 +419,9 @@ emptyWhereClause = EmptyLocalBinds noValue
stringExpr :: String -> LHsExpr GhcRn
stringExpr = noLocValue . HsLit noValue . HsString NoSourceText . fsLit

ubstringExpr :: String -> LHsExpr GhcRn
ubstringExpr = noLocValue . HsLit noValue . mkHsStringPrimLit . fsLit

callLNamedFn :: LIdP GhcRn -> [LHsExpr GhcRn] -> LHsExpr GhcRn
callLNamedFn fn args = mkHsApps (noLocValue $ HsVar noValue fn) args

Expand All @@ -413,4 +432,4 @@ namedVar :: Name -> LHsExpr GhcRn
namedVar = noLocValue . HsVar noValue . noLocValue

namedVarPat :: Name -> LPat GhcRn
namedVarPat = noLocValue . VarPat noValue . noLocValue
namedVarPat = noLocValue . VarPat noValue . noLocValue
14 changes: 8 additions & 6 deletions trace-foreign-calls/src/Plugin/TraceForeignCalls/Instrument.hs
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,21 @@ whenOption_ f = void . whenOption f
-------------------------------------------------------------------------------}

data Names = Names {
nameTraceEventIO :: TcM Name
, nameEvaluate :: TcM Name
, nameUnsafePerformIO :: TcM Name
nameTraceEventHash :: TcM Name
, nameSeq :: TcM Name
, nameRunRW :: TcM Name
, nameNoDuplicate :: TcM Name
, nameHasCallStack :: TcM Name
, nameCallStack :: TcM Name
, namePrettyCallStack :: TcM Name
}

mkNames :: Names
mkNames = Names {
nameTraceEventIO = resolveVarName modlTraceEventIO "traceEventIO"
, nameEvaluate = resolveVarName modlEvaluate "evaluate"
, nameUnsafePerformIO = resolveVarName modlUnsafePerformIO "unsafePerformIO"
nameTraceEventHash = resolveVarName modlTraceEvent "traceEvent#"
, nameSeq = resolveVarName modlSeq "seq#"
, nameRunRW = resolveVarName modlRunRW "runRW#"
, nameNoDuplicate = resolveVarName modlNoDuplicate "noDuplicate#"
, nameHasCallStack = resolveTcName modlHasCallStack "HasCallStack"
, nameCallStack = resolveVarName modlCallStack "callStack"
, namePrettyCallStack = resolveVarName modlPrettyCallStack "prettyCallStack"
Expand Down
37 changes: 13 additions & 24 deletions trace-foreign-calls/src/Plugin/TraceForeignCalls/Util/Shim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ module Plugin.TraceForeignCalls.Util.Shim (
, originGenerated
-- * Name resolution
, modlCallStack
, modlEvaluate
, modlSeq
, modlHasCallStack
, modlPrettyCallStack
, modlTraceEventIO
, modlUnsafePerformIO
, modlTraceEvent
, modlRunRW
, modlNoDuplicate
) where

import GHC
Expand Down Expand Up @@ -77,29 +78,17 @@ originGenerated = Generated OtherExpansion SkipPmc
Defining modules for various symbols
-------------------------------------------------------------------------------}

modlTraceEventIO :: Module
modlTraceEventIO =
#if !MIN_VERSION_ghc(9,9,0)
mkModule baseUnit $ mkModuleName "Debug.Trace"
#else
mkModule ghcInternalUnit $ mkModuleName "GHC.Internal.Debug.Trace"
#endif
modlTraceEvent :: Module
modlTraceEvent = mkModule primUnit $ mkModuleName "GHC.Prim"

modlEvaluate :: Module
modlEvaluate =
#if !MIN_VERSION_ghc(9,9,0)
mkModule baseUnit $ mkModuleName "GHC.IO"
#else
mkModule ghcInternalUnit $ mkModuleName "GHC.Internal.IO"
#endif
modlSeq :: Module
modlSeq = mkModule primUnit $ mkModuleName "GHC.Prim"

modlUnsafePerformIO :: Module
modlUnsafePerformIO =
#if !MIN_VERSION_ghc(9,9,0)
mkModule baseUnit $ mkModuleName "GHC.IO.Unsafe"
#else
mkModule ghcInternalUnit $ mkModuleName "GHC.Internal.IO.Unsafe"
#endif
modlRunRW :: Module
modlRunRW = mkModule primUnit $ mkModuleName "GHC.Magic"

modlNoDuplicate :: Module
modlNoDuplicate = mkModule primUnit $ mkModuleName "GHC.Prim"

modlHasCallStack :: Module
modlHasCallStack =
Expand Down

0 comments on commit d2e809a

Please sign in to comment.