4
\$\begingroup\$

Go and concurrency noob here. I wrote a program to handle a tcp connection to index (and removes) packages. I want it to be able hundreds of connections and messages that might come in at the same time. I am using a hash table to store the data and using channels to provide the worker with the incoming messages.

My understanding is that I will not have to use any locks against the data store as all add/remove will be blocked until the previous connection has finished writing and reading. Just wanted to make sure I understand go channels correctly here and I wrote an efficient solution. In the future in lieu of the hash table does it make sense to use something like Redis? Or even store a list on the file system?

I would like to understand

  • My use of channels is correct and additional locks won't be necessary. In the future the comments will be a list of other package dependencies so it's important each request can write to the data correctly
  • I'm using a hash table as the data store but since this is in memory eventually would it make sense to use something like Redis? Or even writing to a local file
  • How would one test sending hundreds of messages to this server?
  • I do check for the format of the message but what happens if a client disconnects abruptly?

Usage of the app:

go run main.go
# in a separate terminal
echo -n "INDEX|vim|comment\n" | nc localhost 3333
echo -n "REMOVE|vim|blah,blah\n" | nc localhost 3333
echo -n "INDEX|vim|comment\n" | nc localhost 3333
echo -n "INDEX|vim|comment\n" | nc localhost 3333
echo -n "COUNT|vim|comment\n" | nc localhost 3333 

The code

package main
import (
 "bufio"
 "errors"
 "fmt"
 "log"
 "net"
 "os"
 "strings"
)
const (
 CONN_HOST = "localhost"
 CONN_PORT = "3333"
 CONN_TYPE = "tcp"
 INDEX_COMMAND = "INDEX"
 REMOVE_COMMAND = "REMOVE"
 COUNT_COMMAND = "COUNT"
)
func main() {
 l, err := net.Listen(CONN_TYPE, CONN_HOST+":"+CONN_PORT)
 if err != nil {
 fmt.Println("Error listening:", err.Error())
 os.Exit(1)
 }
 defer l.Close()
 fmt.Println("Listening on " + CONN_HOST + ":" + CONN_PORT)
 msgCh := make(chan string)
 resultCh := make(chan string)
 ds := newDataStore()
 go ds.msgHandlerWorker(msgCh, resultCh)
 for {
 conn, err := l.Accept()
 if err != nil {
 fmt.Println("Error accepting: ", err.Error())
 os.Exit(1)
 }
 go handleRequest(conn, msgCh, resultCh)
 }
}
type Message struct {
 command string
 pkg string
 comment string
}
type DataStore struct {
 pkgInfo map[string]string
 pkgRef map[string]uint
}
func newDataStore() *DataStore {
 return &DataStore{
 pkgInfo: make(map[string]string),
 pkgRef: make(map[string]uint),
 }
}
func (ds *DataStore) parseMsg(msg string) (Message, error) {
 var parsedMsg Message
 m := strings.TrimSpace(msg)
 s := strings.Split(m, "|")
 if len(s) < 3 {
 return parsedMsg, errors.New("Not correct format")
 }
 parsedMsg.command = s[0]
 parsedMsg.pkg = s[1]
 parsedMsg.comment = s[2]
 return parsedMsg, nil
}
func (ds *DataStore) addToHashTable(msg Message) error {
 ds.pkgInfo[msg.pkg] = msg.comment
 if _, ok := ds.pkgRef[msg.pkg]; ok {
 ds.pkgRef[msg.pkg] += 1
 } else {
 ds.pkgRef[msg.pkg] = 0
 }
 return nil
}
func (ds *DataStore) removeFromHashTable(msg Message) error {
 delete(ds.pkgInfo, msg.pkg)
 if _, ok := ds.pkgRef[msg.pkg]; ok {
 ds.pkgRef[msg.pkg] -= 1
 }
 return nil
}
func (ds *DataStore) msgHandlerWorker(msgCh chan string, resultCh chan string) {
 for {
 msg := <-msgCh
 parsedMsg, err := ds.parseMsg(msg)
 if err != nil {
 log.Print("pailed to parse msg")
 resultCh <- "FAIL\n"
 continue
 }
 switch parsedMsg.command {
 case INDEX_COMMAND:
 if err := ds.addToHashTable(parsedMsg); err != nil {
 log.Print("Failed To add")
 resultCh <- "FAIL\n"
 continue
 }
 resultCh <- "OK\n"
 case REMOVE_COMMAND:
 if err := ds.removeFromHashTable(parsedMsg); err != nil {
 log.Print("Failed To add")
 resultCh <- "FAIL\n"
 continue
 }
 resultCh <- "OK\n"
 case COUNT_COMMAND:
 if val, ok := ds.pkgRef[parsedMsg.pkg]; ok {
 resultCh <- fmt.Sprintf("%d\n", val)
 continue
 }
 resultCh <- "FAIL\n"
 default:
 log.Print("got invalid command")
 resultCh <- "FAIL\n"
 continue
 }
 }
}
func handleRequest(conn net.Conn, msgCh chan string, resultCh chan string) {
 msg, err := bufio.NewReader(conn).ReadString('\n')
 log.Printf("msg %v", msg)
 if err != nil {
 fmt.Println("Error reading:", err.Error())
 }
 msgCh <- msg
 response := <-resultCh
 conn.Write([]byte(response))
 conn.Close()
}

Thanks for reading!

asked Nov 27, 2017 at 17:47
\$\endgroup\$

1 Answer 1

4
\$\begingroup\$

Before addressing your concern, a few things on coding style:

Coding style

This can easily be detected with tools like golint and go vet

  • In go, always use camelCase for variables and constant names, so CONN_HOST should be connHost, or just simply host

  • you can directly increment a var stored in a map, so instead of ds.pkgRef[msg.pkg] += 1, simply use ds.pkgRef[msg.pkg]++

  • use fmt.Printf("Error listening: %v\n", err) instead of fmt.Println("Error listening:", err.Error())

  • a more idiomatic way of creating a message would be

this

parsedMsg := Message{
 command: s[0],
 pkg: s[1],
 comment: s[2],
}

instead of

var parsedMsg Message
parsedMsg.command = s[0]
parsedMsg.pkg = s[1]
parsedMsg.comment = s[2]
  • some methods always return nil error (ie addToHashTable()), so no need to return/check for an error when calling this method

  • in msgHandlerWorker(), you do not log any errors that could occur :

for example,

if err != nil {
 log.Print("pailed to parse msg")
 resultCh <- "FAIL\n"
 continue
}

should be

if err != nil {
 log.Printf("pailed to parse msg: %v\n", err)
 resultCh <- "FAIL\n"
 continue
}

Concurrency

Maps are not safe for concurrent use, for more details see map documentation

Here you're modifing the map from multiple goroutines, so two solutiions to fix this:

  • use a sync.RWMutex
  • read/write to the map from a single goroutine

I'll go for the second method as there is a simple way to handle all connection in the main goroutine:

instead of go handleRequest(...), let's handle incoming connection synchronously. There's no need for channels here, we can rewrite it like this:

for {
 conn, err := l.Accept()
 if err != nil {
 fmt.Printf("Error accepting: %v\n", err)
 os.Exit(1)
 }
 // read the incoming message 
 msg, err := bufio.NewReader(conn).ReadString('\n')
 if err != nil {
 fmt.Printf("Error reading: %v\n", err)
 }
 // directly parse the message 
 response, err := ds.handleMsg(msg)
 if err != nil {
 fmt.Printf("fail to handle message: %v\n", err)
 }
 // send the response back 
 conn.Write([]byte(response + "\n"))
 conn.Close()
}

The final code would look like:

package main
import (
 "bufio"
 "fmt"
 "net"
 "os"
 "strings"
)
const (
 host = "localhost"
 port = "3333"
 adress = host + ":" + port
 connType = "tcp"
 index = "INDEX"
 remove = "REMOVE"
 count = "COUNT"
)
func main() {
 l, err := net.Listen(connType, adress)
 if err != nil {
 fmt.Printf("Error listening: %v", err)
 os.Exit(1)
 }
 defer l.Close()
 fmt.Println("Listening on " + adress)
 ds := newDataStore()
 for {
 conn, err := l.Accept()
 if err != nil {
 fmt.Printf("Error accepting: %v\n", err)
 os.Exit(1)
 }
 msg, err := bufio.NewReader(conn).ReadString('\n')
 if err != nil {
 fmt.Printf("Error reading: %v\n", err)
 }
 response, err := ds.handleMsg(msg)
 if err != nil {
 fmt.Printf("fail to handle message: %v\n", err)
 }
 conn.Write([]byte(response + "\n"))
 conn.Close()
 }
}
type Message struct {
 command string
 pkg string
 comment string
}
type DataStore struct {
 pkgInfo map[string]string
 pkgRef map[string]uint
}
func newDataStore() *DataStore {
 return &DataStore{
 pkgInfo: make(map[string]string),
 pkgRef: make(map[string]uint),
 }
}
func (ds *DataStore) handleMsg(message string) (string, error) {
 fmt.Printf("received message %s\n", message)
 m := strings.TrimSpace(message)
 s := strings.Split(m, "|")
 if len(s) < 3 {
 return "FAIL", fmt.Errorf("Incorrect format for string %v", message)
 }
 msg := Message{
 command: s[0],
 pkg: s[1],
 comment: s[2],
 }
 response := "OK"
 switch msg.command {
 case index:
 ds.pkgInfo[msg.pkg] = msg.comment
 if _, ok := ds.pkgRef[msg.pkg]; ok {
 ds.pkgRef[msg.pkg]++
 } else {
 ds.pkgRef[msg.pkg] = 0
 }
 case remove:
 delete(ds.pkgInfo, msg.pkg)
 if _, ok := ds.pkgRef[msg.pkg]; ok {
 ds.pkgRef[msg.pkg]--
 }
 case count:
 val, ok := ds.pkgRef[msg.pkg]
 if !ok {
 return "FAIL", fmt.Errorf("fail to get count for package %v", msg.pkg)
 }
 response = fmt.Sprintf("%d", val)
 default:
 return "FAIL", fmt.Errorf("go invalid command: %v", msg.command)
 }
 return response, nil
}

and here is a simple program to test concurrent connections:

package main
import (
 "fmt"
 "io/ioutil"
 "math/rand"
 "net"
 "sync"
)
var (
 message = []string{
 "INDEX|vim|comment",
 "REMOVE|vim|blah,blah",
 "INDEX|vim|comment",
 "INDEX|vim|comment",
 "COUNT|vim|comment",
 }
)
func sendMessage(i int) error {
 conn, err := net.Dial("tcp", "localhost:3333")
 if err != nil {
 return fmt.Errorf("error: %v", err)
 }
 defer conn.Close()
 index := rand.Int31n(int32(len(message)))
 _, err = conn.Write([]byte(message[index] + "\n"))
 if err != nil {
 return fmt.Errorf("error: %v", err)
 }
 buf, err := ioutil.ReadAll(conn)
 if err != nil {
 return fmt.Errorf("error: %v", err)
 }
 fmt.Printf("reponse for conn %v: %v", i, string(buf))
 return nil
}
func main() {
 var wg sync.WaitGroup
 nbGoroutines := 3
 wg.Add(nbGoroutines)
 for k := 0; k < nbGoroutines; k++ {
 go func() {
 for i := 1; i <= 100; i++ {
 err := sendMessage(i)
 if err != nil {
 fmt.Printf("fail: %v\n", err)
 break
 }
 }
 wg.Done()
 }()
 }
 wg.Wait()
}

Other concern

  • Using redis could be a good idea if there is really a huge amount of entry in your map, but if you face performance issues, profile your code with pprof to make sure that map accessing is the bottleneck

  • if the client disconnect, you won't be able to send the response, so calling Write(...) will just return an error

Hope this helps !

answered Nov 28, 2017 at 14:14
\$\endgroup\$
0

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.