module Striot.Nodes.TCP
( connectTCP
, sendStreamTCP
, processSocket
) where

import           Control.Concurrent                       (forkFinally)
import           Control.Concurrent.Async                 (async)
import           Control.Concurrent.Chan.Unagi.Bounded    as U
import qualified Control.Exception                        as E (bracket, catch,
                                                                evaluate)
import           Control.Lens
import           Control.Monad                            (forever)
import qualified Data.ByteString                          as B (ByteString,
                                                                length, null)
import           Data.Store                               (Store, decode,
                                                           encode)
import qualified Data.Store.Streaming                     as SS
import           Network.Socket
import           Network.Socket.ByteString
import           Striot.FunctionalIoTtypes
import           Striot.Nodes.Types
import           System.IO.ByteBuffer                     as BB
import           System.Metrics.Prometheus.Metric.Counter as PC (add, inc)
import           System.Metrics.Prometheus.Metric.Gauge   as PG (dec, inc)


processSocket :: Store alpha => String -> TCPConfig -> Metrics -> IO (Stream alpha)
processSocket :: forall alpha.
Store alpha =>
String -> TCPConfig -> Metrics -> IO (Stream alpha)
processSocket String
name TCPConfig
conf Metrics
met = OutChan (Event alpha) -> IO [Event alpha]
forall a. OutChan a -> IO [a]
U.getChanContents (OutChan (Event alpha) -> IO [Event alpha])
-> IO (OutChan (Event alpha)) -> IO [Event alpha]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> TCPConfig -> Metrics -> IO (OutChan (Event alpha))
forall alpha.
Store alpha =>
String -> TCPConfig -> Metrics -> IO (OutChan (Event alpha))
acceptConnections String
name TCPConfig
conf Metrics
met


acceptConnections :: Store alpha => String -> TCPConfig -> Metrics -> IO (U.OutChan (Event alpha))
acceptConnections :: forall alpha.
Store alpha =>
String -> TCPConfig -> Metrics -> IO (OutChan (Event alpha))
acceptConnections String
name TCPConfig
conf Metrics
met = do
    (InChan (Event alpha)
inChan, OutChan (Event alpha)
outChan) <- Int -> IO (InChan (Event alpha), OutChan (Event alpha))
forall a. Int -> IO (InChan a, OutChan a)
U.newChan Int
defaultChanSize
    IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ String -> TCPConfig -> Metrics -> InChan (Event alpha) -> IO ()
forall alpha.
Store alpha =>
String -> TCPConfig -> Metrics -> InChan (Event alpha) -> IO ()
connectTCP String
name TCPConfig
conf Metrics
met InChan (Event alpha)
inChan
    OutChan (Event alpha) -> IO (OutChan (Event alpha))
forall (m :: * -> *) a. Monad m => a -> m a
return OutChan (Event alpha)
outChan


defaultChanSize :: Int
defaultChanSize :: Int
defaultChanSize = Int
10


{- connectTCP sits accepting any new connections. Once accepted, a new
thread is forked to read from the socket. The function then loops to accept any
subsequent connections -}
connectTCP :: Store alpha
           => String
           -> TCPConfig
           -> Metrics
           -> U.InChan (Event alpha)
           -> IO ()
connectTCP :: forall alpha.
Store alpha =>
String -> TCPConfig -> Metrics -> InChan (Event alpha) -> IO ()
connectTCP String
_ TCPConfig
conf Metrics
met InChan (Event alpha)
chan = do
    Socket
sock <- String -> IO Socket
listenSocket (String -> IO Socket) -> String -> IO Socket
forall a b. (a -> b) -> a -> b
$ TCPConfig
conf TCPConfig -> Getting String TCPConfig String -> String
forall s a. s -> Getting a s a -> a
^. (NetConfig -> Const String NetConfig)
-> TCPConfig -> Const String TCPConfig
Iso' TCPConfig NetConfig
tcpConn ((NetConfig -> Const String NetConfig)
 -> TCPConfig -> Const String TCPConfig)
-> ((String -> Const String String)
    -> NetConfig -> Const String NetConfig)
-> Getting String TCPConfig String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Const String String)
-> NetConfig -> Const String NetConfig
Lens' NetConfig String
port
    IO ThreadId -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        (Socket
conn, SockAddr
_) <- Socket -> IO (Socket, SockAddr)
accept Socket
sock
        IO () -> (Either SomeException () -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (Gauge -> IO ()
PG.inc (Metrics -> Gauge
_ingressConn Metrics
met)
                    IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Metrics -> Socket -> InChan (Event alpha) -> IO ()
forall alpha.
Store alpha =>
Metrics -> Socket -> InChan (Event alpha) -> IO ()
processData Metrics
met Socket
conn InChan (Event alpha)
chan)
                    (\Either SomeException ()
_ -> Gauge -> IO ()
PG.dec (Metrics -> Gauge
_ingressConn Metrics
met)
                        IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO ()
close Socket
conn)


{- processData takes a Socket and UChan. All of the events are read through
use of a ByteBuffer and recv. The events are decoded by using store-streaming
and added to the chan  -}
processData :: Store alpha => Metrics -> Socket -> U.InChan (Event alpha) -> IO ()
processData :: forall alpha.
Store alpha =>
Metrics -> Socket -> InChan (Event alpha) -> IO ()
processData Metrics
met Socket
conn InChan (Event alpha)
eventChan =
    Maybe Int -> (ByteBuffer -> IO ()) -> IO ()
forall (m :: * -> *) a.
(MonadIO m, MonadBaseControl IO m) =>
Maybe Int -> (ByteBuffer -> m a) -> m a
BB.with Maybe Int
forall a. Maybe a
Nothing ((ByteBuffer -> IO ()) -> IO ()) -> (ByteBuffer -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ByteBuffer
buffer -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Maybe (Message (Event alpha))
event <- Metrics
-> ByteBuffer
-> IO (Maybe ByteString)
-> IO (Maybe (Message (Event alpha)))
forall a.
Store a =>
Metrics
-> ByteBuffer -> IO (Maybe ByteString) -> IO (Maybe (Message a))
decodeMessageBS' Metrics
met ByteBuffer
buffer (Socket -> IO (Maybe ByteString)
readFromSocket Socket
conn)
        case Maybe (Message (Event alpha))
event of
            Just Message (Event alpha)
m  -> do
                        Counter -> IO ()
PC.inc (Metrics -> Counter
_ingressEvents Metrics
met)
                        InChan (Event alpha) -> Event alpha -> IO ()
forall a. InChan a -> a -> IO ()
U.writeChan InChan (Event alpha)
eventChan (Event alpha -> IO ()) -> Event alpha -> IO ()
forall a b. (a -> b) -> a -> b
$ Message (Event alpha) -> Event alpha
forall a. Message a -> a
SS.fromMessage Message (Event alpha)
m
            Maybe (Message (Event alpha))
Nothing -> String -> IO ()
forall a. Show a => a -> IO ()
print String
"decode failed"


{- This is a rewrite of Data.Store.Streaming decodeMessageBS, passing in
Metrics so that we can calculate ingressBytes while decoding -}
decodeMessageBS' :: Store a
                 => Metrics -> BB.ByteBuffer
                 -> IO (Maybe B.ByteString) -> IO (Maybe (SS.Message a))
decodeMessageBS' :: forall a.
Store a =>
Metrics
-> ByteBuffer -> IO (Maybe ByteString) -> IO (Maybe (Message a))
decodeMessageBS' Metrics
met = FillByteBuffer ByteString IO
-> ByteBuffer -> IO (Maybe ByteString) -> IO (Maybe (Message a))
forall a (m :: * -> *) i.
(Store a, MonadIO m) =>
FillByteBuffer i m
-> ByteBuffer -> m (Maybe i) -> m (Maybe (Message a))
SS.decodeMessage (\ByteBuffer
bb Int
_ ByteString
bs -> Int -> Counter -> IO ()
PC.add (ByteString -> Int
B.length ByteString
bs)
                                                            (Metrics -> Counter
_ingressBytes Metrics
met)
                                                     IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteBuffer -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => ByteBuffer -> ByteString -> m ()
BB.copyByteString ByteBuffer
bb ByteString
bs)


{- Read up to 4096 bytes from the socket at a time, packing into a Maybe
structure. As we use TCP sockets recv should block, and so if msg is empty
the connection has been closed -}
readFromSocket :: Socket -> IO (Maybe B.ByteString)
readFromSocket :: Socket -> IO (Maybe ByteString)
readFromSocket Socket
conn = do
    ByteString
msg <- Socket -> Int -> IO ByteString
recv Socket
conn Int
4096
    if ByteString -> Bool
B.null ByteString
msg
        then String -> IO (Maybe ByteString)
forall a. HasCallStack => String -> a
error String
"Upstream connection closed"
        else Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
msg


{- Connects to socket within a bracket to ensure the socket is closed if an
exception occurs -}
sendStreamTCP :: Store alpha => String -> TCPConfig -> Metrics -> Stream alpha -> IO ()
sendStreamTCP :: forall alpha.
Store alpha =>
String -> TCPConfig -> Metrics -> Stream alpha -> IO ()
sendStreamTCP String
_ TCPConfig
_    Metrics
_   []     = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sendStreamTCP String
_ TCPConfig
conf Metrics
met [Event alpha]
stream =
    IO Socket -> (Socket -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Gauge -> IO ()
PG.inc (Metrics -> Gauge
_egressConn Metrics
met)
               IO () -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> String -> IO Socket
connectSocket (TCPConfig
conf TCPConfig -> Getting String TCPConfig String -> String
forall s a. s -> Getting a s a -> a
^. (NetConfig -> Const String NetConfig)
-> TCPConfig -> Const String TCPConfig
Iso' TCPConfig NetConfig
tcpConn ((NetConfig -> Const String NetConfig)
 -> TCPConfig -> Const String TCPConfig)
-> ((String -> Const String String)
    -> NetConfig -> Const String NetConfig)
-> Getting String TCPConfig String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Const String String)
-> NetConfig -> Const String NetConfig
Lens' NetConfig String
host) (TCPConfig
conf TCPConfig -> Getting String TCPConfig String -> String
forall s a. s -> Getting a s a -> a
^. (NetConfig -> Const String NetConfig)
-> TCPConfig -> Const String TCPConfig
Iso' TCPConfig NetConfig
tcpConn ((NetConfig -> Const String NetConfig)
 -> TCPConfig -> Const String TCPConfig)
-> ((String -> Const String String)
    -> NetConfig -> Const String NetConfig)
-> Getting String TCPConfig String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Const String String)
-> NetConfig -> Const String NetConfig
Lens' NetConfig String
port))
              (\Socket
conn -> Gauge -> IO ()
PG.dec (Metrics -> Gauge
_egressConn Metrics
met)
                        IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO ()
close Socket
conn)
              (\Socket
conn -> Socket -> Metrics -> [Event alpha] -> IO ()
forall alpha.
Store alpha =>
Socket -> Metrics -> Stream alpha -> IO ()
writeSocket Socket
conn Metrics
met [Event alpha]
stream)


{- Encode messages and send over the socket -}
writeSocket :: Store alpha => Socket -> Metrics -> Stream alpha -> IO ()
writeSocket :: forall alpha.
Store alpha =>
Socket -> Metrics -> Stream alpha -> IO ()
writeSocket Socket
conn Metrics
met =
    (Event alpha -> IO ()) -> [Event alpha] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\Event alpha
event ->
            let val :: ByteString
val = Message (Event alpha) -> ByteString
forall a. Store a => Message a -> ByteString
SS.encodeMessage (Message (Event alpha) -> ByteString)
-> (Event alpha -> Message (Event alpha))
-> Event alpha
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Event alpha -> Message (Event alpha)
forall a. a -> Message a
SS.Message (Event alpha -> ByteString) -> Event alpha -> ByteString
forall a b. (a -> b) -> a -> b
$ Event alpha
event
            in  Counter -> IO ()
PC.inc (Metrics -> Counter
_egressEvents Metrics
met)
                IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> Counter -> IO ()
PC.add (ByteString -> Int
B.length ByteString
val) (Metrics -> Counter
_egressBytes Metrics
met)
                IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> ByteString -> IO ()
sendAll Socket
conn ByteString
val)


--- SOCKETS ---

listenSocket :: ServiceName -> IO Socket
listenSocket :: String -> IO Socket
listenSocket String
port = do
    let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_PASSIVE],
                               addrSocketType :: SocketType
addrSocketType = SocketType
Stream }
    (Socket
sock, AddrInfo
addr) <- String -> String -> AddrInfo -> IO (Socket, AddrInfo)
createSocket [] String
port AddrInfo
hints
    Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
    Socket -> SockAddr -> IO ()
bind Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
    Socket -> Int -> IO ()
listen Socket
sock Int
maxQConn
    Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
    where maxQConn :: Int
maxQConn = Int
10


connectSocket :: HostName -> ServiceName -> IO Socket
connectSocket :: String -> String -> IO Socket
connectSocket String
host String
port = do
    let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrSocketType :: SocketType
addrSocketType = SocketType
Stream }
    (Socket
sock, AddrInfo
addr) <- String -> String -> AddrInfo -> IO (Socket, AddrInfo)
createSocket String
host String
port AddrInfo
hints
    Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
KeepAlive Int
1
    Socket -> SockAddr -> IO ()
connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
    Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock


createSocket :: HostName -> ServiceName -> AddrInfo -> IO (Socket, AddrInfo)
createSocket :: String -> String -> AddrInfo -> IO (Socket, AddrInfo)
createSocket String
host String
port AddrInfo
hints = do
    AddrInfo
addr <- String -> String -> AddrInfo -> IO AddrInfo
resolve String
host String
port AddrInfo
hints
    Socket
sock <- AddrInfo -> IO Socket
getSocket AddrInfo
addr
    (Socket, AddrInfo) -> IO (Socket, AddrInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sock, AddrInfo
addr)
  where
    resolve :: String -> String -> AddrInfo -> IO AddrInfo
resolve String
host' String
port' AddrInfo
hints' = do
        AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints') (String -> Maybe String
forall {t :: * -> *} {a}. Foldable t => t a -> Maybe (t a)
isHost String
host') (String -> Maybe String
forall a. a -> Maybe a
Just String
port')
        AddrInfo -> IO AddrInfo
forall (m :: * -> *) a. Monad m => a -> m a
return AddrInfo
addr
    getSocket :: AddrInfo -> IO Socket
getSocket AddrInfo
addr = Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr)
                            (AddrInfo -> SocketType
addrSocketType AddrInfo
addr)
                            (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr)
    isHost :: t a -> Maybe (t a)
isHost t a
h
        | t a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null t a
h    = Maybe (t a)
forall a. Maybe a
Nothing
        | Bool
otherwise = t a -> Maybe (t a)
forall a. a -> Maybe a
Just t a
h