@@ -30,26 +30,29 @@ def __init__(self, data: dict):
3030class DataType (Enum ):
3131 FILE = "file"
3232 DIRECTORY = "directory"
33+ BYTE = "byte"
3334 INVALID = "invalid"
3435
36+
37+
3538# Todo (alaydshah): Store service name in metadata
3639# Todo (alaydshah): If data already exists, don't upload again. Instead suggest to use update command
3740# Todo (bhargav) : Discuss and remove the service variable. Maybe needed sometime later.
3841def upload (data_path , api_key , name , description , tag_list , service , show_progress , out_progress_to_err , progress_desc ,
39- metadata ) -> FedMLResponse :
42+ metadata , byte_data_flag = False , byte_data = None ) -> FedMLResponse :
4043 api_key = authenticate (api_key )
4144
4245 user_id , message = _get_user_id_from_api_key (api_key )
4346
4447 if user_id is None :
4548 return FedMLResponse (code = ResponseCode .FAILURE , message = message )
4649
47- data_type = _get_data_type (data_path )
50+ data_type = _get_data_type (data_path , byte_data_flag )
4851
49- if ( data_type == DataType .INVALID ) :
52+ if data_type == DataType .INVALID :
5053 return FedMLResponse (code = ResponseCode .FAILURE ,message = "Invalid data path" )
5154
52- if ( data_type == DataType .DIRECTORY ) :
55+ if data_type == DataType .DIRECTORY :
5356 to_upload_path , message = _archive_data (data_path )
5457 name = os .path .splitext (os .path .basename (to_upload_path ))[0 ] if name is None else name
5558 file_name = name + ".zip"
@@ -67,18 +70,24 @@ def upload(data_path, api_key, name, description, tag_list, service, show_progre
6770
6871 file_name = name
6972
70- if not to_upload_path :
73+ if not to_upload_path and not byte_data_flag :
7174 return FedMLResponse (code = ResponseCode .FAILURE , message = message )
7275
7376 #TODO(bhargav191098) - Better done on the backend. Remove and pass file_name once completed on backend.
7477 dest_path = os .path .join (user_id , file_name )
75- file_size = os .path .getsize (to_upload_path )
78+ max_chunk_size = 20 * 1024 * 1024
79+
80+ if byte_data_flag :
81+ file_size = sum (len (chunk ) for chunk in get_chunks_from_byte_data (byte_data , max_chunk_size ))
7682
77- file_uploaded_url , message = _upload_multipart (api_key , dest_path , to_upload_path , show_progress ,
83+ else :
84+ file_size = os .path .getsize (to_upload_path )
85+
86+ file_uploaded_url , message = _upload_multipart (api_key , dest_path , file_size , max_chunk_size , to_upload_path , show_progress ,
7887 out_progress_to_err ,
79- progress_desc , metadata )
88+ progress_desc , metadata , byte_data_flag , byte_data )
8089
81- if ( data_type == "dir" ) :
90+ if data_type == "dir" :
8291 os .remove (to_upload_path )
8392 if not file_uploaded_url :
8493 return FedMLResponse (code = ResponseCode .FAILURE , message = f"Failed to upload file: { to_upload_path } " )
@@ -262,6 +271,13 @@ def get_chunks(file_path, chunk_size):
262271 break
263272 yield chunk
264273
274+ def get_chunks_from_byte_data (byte_data , chunk_size ):
275+ while True :
276+ chunk = byte_data .read (chunk_size )
277+ if not chunk :
278+ break
279+ yield chunk
280+
265281
266282def _get_presigned_url (api_key , request_url , file_name , part_number = None ):
267283 cert_path = MLOpsConfigs .get_cert_path_with_version ()
@@ -287,7 +303,7 @@ def _upload_part(url,part_data,session):
287303 return response
288304
289305
290- def _upload_chunk (presigned_url , chunk , part , pbar = None , max_retries = 20 ,session = None ):
306+ def _upload_chunk (presigned_url , chunk , part , pbar = None , max_retries = 20 ,session = None , byte_data_flag = False ):
291307 for retry_attempt in range (max_retries ):
292308 try :
293309 response = _upload_part (presigned_url ,chunk ,session )
@@ -297,11 +313,12 @@ def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=
297313 else :
298314 raise requests .exceptions .RequestException
299315
300- if ( pbar is not None ) :
301- pbar .update (chunk . __sizeof__ ( ))
316+ if pbar is not None :
317+ pbar .update (len ( chunk ))
302318 return {'etag' : response .headers ['ETag' ], 'partNumber' : part }
303319 raise requests .exceptions .RequestException
304320
321+
305322def _process_post_response (response ):
306323 if response .status_code != 200 :
307324 message = (f"Failed to complete multipart upload with status code = { response .status_code } , "
@@ -345,14 +362,10 @@ def _complete_multipart_upload(api_key, file_key, part_info, upload_id):
345362 return _process_post_response (complete_multipart_response )
346363
347364
348- def _upload_multipart (api_key : str , file_key , archive_path , show_progress , out_progress_to_err ,
349- progress_desc_text , metadata ):
365+ def _upload_multipart (api_key : str , file_key , file_size , max_chunk_size , archive_path , show_progress , out_progress_to_err ,
366+ progress_desc_text , metadata , byte_data_flag , byte_data ):
350367 request_url = ServerConstants .get_presigned_multi_part_url ()
351368
352- file_size = os .path .getsize (archive_path )
353-
354- max_chunk_size = 20 * 1024 * 1024
355-
356369 num_chunks = _get_num_chunks (file_size , max_chunk_size )
357370
358371 upload_id = ""
@@ -379,8 +392,12 @@ def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_p
379392 upload_id = data ['uploadId' ]
380393 presigned_urls = data ['urls' ]
381394
382- parts = []
383- chunks = get_chunks (archive_path , max_chunk_size )
395+ if byte_data_flag :
396+ byte_data .seek (0 )
397+ chunks = get_chunks_from_byte_data (byte_data , max_chunk_size )
398+ else :
399+ chunks = get_chunks (archive_path , max_chunk_size )
400+
384401 part_info = []
385402 chunk_count = 0
386403 successful_chunks = 0
@@ -396,7 +413,7 @@ def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_p
396413 if show_progress :
397414 try :
398415 part_data = _upload_chunk (presigned_url = presigned_url , chunk = chunk , part = part ,
399- pbar = pbar ,session = atomic_session )
416+ pbar = pbar ,session = atomic_session , byte_data_flag = byte_data_flag )
400417 part_info .append (part_data )
401418 successful_chunks += 1
402419 except Exception as e :
@@ -474,8 +491,11 @@ def _get_storage_service(service):
474491 else :
475492 raise NotImplementedError (f"Service { service } not implemented" )
476493
477- def _get_data_type (data_path ):
478- if os .path .isdir (data_path ):
494+
495+ def _get_data_type (data_path , byte_data_flag ):
496+ if byte_data_flag :
497+ return DataType .BYTE
498+ elif os .path .isdir (data_path ):
479499 return DataType .DIRECTORY
480500 elif os .path .isfile (data_path ):
481501 return DataType .FILE
0 commit comments