Go and concurrency noob here. I wrote a program to handle a tcp connection to index (and removes) packages. I want it to be able hundreds of connections and messages that might come in at the same time. I am using a hash table to store the data and using channels to provide the worker with the incoming messages.
My understanding is that I will not have to use any locks against the data store as all add/remove will be blocked until the previous connection has finished writing and reading. Just wanted to make sure I understand go channels correctly here and I wrote an efficient solution. In the future in lieu of the hash table does it make sense to use something like Redis? Or even store a list on the file system?
I would like to understand
- My use of channels is correct and additional locks won't be necessary. In the future the comments will be a list of other package dependencies so it's important each request can write to the data correctly
- I'm using a hash table as the data store but since this is in memory eventually would it make sense to use something like Redis? Or even writing to a local file
- How would one test sending hundreds of messages to this server?
- I do check for the format of the message but what happens if a client disconnects abruptly?
Usage of the app:
go run main.go
# in a separate terminal
echo -n "INDEX|vim|comment\n" | nc localhost 3333
echo -n "REMOVE|vim|blah,blah\n" | nc localhost 3333
echo -n "INDEX|vim|comment\n" | nc localhost 3333
echo -n "INDEX|vim|comment\n" | nc localhost 3333
echo -n "COUNT|vim|comment\n" | nc localhost 3333
The code
package main
import (
"bufio"
"errors"
"fmt"
"log"
"net"
"os"
"strings"
)
const (
CONN_HOST = "localhost"
CONN_PORT = "3333"
CONN_TYPE = "tcp"
INDEX_COMMAND = "INDEX"
REMOVE_COMMAND = "REMOVE"
COUNT_COMMAND = "COUNT"
)
func main() {
l, err := net.Listen(CONN_TYPE, CONN_HOST+":"+CONN_PORT)
if err != nil {
fmt.Println("Error listening:", err.Error())
os.Exit(1)
}
defer l.Close()
fmt.Println("Listening on " + CONN_HOST + ":" + CONN_PORT)
msgCh := make(chan string)
resultCh := make(chan string)
ds := newDataStore()
go ds.msgHandlerWorker(msgCh, resultCh)
for {
conn, err := l.Accept()
if err != nil {
fmt.Println("Error accepting: ", err.Error())
os.Exit(1)
}
go handleRequest(conn, msgCh, resultCh)
}
}
type Message struct {
command string
pkg string
comment string
}
type DataStore struct {
pkgInfo map[string]string
pkgRef map[string]uint
}
func newDataStore() *DataStore {
return &DataStore{
pkgInfo: make(map[string]string),
pkgRef: make(map[string]uint),
}
}
func (ds *DataStore) parseMsg(msg string) (Message, error) {
var parsedMsg Message
m := strings.TrimSpace(msg)
s := strings.Split(m, "|")
if len(s) < 3 {
return parsedMsg, errors.New("Not correct format")
}
parsedMsg.command = s[0]
parsedMsg.pkg = s[1]
parsedMsg.comment = s[2]
return parsedMsg, nil
}
func (ds *DataStore) addToHashTable(msg Message) error {
ds.pkgInfo[msg.pkg] = msg.comment
if _, ok := ds.pkgRef[msg.pkg]; ok {
ds.pkgRef[msg.pkg] += 1
} else {
ds.pkgRef[msg.pkg] = 0
}
return nil
}
func (ds *DataStore) removeFromHashTable(msg Message) error {
delete(ds.pkgInfo, msg.pkg)
if _, ok := ds.pkgRef[msg.pkg]; ok {
ds.pkgRef[msg.pkg] -= 1
}
return nil
}
func (ds *DataStore) msgHandlerWorker(msgCh chan string, resultCh chan string) {
for {
msg := <-msgCh
parsedMsg, err := ds.parseMsg(msg)
if err != nil {
log.Print("pailed to parse msg")
resultCh <- "FAIL\n"
continue
}
switch parsedMsg.command {
case INDEX_COMMAND:
if err := ds.addToHashTable(parsedMsg); err != nil {
log.Print("Failed To add")
resultCh <- "FAIL\n"
continue
}
resultCh <- "OK\n"
case REMOVE_COMMAND:
if err := ds.removeFromHashTable(parsedMsg); err != nil {
log.Print("Failed To add")
resultCh <- "FAIL\n"
continue
}
resultCh <- "OK\n"
case COUNT_COMMAND:
if val, ok := ds.pkgRef[parsedMsg.pkg]; ok {
resultCh <- fmt.Sprintf("%d\n", val)
continue
}
resultCh <- "FAIL\n"
default:
log.Print("got invalid command")
resultCh <- "FAIL\n"
continue
}
}
}
func handleRequest(conn net.Conn, msgCh chan string, resultCh chan string) {
msg, err := bufio.NewReader(conn).ReadString('\n')
log.Printf("msg %v", msg)
if err != nil {
fmt.Println("Error reading:", err.Error())
}
msgCh <- msg
response := <-resultCh
conn.Write([]byte(response))
conn.Close()
}
Thanks for reading!
1 Answer 1
Before addressing your concern, a few things on coding style:
Coding style
This can easily be detected with tools like golint and go vet
In go, always use camelCase for variables and constant names, so
CONN_HOST
should beconnHost
, or just simplyhost
you can directly increment a var stored in a map, so instead of
ds.pkgRef[msg.pkg] += 1
, simply useds.pkgRef[msg.pkg]++
use
fmt.Printf("Error listening: %v\n", err)
instead offmt.Println("Error listening:", err.Error())
a more idiomatic way of creating a message would be
this
parsedMsg := Message{
command: s[0],
pkg: s[1],
comment: s[2],
}
instead of
var parsedMsg Message
parsedMsg.command = s[0]
parsedMsg.pkg = s[1]
parsedMsg.comment = s[2]
some methods always return
nil
error (ieaddToHashTable()
), so no need to return/check for an error when calling this methodin
msgHandlerWorker()
, you do not log any errors that could occur :
for example,
if err != nil {
log.Print("pailed to parse msg")
resultCh <- "FAIL\n"
continue
}
should be
if err != nil {
log.Printf("pailed to parse msg: %v\n", err)
resultCh <- "FAIL\n"
continue
}
Concurrency
Maps are not safe for concurrent use, for more details see map documentation
Here you're modifing the map from multiple goroutines, so two solutiions to fix this:
- use a sync.RWMutex
- read/write to the map from a single goroutine
I'll go for the second method as there is a simple way to handle all connection in the main goroutine:
instead of go handleRequest(...)
, let's handle incoming connection synchronously. There's no need
for channels here, we can rewrite it like this:
for {
conn, err := l.Accept()
if err != nil {
fmt.Printf("Error accepting: %v\n", err)
os.Exit(1)
}
// read the incoming message
msg, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
fmt.Printf("Error reading: %v\n", err)
}
// directly parse the message
response, err := ds.handleMsg(msg)
if err != nil {
fmt.Printf("fail to handle message: %v\n", err)
}
// send the response back
conn.Write([]byte(response + "\n"))
conn.Close()
}
The final code would look like:
package main
import (
"bufio"
"fmt"
"net"
"os"
"strings"
)
const (
host = "localhost"
port = "3333"
adress = host + ":" + port
connType = "tcp"
index = "INDEX"
remove = "REMOVE"
count = "COUNT"
)
func main() {
l, err := net.Listen(connType, adress)
if err != nil {
fmt.Printf("Error listening: %v", err)
os.Exit(1)
}
defer l.Close()
fmt.Println("Listening on " + adress)
ds := newDataStore()
for {
conn, err := l.Accept()
if err != nil {
fmt.Printf("Error accepting: %v\n", err)
os.Exit(1)
}
msg, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
fmt.Printf("Error reading: %v\n", err)
}
response, err := ds.handleMsg(msg)
if err != nil {
fmt.Printf("fail to handle message: %v\n", err)
}
conn.Write([]byte(response + "\n"))
conn.Close()
}
}
type Message struct {
command string
pkg string
comment string
}
type DataStore struct {
pkgInfo map[string]string
pkgRef map[string]uint
}
func newDataStore() *DataStore {
return &DataStore{
pkgInfo: make(map[string]string),
pkgRef: make(map[string]uint),
}
}
func (ds *DataStore) handleMsg(message string) (string, error) {
fmt.Printf("received message %s\n", message)
m := strings.TrimSpace(message)
s := strings.Split(m, "|")
if len(s) < 3 {
return "FAIL", fmt.Errorf("Incorrect format for string %v", message)
}
msg := Message{
command: s[0],
pkg: s[1],
comment: s[2],
}
response := "OK"
switch msg.command {
case index:
ds.pkgInfo[msg.pkg] = msg.comment
if _, ok := ds.pkgRef[msg.pkg]; ok {
ds.pkgRef[msg.pkg]++
} else {
ds.pkgRef[msg.pkg] = 0
}
case remove:
delete(ds.pkgInfo, msg.pkg)
if _, ok := ds.pkgRef[msg.pkg]; ok {
ds.pkgRef[msg.pkg]--
}
case count:
val, ok := ds.pkgRef[msg.pkg]
if !ok {
return "FAIL", fmt.Errorf("fail to get count for package %v", msg.pkg)
}
response = fmt.Sprintf("%d", val)
default:
return "FAIL", fmt.Errorf("go invalid command: %v", msg.command)
}
return response, nil
}
and here is a simple program to test concurrent connections:
package main
import (
"fmt"
"io/ioutil"
"math/rand"
"net"
"sync"
)
var (
message = []string{
"INDEX|vim|comment",
"REMOVE|vim|blah,blah",
"INDEX|vim|comment",
"INDEX|vim|comment",
"COUNT|vim|comment",
}
)
func sendMessage(i int) error {
conn, err := net.Dial("tcp", "localhost:3333")
if err != nil {
return fmt.Errorf("error: %v", err)
}
defer conn.Close()
index := rand.Int31n(int32(len(message)))
_, err = conn.Write([]byte(message[index] + "\n"))
if err != nil {
return fmt.Errorf("error: %v", err)
}
buf, err := ioutil.ReadAll(conn)
if err != nil {
return fmt.Errorf("error: %v", err)
}
fmt.Printf("reponse for conn %v: %v", i, string(buf))
return nil
}
func main() {
var wg sync.WaitGroup
nbGoroutines := 3
wg.Add(nbGoroutines)
for k := 0; k < nbGoroutines; k++ {
go func() {
for i := 1; i <= 100; i++ {
err := sendMessage(i)
if err != nil {
fmt.Printf("fail: %v\n", err)
break
}
}
wg.Done()
}()
}
wg.Wait()
}
Other concern
Using redis could be a good idea if there is really a huge amount of entry in your map, but if you face performance issues, profile your code with pprof to make sure that map accessing is the bottleneck
if the client disconnect, you won't be able to send the response, so calling
Write(...)
will just return an error
Hope this helps !