Skip to content

Commit e4a5c99

Browse files
committed
byte support for upload
1 parent eb0f207 commit e4a5c99

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

python/fedml/api/modules/storage.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,29 @@ def __init__(self, data: dict):
3030
class DataType(Enum):
3131
FILE = "file"
3232
DIRECTORY = "directory"
33+
BYTE = "byte"
3334
INVALID = "invalid"
3435

36+
37+
3538
# Todo (alaydshah): Store service name in metadata
3639
# Todo (alaydshah): If data already exists, don't upload again. Instead suggest to use update command
3740
# Todo (bhargav) : Discuss and remove the service variable. Maybe needed sometime later.
3841
def upload(data_path, api_key, name, description, tag_list, service, show_progress, out_progress_to_err, progress_desc,
39-
metadata) -> FedMLResponse:
42+
metadata, byte_data_flag=False, byte_data=None) -> FedMLResponse:
4043
api_key = authenticate(api_key)
4144

4245
user_id, message = _get_user_id_from_api_key(api_key)
4346

4447
if user_id is None:
4548
return FedMLResponse(code=ResponseCode.FAILURE, message=message)
4649

47-
data_type = _get_data_type(data_path)
50+
data_type = _get_data_type(data_path, byte_data_flag)
4851

49-
if(data_type == DataType.INVALID):
52+
if data_type == DataType.INVALID:
5053
return FedMLResponse(code=ResponseCode.FAILURE,message="Invalid data path")
5154

52-
if(data_type == DataType.DIRECTORY):
55+
if data_type == DataType.DIRECTORY:
5356
to_upload_path, message = _archive_data(data_path)
5457
name = os.path.splitext(os.path.basename(to_upload_path))[0] if name is None else name
5558
file_name = name + ".zip"
@@ -67,18 +70,24 @@ def upload(data_path, api_key, name, description, tag_list, service, show_progre
6770

6871
file_name = name
6972

70-
if not to_upload_path:
73+
if not to_upload_path and not byte_data_flag:
7174
return FedMLResponse(code=ResponseCode.FAILURE, message=message)
7275

7376
#TODO(bhargav191098) - Better done on the backend. Remove and pass file_name once completed on backend.
7477
dest_path = os.path.join(user_id, file_name)
75-
file_size = os.path.getsize(to_upload_path)
78+
max_chunk_size = 20 * 1024 * 1024
79+
80+
if byte_data_flag:
81+
file_size = sum(len(chunk) for chunk in get_chunks_from_byte_data(byte_data, max_chunk_size))
7682

77-
file_uploaded_url, message = _upload_multipart(api_key, dest_path, to_upload_path, show_progress,
83+
else:
84+
file_size = os.path.getsize(to_upload_path)
85+
86+
file_uploaded_url, message = _upload_multipart(api_key, dest_path, file_size, max_chunk_size, to_upload_path, show_progress,
7887
out_progress_to_err,
79-
progress_desc, metadata)
88+
progress_desc, metadata, byte_data_flag, byte_data)
8089

81-
if(data_type == "dir"):
90+
if data_type == "dir":
8291
os.remove(to_upload_path)
8392
if not file_uploaded_url:
8493
return FedMLResponse(code=ResponseCode.FAILURE, message=f"Failed to upload file: {to_upload_path}")
@@ -262,6 +271,13 @@ def get_chunks(file_path, chunk_size):
262271
break
263272
yield chunk
264273

274+
def get_chunks_from_byte_data(byte_data, chunk_size):
275+
while True:
276+
chunk = byte_data.read(chunk_size)
277+
if not chunk:
278+
break
279+
yield chunk
280+
265281

266282
def _get_presigned_url(api_key, request_url, file_name, part_number=None):
267283
cert_path = MLOpsConfigs.get_cert_path_with_version()
@@ -287,7 +303,7 @@ def _upload_part(url,part_data,session):
287303
return response
288304

289305

290-
def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=None):
306+
def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=None, byte_data_flag= False):
291307
for retry_attempt in range(max_retries):
292308
try:
293309
response = _upload_part(presigned_url,chunk,session)
@@ -297,11 +313,12 @@ def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=
297313
else:
298314
raise requests.exceptions.RequestException
299315

300-
if(pbar is not None):
301-
pbar.update(chunk.__sizeof__())
316+
if pbar is not None:
317+
pbar.update(len(chunk))
302318
return {'etag': response.headers['ETag'], 'partNumber': part}
303319
raise requests.exceptions.RequestException
304320

321+
305322
def _process_post_response(response):
306323
if response.status_code != 200:
307324
message = (f"Failed to complete multipart upload with status code = {response.status_code}, "
@@ -345,14 +362,10 @@ def _complete_multipart_upload(api_key, file_key, part_info, upload_id):
345362
return _process_post_response(complete_multipart_response)
346363

347364

348-
def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_progress_to_err,
349-
progress_desc_text, metadata):
365+
def _upload_multipart(api_key: str, file_key, file_size, max_chunk_size, archive_path, show_progress, out_progress_to_err,
366+
progress_desc_text, metadata, byte_data_flag, byte_data):
350367
request_url = ServerConstants.get_presigned_multi_part_url()
351368

352-
file_size = os.path.getsize(archive_path)
353-
354-
max_chunk_size = 20 * 1024 * 1024
355-
356369
num_chunks = _get_num_chunks(file_size, max_chunk_size)
357370

358371
upload_id = ""
@@ -379,8 +392,12 @@ def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_p
379392
upload_id = data['uploadId']
380393
presigned_urls = data['urls']
381394

382-
parts = []
383-
chunks = get_chunks(archive_path, max_chunk_size)
395+
if byte_data_flag:
396+
byte_data.seek(0)
397+
chunks = get_chunks_from_byte_data(byte_data, max_chunk_size)
398+
else:
399+
chunks = get_chunks(archive_path, max_chunk_size)
400+
384401
part_info = []
385402
chunk_count = 0
386403
successful_chunks = 0
@@ -396,7 +413,7 @@ def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_p
396413
if show_progress:
397414
try:
398415
part_data = _upload_chunk(presigned_url=presigned_url, chunk=chunk, part=part,
399-
pbar=pbar,session=atomic_session)
416+
pbar=pbar,session=atomic_session, byte_data_flag = byte_data_flag)
400417
part_info.append(part_data)
401418
successful_chunks += 1
402419
except Exception as e:
@@ -474,8 +491,11 @@ def _get_storage_service(service):
474491
else:
475492
raise NotImplementedError(f"Service {service} not implemented")
476493

477-
def _get_data_type(data_path):
478-
if os.path.isdir(data_path):
494+
495+
def _get_data_type(data_path, byte_data_flag):
496+
if byte_data_flag:
497+
return DataType.BYTE
498+
elif os.path.isdir(data_path):
479499
return DataType.DIRECTORY
480500
elif os.path.isfile(data_path):
481501
return DataType.FILE

0 commit comments

Comments
 (0)