Skip to content

Commit

Permalink
Merge pull request #880 from ronawho/opt-array-transfer-server
Browse files Browse the repository at this point in the history
Optimize server conversions for client-server array transfers
  • Loading branch information
reuster986 authored Jul 14, 2021
2 parents 661fab1 + ebc80ce commit 2df9393
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 122 deletions.
22 changes: 14 additions & 8 deletions arkouda/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@
__all__ = ["DTypes", "DTypeObjects", "dtype", "bool", "int64", "float64",
"uint8", "str_", "check_np_dtype", "translate_np_dtype",
"resolve_scalar_dtype", "ARKOUDA_SUPPORTED_DTYPES", "bool_scalars",
"float_scalars", "int_scalars", "numeric_scalars", "numpy_scalars",
"str_scalars", "all_scalars", "get_byteorder"]

# supported dtypes
structDtypeCodes = {'int64': 'q',
'float64': 'd',
'bool': '?',
'uint8': 'B'}
"float_scalars", "int_scalars", "numeric_scalars", "numpy_scalars",
"str_scalars", "all_scalars", "get_byteorder",
"get_server_byteorder"]

NUMBER_FORMAT_STRINGS = {'bool': '{}',
'int64': '{:n}',
'float64': '{:.17f}',
Expand Down Expand Up @@ -176,3 +172,13 @@ def get_byteorder(dt: np.dtype) -> str:
raise ValueError("Client byteorder must be 'little' or 'big'")
else:
return dt.byteorder

def get_server_byteorder() -> str:
"""
Get the server's byteorder
"""
from arkouda.client import get_config
order = get_config()['byteorder']
if order not in ('little', 'big'):
raise ValueError("Server byteorder must be 'little' or 'big'")
return cast('str', order)
31 changes: 9 additions & 22 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations
from typing import cast, List, Sequence
from typeguard import typechecked
import json, struct
import json
import numpy as np # type: ignore
from arkouda.client import generic_msg
from arkouda.dtypes import dtype, DTypes, resolve_scalar_dtype, \
structDtypeCodes, translate_np_dtype, NUMBER_FORMAT_STRINGS, \
int_scalars, numeric_scalars, numpy_scalars
translate_np_dtype, NUMBER_FORMAT_STRINGS, \
int_scalars, numeric_scalars, numpy_scalars, get_server_byteorder
from arkouda.dtypes import int64 as akint64
from arkouda.dtypes import str_ as akstr_
from arkouda.dtypes import bool as npbool
Expand Down Expand Up @@ -856,10 +856,13 @@ def to_ndarray(self) -> np.ndarray:
if len(rep_msg) != self.size*self.dtype.itemsize:
raise RuntimeError("Expected {} bytes but received {}".\
format(self.size*self.dtype.itemsize, len(rep_msg)))
# The server sends us big-endian bytes so we need to account for that.
# The server sends us native-endian bytes so we need to account for that.
# Since bytes are immutable, we need to copy the np array to be mutable
dt = np.dtype(self.dtype)
dt = dt.newbyteorder('>')
if get_server_byteorder() == 'big':
dt = dt.newbyteorder('>')
else:
dt = dt.newbyteorder('<')
return np.frombuffer(rep_msg, dt).copy()

def to_cuda(self):
Expand Down Expand Up @@ -915,24 +918,8 @@ def to_cuda(self):
raise ModuleNotFoundError(('Numba is not enabled or installed and ' +
'is required for GPU support.'))

# Total number of bytes in the array data
arraybytes = self.size * self.dtype.itemsize

from arkouda.client import maxTransferBytes
# Guard against overflowing client memory
if arraybytes > maxTransferBytes:
raise RuntimeError(("Array exceeds allowed size for transfer. " +
"Increase client.maxTransferBytes to allow"))
# The reply from the server will be a bytes object
rep_msg = generic_msg(cmd="tondarray", args="{}".format(self.name), recv_bytes=True)
# Make sure the received data has the expected length
if len(rep_msg) != self.size*self.dtype.itemsize:
raise RuntimeError("Expected {} bytes but received {}".\
format(self.size*self.dtype.itemsize, len(rep_msg)))
# Use struct to interpret bytes as a big-endian numeric array
fmt = '>{:n}{}'.format(self.size, structDtypeCodes[self.dtype.name])
# Return a numba devicendarray
return cuda.to_device(struct.unpack(fmt, rep_msg))
return cuda.to_device(self.to_ndarray())

@typechecked
def save(self, prefix_path : str, dataset : str='array', mode : str='truncate') -> str:
Expand Down
12 changes: 7 additions & 5 deletions arkouda/pdarraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import cast, Iterable, Optional, Union
from typeguard import typechecked
from arkouda.client import generic_msg
from arkouda.dtypes import structDtypeCodes, NUMBER_FORMAT_STRINGS, float64, int64, \
from arkouda.dtypes import NUMBER_FORMAT_STRINGS, float64, int64, \
DTypes, isSupportedInt, isSupportedNumber, NumericDTypes, SeriesDTypes,\
int_scalars, numeric_scalars, get_byteorder
int_scalars, numeric_scalars, get_byteorder, get_server_byteorder
from arkouda.dtypes import dtype as akdtype
from arkouda.pdarrayclass import pdarray, create_pdarray
from arkouda.strings import Strings
Expand Down Expand Up @@ -203,9 +203,11 @@ def array(a : Union[pdarray,np.ndarray, Iterable]) -> Union[pdarray, Strings]:
raise RuntimeError(("Array exceeds allowed transfer size. Increase " +
"ak.maxTransferBytes to allow"))
# Pack binary array data into a bytes object with a command header
# including the dtype and size. Note that the server expects big-endian so
# if we're using litle-endian swap the bytes before sending.
if get_byteorder(a.dtype) == '<':
# including the dtype and size. If the server has a different byteorder
# than our numpy array we need to swap to match since the server expects
# native endian bytes
if ((get_byteorder(a.dtype) == '<' and get_server_byteorder() == 'big') or
(get_byteorder(a.dtype) == '>' and get_server_byteorder() == 'little')):
abytes = a.byteswap().tobytes()
else:
abytes = a.tobytes()
Expand Down
115 changes: 35 additions & 80 deletions src/GenSymIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -55,39 +55,29 @@ module GenSymIO {
}

overMemLimit(2*8*size);
var tmpf:file; defer { ensureClose(tmpf); }

gsLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"dtype: %t size: %i".format(dtype,size));

// Write the data payload composing the pdarray to a memory buffer
try {
tmpf = openmem();
var tmpw = tmpf.writer(kind=iobig);
tmpw.write(data);
tmpw.close();
} catch {
var errorMsg = "Could not write to memory buffer";
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
proc bytesToSymEntry(size:int, type t, st: borrowed SymTab, ref data:bytes): string throws {
var entry = new shared SymEntry(size, t);
var localA = makeArrayFromPtr(data.c_str():c_void_ptr:c_ptr(t), size:uint);
entry.a = localA;
var name = st.nextName();
st.addEntry(name, entry);
return name;
}

try { // Read data in SymEntry based on type
if dtype == DType.Int64 {
rname = makeEntry(size, int, st, tmpf);
} else if dtype == DType.Float64 {
rname = makeEntry(size, real, st, tmpf);
} else if dtype == DType.Bool {
rname = makeEntry(size, bool, st, tmpf);
} else if dtype == DType.UInt8 {
rname = makeEntry(size, uint(8), st, tmpf);
} else {
msg = "Unhandled data type %s".format(dtypeBytes);
msgType = MsgType.ERROR;
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),msg);
}
} catch {
msg = "Could not read from memory buffer into SymEntry";
if dtype == DType.Int64 {
rname = bytesToSymEntry(size, int, st, data);
} else if dtype == DType.Float64 {
rname = bytesToSymEntry(size, real, st, data);
} else if dtype == DType.Bool {
rname = bytesToSymEntry(size, bool, st, data);
} else if dtype == DType.UInt8 {
rname = bytesToSymEntry(size, uint(8), st, data);
} else {
msg = "Unhandled data type %s".format(dtypeBytes);
msgType = MsgType.ERROR;
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),msg);
}
Expand All @@ -100,22 +90,6 @@ module GenSymIO {
return new MsgTuple(msg, msgType);
}

/*
* Read the data payload from the memory buffer, encapsulate
* within a SymEntry, and write to the SymTab cache
* Here tmpf is a memory buffer which contains the data we want to read.
*/
private proc makeEntry(size:int, type t, st: borrowed SymTab, tmpf:file): string throws {
var entry = new shared SymEntry(size, t);
var tmpr = tmpf.reader(kind=iobig, start=0);
var localA: [entry.aD.low..entry.aD.high] t;
tmpr.read(localA);
entry.a = localA;
tmpr.close();
var name = st.nextName();
st.addEntry(name, entry);
return name;
}

/*
* Ensure the file is closed, disregard errors
Expand All @@ -139,48 +113,29 @@ module GenSymIO {
var arrayBytes: bytes;
var entry = st.lookup(payload);
overMemLimit(2*entry.size*entry.itemsize);
var tmpf: file; defer { ensureClose(tmpf); }

proc localizeArr(A: [?D] ?eltType) {
const localA:[D.low..D.high] eltType = A;
return localA;
}
try {
tmpf = openmem();
var tmpw = tmpf.writer(kind=iobig);
if entry.dtype == DType.Int64 {
tmpw.write(localizeArr(toSymEntry(entry, int).a));
} else if entry.dtype == DType.Float64 {
tmpw.write(localizeArr(toSymEntry(entry, real).a));
} else if entry.dtype == DType.Bool {
tmpw.write(localizeArr(toSymEntry(entry, bool).a));
} else if entry.dtype == DType.UInt8 {
tmpw.write(localizeArr(toSymEntry(entry, uint(8)).a));
} else {
var errorMsg = "Error: Unhandled dtype %s".format(entry.dtype);
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return errorMsg.encode(); // return as bytes
}
tmpw.close();
} catch {
return b"Error: Unable to write SymEntry to memory buffer";
proc distArrToBytes(A: [?D] ?eltType) {
var ptr = c_malloc(eltType, D.size);
var localA = makeArrayFromPtr(ptr, D.size:uint);
localA = A;
const size = D.size*c_sizeof(eltType):int;
return createBytesWithOwnedBuffer(ptr:c_ptr(uint(8)), size, size);
}

try {
var tmpr = tmpf.reader(kind=iobig, start=0);
tmpr.readbytes(arrayBytes);
tmpr.close();
} catch {
return b"Error: Unable to copy array from memory buffer to string";
if entry.dtype == DType.Int64 {
arrayBytes = distArrToBytes(toSymEntry(entry, int).a);
} else if entry.dtype == DType.Float64 {
arrayBytes = distArrToBytes(toSymEntry(entry, real).a);
} else if entry.dtype == DType.Bool {
arrayBytes = distArrToBytes(toSymEntry(entry, bool).a);
} else if entry.dtype == DType.UInt8 {
arrayBytes = distArrToBytes(toSymEntry(entry, uint(8)).a);
} else {
var errorMsg = "Error: Unhandled dtype %s".format(entry.dtype);
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return errorMsg.encode(); // return as bytes
}
//var repMsg = try! "Array: %i".format(arraystr.length) + arraystr;
/*
Engin: fwiw, if you want to achieve the above, you can:

return b"Array: %i %|t".format(arrayBytes.length, arrayBytes);
But I think the main problem is how to separate the length from the data
*/
return arrayBytes;
}

Expand Down
16 changes: 15 additions & 1 deletion src/ServerConfig.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ module ServerConfig
const LocaleConfigs: [LocaleSpace] owned LocaleConfig;
const authenticate: bool;
const logLevel: LogLevel;
const byteorder: string;
}
var (Zmajor, Zminor, Zmicro) = ZMQ.version;
var H5major: c_uint, H5minor: c_uint, H5micro: c_uint;
Expand All @@ -111,7 +112,8 @@ module ServerConfig
distributionType = (makeDistDom(10).type):string,
LocaleConfigs = [loc in LocaleSpace] new owned LocaleConfig(loc),
authenticate = authenticate,
logLevel = logLevel
logLevel = logLevel,
byteorder = try! getByteorder()
);

return cfg;
Expand All @@ -138,6 +140,18 @@ module ServerConfig
return here.physicalMemory();
}

/*
Get the byteorder (endianness) of this locale
*/
proc getByteorder() throws {
use IO;
var writeVal = 1, readVal = 0;
var tmpf = openmem();
tmpf.writer(kind=iobig).write(writeVal);
tmpf.reader(kind=ionative, start=0).read(readVal);
return if writeVal == readVal then "big" else "little";
}

/*
Get the memory used on this locale
*/
Expand Down
6 changes: 0 additions & 6 deletions tests/dtypes_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,3 @@ def test_number_format_strings(self):
self.assertEqual('{:.17f}', dtypes.NUMBER_FORMAT_STRINGS['float64'])
self.assertEqual('f', dtypes.NUMBER_FORMAT_STRINGS['np.float64'])
self.assertEqual('{:n}', dtypes.NUMBER_FORMAT_STRINGS['uint8'])

def test_structDtypeCodes(self):
self.assertEqual('q', dtypes.structDtypeCodes['int64'])
self.assertEqual('d', dtypes.structDtypeCodes['float64'])
self.assertEqual('?', dtypes.structDtypeCodes['bool'])
self.assertEqual('B', dtypes.structDtypeCodes['uint8'])

0 comments on commit 2df9393

Please sign in to comment.