VYPR
High severity7.2OSV Advisory· Published Feb 2, 2026· Updated Apr 15, 2026

CVE-2026-1777

CVE-2026-1777

Description

The Amazon SageMaker Python SDK before v3.2.0 and v2.256.0 includes the ModelBuilder HMAC signing key in the cleartext response elements of the DescribeTrainingJob function. A third party with permissions to both call this API and permissions to modify objects in the Training Jobs S3 output location may have the ability to upload arbitrary artifacts which are executed the next time the Training Job is invoked.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
sagemakerPyPI
>= 3.0, < 3.2.03.2.0
sagemakerPyPI
< 2.256.02.256.0

Affected products

1

Patches

2
fb0d789db4fd

Bug fix for hmac key for V3 (#5379)

https://github.com/aws/sagemaker-python-sdkaviruthenDec 15, 2025via ghsa
18 files changed · +1867 1108
  • sagemaker-core/src/sagemaker/core/remote_function/client.py+6 6 modified
    @@ -366,7 +366,7 @@ def wrapper(*args, **kwargs):
                                 s3_uri=s3_path_join(
                                     job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
                                 ),
    -                            hmac_key=job.hmac_key,
    +                            
                             )
                         except ServiceError as serr:
                             chained_e = serr.__cause__
    @@ -403,7 +403,7 @@ def wrapper(*args, **kwargs):
                     return serialization.deserialize_obj_from_s3(
                         sagemaker_session=job_settings.sagemaker_session,
                         s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
    -                    hmac_key=job.hmac_key,
    +                    
                     )
     
                 if job.describe()["TrainingJobStatus"] == "Stopped":
    @@ -983,7 +983,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
                     job_return = serialization.deserialize_obj_from_s3(
                         sagemaker_session=sagemaker_session,
                         s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
    -                    hmac_key=job.hmac_key,
    +                    
                     )
                 except DeserializationError as e:
                     client_exception = e
    @@ -995,7 +995,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
                     job_exception = serialization.deserialize_exception_from_s3(
                         sagemaker_session=sagemaker_session,
                         s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
    -                    hmac_key=job.hmac_key,
    +                    
                     )
                 except ServiceError as serr:
                     chained_e = serr.__cause__
    @@ -1085,7 +1085,7 @@ def result(self, timeout: float = None) -> Any:
                         self._return = serialization.deserialize_obj_from_s3(
                             sagemaker_session=self._job.sagemaker_session,
                             s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
    -                        hmac_key=self._job.hmac_key,
    +                        
                         )
                         self._state = _FINISHED
                         return self._return
    @@ -1094,7 +1094,7 @@ def result(self, timeout: float = None) -> Any:
                             self._exception = serialization.deserialize_exception_from_s3(
                                 sagemaker_session=self._job.sagemaker_session,
                                 s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
    -                            hmac_key=self._job.hmac_key,
    +                            
                             )
                         except ServiceError as serr:
                             chained_e = serr.__cause__
    
  • sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py+0 6 modified
    @@ -164,7 +164,6 @@ class _DelayedReturnResolver:
         def __init__(
             self,
             delayed_returns: List[_DelayedReturn],
    -        hmac_key: str,
             properties_resolver: _PropertiesResolver,
             parameter_resolver: _ParameterResolver,
             execution_variable_resolver: _ExecutionVariableResolver,
    @@ -175,7 +174,6 @@ def __init__(
     
             Args:
                 delayed_returns: list of delayed returns to resolve.
    -            hmac_key: key used to encrypt serialized and deserialized function and arguments.
                 properties_resolver: resolver used to resolve step properties.
                 parameter_resolver: resolver used to pipeline parameters.
                 execution_variable_resolver: resolver used to resolve execution variables.
    @@ -197,7 +195,6 @@ def deserialization_task(uri):
                 return uri, deserialize_obj_from_s3(
                     sagemaker_session=settings["sagemaker_session"],
                     s3_uri=uri,
    -                hmac_key=hmac_key,
                 )
     
             with ThreadPoolExecutor() as executor:
    @@ -247,7 +244,6 @@ def resolve_pipeline_variables(
         context: Context,
         func_args: Tuple,
         func_kwargs: Dict,
    -    hmac_key: str,
         s3_base_uri: str,
         **settings,
     ):
    @@ -257,7 +253,6 @@ def resolve_pipeline_variables(
             context: context for the execution.
             func_args: function args.
             func_kwargs: function kwargs.
    -        hmac_key: key used to encrypt serialized and deserialized function and arguments.
             s3_base_uri: the s3 base uri of the function step that the serialized artifacts
                 will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
             **settings: settings to pass to the deserialization function.
    @@ -280,7 +275,6 @@ def resolve_pipeline_variables(
         properties_resolver = _PropertiesResolver(context)
         delayed_return_resolver = _DelayedReturnResolver(
             delayed_returns=delayed_returns,
    -        hmac_key=hmac_key,
             properties_resolver=properties_resolver,
             parameter_resolver=parameter_resolver,
             execution_variable_resolver=execution_variable_resolver,
    
  • sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py+16 28 modified
    @@ -19,7 +19,6 @@
     import io
     
     import sys
    -import hmac
     import hashlib
     import pickle
     
    @@ -156,15 +155,14 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
     
     # TODO: use dask serializer in case dask distributed is installed in users' environment.
     def serialize_func_to_s3(
    -    func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
    +    func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
     ):
         """Serializes function and uploads it to S3.
     
         Args:
             sagemaker_session (sagemaker.core.helper.session.Session):
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
             func: function to be serialized and persisted
         Raises:
    @@ -173,14 +171,13 @@ def serialize_func_to_s3(
     
         _upload_payload_and_metadata_to_s3(
             bytes_to_upload=CloudpickleSerializer.serialize(func),
    -        hmac_key=hmac_key,
             s3_uri=s3_uri,
             sagemaker_session=sagemaker_session,
             s3_kms_key=s3_kms_key,
         )
     
     
    -def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
    +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
         """Downloads from S3 and then deserializes data objects.
     
         This method downloads the serialized training job outputs to a temporary directory and
    @@ -190,7 +187,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
             sagemaker_session (sagemaker.core.helper.session.Session):
                 The underlying sagemaker session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
         Returns :
             The deserialized function.
         Raises:
    @@ -203,14 +199,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
         bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
     
         _perform_integrity_check(
    -        expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
    +        expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
         )
     
         return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
     
     
     def serialize_obj_to_s3(
    -    obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
    +    obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
     ):
         """Serializes data object and uploads it to S3.
     
    @@ -219,15 +215,13 @@ def serialize_obj_to_s3(
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
             obj: object to be serialized and persisted
         Raises:
             SerializationError: when fail to serialize object to bytes.
         """
     
         _upload_payload_and_metadata_to_s3(
             bytes_to_upload=CloudpickleSerializer.serialize(obj),
    -        hmac_key=hmac_key,
             s3_uri=s3_uri,
             sagemaker_session=sagemaker_session,
             s3_kms_key=s3_kms_key,
    @@ -274,14 +268,13 @@ def json_serialize_obj_to_s3(
         )
     
     
    -def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
    +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
         """Downloads from S3 and then deserializes data objects.
     
         Args:
             sagemaker_session (sagemaker.core.helper.session.Session):
                 The underlying sagemaker session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
         Returns :
             Deserialized python objects.
         Raises:
    @@ -295,14 +288,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
         bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
     
         _perform_integrity_check(
    -        expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
    +        expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
         )
     
         return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
     
     
     def serialize_exception_to_s3(
    -    exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
    +    exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
     ):
         """Serializes exception with traceback and uploads it to S3.
     
    @@ -311,7 +304,6 @@ def serialize_exception_to_s3(
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
             exc: Exception to be serialized and persisted
         Raises:
             SerializationError: when fail to serialize object to bytes.
    @@ -320,7 +312,6 @@ def serialize_exception_to_s3(
     
         _upload_payload_and_metadata_to_s3(
             bytes_to_upload=CloudpickleSerializer.serialize(exc),
    -        hmac_key=hmac_key,
             s3_uri=s3_uri,
             sagemaker_session=sagemaker_session,
             s3_kms_key=s3_kms_key,
    @@ -329,7 +320,6 @@ def serialize_exception_to_s3(
     
     def _upload_payload_and_metadata_to_s3(
         bytes_to_upload: Union[bytes, io.BytesIO],
    -    hmac_key: str,
         s3_uri: str,
         sagemaker_session: Session,
         s3_kms_key,
    @@ -338,15 +328,14 @@ def _upload_payload_and_metadata_to_s3(
     
         Args:
             bytes_to_upload (bytes): Serialized bytes to upload.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
             sagemaker_session (sagemaker.core.helper.session.Session):
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
         """
         _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
     
    -    sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
    +    sha256_hash = _compute_hash(bytes_to_upload)
     
         _upload_bytes_to_s3(
             _MetaData(sha256_hash).to_json(),
    @@ -356,14 +345,13 @@ def _upload_payload_and_metadata_to_s3(
         )
     
     
    -def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
    +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
         """Downloads from S3 and then deserializes exception.
     
         Args:
             sagemaker_session (sagemaker.core.helper.session.Session):
                 The underlying sagemaker session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
         Returns :
             Deserialized exception with traceback.
         Raises:
    @@ -377,7 +365,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
         bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
     
         _perform_integrity_check(
    -        expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
    +        expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
         )
     
         return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
    @@ -403,19 +391,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
             ) from e
     
     
    -def _compute_hash(buffer: bytes, secret_key: str) -> str:
    -    """Compute the hmac-sha256 hash"""
    -    return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
    +def _compute_hash(buffer: bytes) -> str:
    +    """Compute the sha256 hash"""
    +    return hashlib.sha256(buffer).hexdigest()
     
     
    -def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
    +def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
         """Performs integrity checks for serialized code/arguments uploaded to s3.
     
         Verifies whether the hash read from s3 matches the hash calculated
         during remote function execution.
         """
    -    actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
    -    if not hmac.compare_digest(expected_hash_value, actual_hash_value):
    +    actual_hash_value = _compute_hash(buffer=buffer)
    +    if expected_hash_value != actual_hash_value:
             raise DeserializationError(
                 "Integrity check for the serialized function or data failed. "
                 "Please restrict access to your S3 bucket"
    
  • sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py+8 11 modified
    @@ -55,7 +55,6 @@ def __init__(
             self,
             sagemaker_session: Session,
             s3_base_uri: str,
    -        hmac_key: str,
             s3_kms_key: str = None,
             context: Context = Context(),
         ):
    @@ -66,13 +65,11 @@ def __init__(
                     AWS service calls are delegated to.
                 s3_base_uri: the base uri to which serialized artifacts will be uploaded.
                 s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
    -            hmac_key: Key used to encrypt serialized and deserialized function and arguments.
                 context: Build or run context of a pipeline step.
             """
             self.sagemaker_session = sagemaker_session
             self.s3_base_uri = s3_base_uri
             self.s3_kms_key = s3_kms_key
    -        self.hmac_key = hmac_key
             self.context = context
     
             # For pipeline steps, function code is at: base/step_name/build_timestamp/
    @@ -114,7 +111,7 @@ def save(self, func, *args, **kwargs):
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
                 s3_kms_key=self.s3_kms_key,
    -            hmac_key=self.hmac_key,
    +            
             )
     
             logger.info(
    @@ -126,7 +123,7 @@ def save(self, func, *args, **kwargs):
                 obj=(args, kwargs),
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
    -            hmac_key=self.hmac_key,
    +            
                 s3_kms_key=self.s3_kms_key,
             )
     
    @@ -144,7 +141,7 @@ def save_pipeline_step_function(self, serialized_data):
             )
             serialization._upload_payload_and_metadata_to_s3(
                 bytes_to_upload=serialized_data.func,
    -            hmac_key=self.hmac_key,
    +            
                 s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
                 sagemaker_session=self.sagemaker_session,
                 s3_kms_key=self.s3_kms_key,
    @@ -156,7 +153,7 @@ def save_pipeline_step_function(self, serialized_data):
             )
             serialization._upload_payload_and_metadata_to_s3(
                 bytes_to_upload=serialized_data.args,
    -            hmac_key=self.hmac_key,
    +            
                 s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
                 sagemaker_session=self.sagemaker_session,
                 s3_kms_key=self.s3_kms_key,
    @@ -172,7 +169,7 @@ def load_and_invoke(self) -> Any:
             func = serialization.deserialize_func_from_s3(
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
    -            hmac_key=self.hmac_key,
    +            
             )
     
             logger.info(
    @@ -182,15 +179,15 @@ def load_and_invoke(self) -> Any:
             args, kwargs = serialization.deserialize_obj_from_s3(
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
    -            hmac_key=self.hmac_key,
    +            
             )
     
             logger.info("Resolving pipeline variables")
             resolved_args, resolved_kwargs = resolve_pipeline_variables(
                 self.context,
                 args,
                 kwargs,
    -            hmac_key=self.hmac_key,
    +            
                 s3_base_uri=self.s3_base_uri,
                 sagemaker_session=self.sagemaker_session,
             )
    @@ -206,7 +203,7 @@ def load_and_invoke(self) -> Any:
                 obj=result,
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER),
    -            hmac_key=self.hmac_key,
    +            
                 s3_kms_key=self.s3_kms_key,
             )
     
    
  • sagemaker-core/src/sagemaker/core/remote_function/errors.py+1 3 modified
    @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg):
                 f.write(failure_msg)
     
     
    -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int:
    +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int:
         """Handle all exceptions raised during remote function execution.
     
         Args:
    @@ -79,7 +79,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) ->
                  AWS service calls are delegated to.
             s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
    -        hmac_key (str): Key used to calculate hmac hash of the serialized exception.
         Returns :
             exit_code (int): Exit code to terminate current job.
         """
    @@ -97,7 +96,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) ->
             exc=error,
             sagemaker_session=sagemaker_session,
             s3_uri=s3_path_join(s3_base_uri, "exception"),
    -        hmac_key=hmac_key,
             s3_kms_key=s3_kms_key,
         )
     
    
  • sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py+1 6 modified
    @@ -98,7 +98,7 @@ def _load_pipeline_context(args) -> Context:
     
     
     def _execute_remote_function(
    -    sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context
    +    sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context
     ):
         """Execute stored remote function"""
         from sagemaker.core.remote_function.core.stored_function import StoredFunction
    @@ -107,7 +107,6 @@ def _execute_remote_function(
             sagemaker_session=sagemaker_session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=s3_kms_key,
    -        hmac_key=hmac_key,
             context=context,
         )
     
    @@ -138,15 +137,12 @@ def main(sys_args=None):
             run_in_context = args.run_in_context
             pipeline_context = _load_pipeline_context(args)
     
    -        hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY")
    -
             sagemaker_session = _get_sagemaker_session(region)
             _execute_remote_function(
                 sagemaker_session=sagemaker_session,
                 s3_base_uri=s3_base_uri,
                 s3_kms_key=s3_kms_key,
                 run_in_context=run_in_context,
    -            hmac_key=hmac_key,
                 context=pipeline_context,
             )
     
    @@ -162,7 +158,6 @@ def main(sys_args=None):
                 sagemaker_session=sagemaker_session,
                 s3_base_uri=s3_uri,
                 s3_kms_key=s3_kms_key,
    -            hmac_key=hmac_key,
             )
         finally:
             sys.exit(exit_code)
    
  • sagemaker-core/src/sagemaker/core/remote_function/job.py+2 21 modified
    @@ -17,7 +17,6 @@
     import json
     import os
     import re
    -import secrets
     import shutil
     import sys
     import time
    @@ -621,11 +620,6 @@ def __init__(
                 {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name}
             )
     
    -        # The following will be overridden by the _Job.compile method.
    -        # However, it needs to be kept here for feature store SDK.
    -        # TODO: update the feature store SDK to set the HMAC key there.
    -        self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)})
    -
             if spark_config and image_uri:
                 raise ValueError("spark_config and image_uri cannot be specified at the same time!")
     
    @@ -839,19 +833,17 @@ def _get_default_spark_image(session):
     class _Job:
         """Helper class that interacts with the SageMaker training service."""
     
    -    def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str):
    +    def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session):
             """Initialize a _Job object.
     
             Args:
                 job_name (str): The training job name.
                 s3_uri (str): The training job output S3 uri.
                 sagemaker_session (Session): SageMaker boto session.
    -            hmac_key (str): Remote function secret key.
             """
             self.job_name = job_name
             self.s3_uri = s3_uri
             self.sagemaker_session = sagemaker_session
    -        self.hmac_key = hmac_key
             self._last_describe_response = None
     
         @staticmethod
    @@ -867,9 +859,8 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
             """
             job_name = describe_training_job_response["TrainingJobName"]
             s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
    -        hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"]
     
    -        job = _Job(job_name, s3_uri, sagemaker_session, hmac_key)
    +        job = _Job(job_name, s3_uri, sagemaker_session)
             job._last_describe_response = describe_training_job_response
             return job
     
    @@ -907,7 +898,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
                 job_name,
                 s3_base_uri,
                 job_settings.sagemaker_session,
    -            training_job_request["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
             )
     
         @staticmethod
    @@ -935,26 +925,18 @@ def compile(
     
             jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:]
     
    -        # generate hmac key for integrity check
    -        if step_compilation_context is None:
    -            hmac_key = secrets.token_hex(32)
    -        else:
    -            hmac_key = step_compilation_context.function_step_secret_token
    -
             # serialize function and arguments
             if step_compilation_context is None:
                 stored_function = StoredFunction(
                     sagemaker_session=job_settings.sagemaker_session,
                     s3_base_uri=s3_base_uri,
    -                hmac_key=hmac_key,
                     s3_kms_key=job_settings.s3_kms_key,
                 )
                 stored_function.save(func, *func_args, **func_kwargs)
             else:
                 stored_function = StoredFunction(
                     sagemaker_session=job_settings.sagemaker_session,
                     s3_base_uri=s3_base_uri,
    -                hmac_key=hmac_key,
                     s3_kms_key=job_settings.s3_kms_key,
                     context=Context(
                         step_name=step_compilation_context.step_name,
    @@ -1114,7 +1096,6 @@ def compile(
             request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances
     
             request_dict["Environment"] = job_settings.environment_variables
    -        request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})
     
             extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
             extended_request = _extend_mpirun_to_request(extended_request, job_settings)
    
  • sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py+554 423 modified
    @@ -10,20 +10,23 @@
     # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
     # ANY KIND, either express or implied. See the License for the specific
     # language governing permissions and limitations under the License.
    +"""Tests for bootstrap_runtime_environment module."""
    +from __future__ import absolute_import
     
    -import pytest
    -from unittest.mock import Mock, patch, mock_open, MagicMock
     import json
    -import sys
    +import os
    +import pytest
    +import subprocess
    +from unittest.mock import patch, MagicMock, mock_open, call
     
     from sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment import (
    +    _parse_args,
         _bootstrap_runtime_env_for_remote_function,
         _bootstrap_runtime_env_for_pipeline_step,
         _handle_pre_exec_scripts,
         _install_dependencies,
         _unpack_user_workspace,
         _write_failure_reason_file,
    -    _parse_args,
         log_key_value,
         log_env_variables,
         mask_sensitive_info,
    @@ -35,6 +38,11 @@
         main,
         SUCCESS_EXIT_CODE,
         DEFAULT_FAILURE_CODE,
    +    FAILURE_REASON_PATH,
    +    REMOTE_FUNCTION_WORKSPACE,
    +    BASE_CHANNEL_PATH,
    +    JOB_REMOTE_FUNCTION_WORKSPACE,
    +    SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME,
         SENSITIVE_KEYWORDS,
         HIDDEN_VALUE,
     )
    @@ -43,506 +51,629 @@
     )
     
     
    -class TestBootstrapRuntimeEnvironment:
    -    """Test cases for bootstrap runtime environment functions"""
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies"
    -    )
    -    def test_bootstrap_runtime_env_for_remote_function(
    -        self, mock_install, mock_handle, mock_unpack
    -    ):
    -        """Test _bootstrap_runtime_env_for_remote_function"""
    -        mock_unpack.return_value = "/workspace"
    -        dependency_settings = _DependencySettings(dependency_file="requirements.txt")
    -
    -        _bootstrap_runtime_env_for_remote_function(
    -            client_python_version="3.8", conda_env="myenv", dependency_settings=dependency_settings
    -        )
    -
    -        mock_unpack.assert_called_once()
    -        mock_handle.assert_called_once_with("/workspace")
    -        mock_install.assert_called_once()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace"
    -    )
    -    def test_bootstrap_runtime_env_for_remote_function_no_workspace(self, mock_unpack):
    -        """Test _bootstrap_runtime_env_for_remote_function with no workspace"""
    -        mock_unpack.return_value = None
    -
    -        _bootstrap_runtime_env_for_remote_function(client_python_version="3.8")
    -
    -        mock_unpack.assert_called_once()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.mkdir"
    -    )
    -    def test_bootstrap_runtime_env_for_pipeline_step(self, mock_mkdir, mock_exists, mock_unpack):
    -        """Test _bootstrap_runtime_env_for_pipeline_step"""
    -        mock_unpack.return_value = None
    -        mock_exists.return_value = False
    -
    -        _bootstrap_runtime_env_for_pipeline_step(
    -            client_python_version="3.8", func_step_workspace="workspace"
    -        )
    -
    -        mock_mkdir.assert_called_once()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile"
    -    )
    -    def test_handle_pre_exec_scripts_exists(self, mock_isfile, mock_manager_class):
    -        """Test _handle_pre_exec_scripts when script exists"""
    -        mock_isfile.return_value = True
    -        mock_manager = Mock()
    -        mock_manager_class.return_value = mock_manager
    -
    -        _handle_pre_exec_scripts("/workspace")
    -
    -        mock_manager.run_pre_exec_script.assert_called_once()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile"
    -    )
    -    def test_handle_pre_exec_scripts_not_exists(self, mock_isfile, mock_manager_class):
    -        """Test _handle_pre_exec_scripts when script doesn't exist"""
    -        mock_isfile.return_value = False
    -        mock_manager = Mock()
    -        mock_manager_class.return_value = mock_manager
    -
    -        _handle_pre_exec_scripts("/workspace")
    -
    -        mock_manager.run_pre_exec_script.assert_not_called()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.join"
    -    )
    -    def test_install_dependencies_with_file(self, mock_join, mock_manager_class):
    -        """Test _install_dependencies with dependency file"""
    -        mock_join.return_value = "/workspace/requirements.txt"
    -        mock_manager = Mock()
    -        mock_manager_class.return_value = mock_manager
    -
    -        dependency_settings = _DependencySettings(dependency_file="requirements.txt")
    -
    -        _install_dependencies(
    -            dependency_file_dir="/workspace",
    -            conda_env="myenv",
    -            client_python_version="3.8",
    -            channel_name="channel",
    -            dependency_settings=dependency_settings,
    -        )
    -
    -        mock_manager.bootstrap.assert_called_once()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager"
    -    )
    -    def test_install_dependencies_no_file(self, mock_manager_class):
    -        """Test _install_dependencies with no dependency file"""
    -        mock_manager = Mock()
    -        mock_manager_class.return_value = mock_manager
    -
    -        dependency_settings = _DependencySettings(dependency_file=None)
    -
    -        _install_dependencies(
    -            dependency_file_dir="/workspace",
    -            conda_env=None,
    -            client_python_version="3.8",
    -            channel_name="channel",
    -            dependency_settings=dependency_settings,
    -        )
    -
    -        mock_manager.bootstrap.assert_not_called()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.shutil.unpack_archive"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.pathlib.Path"
    -    )
    -    def test_unpack_user_workspace_success(self, mock_path, mock_unpack, mock_isfile, mock_exists):
    -        """Test _unpack_user_workspace successfully unpacks workspace"""
    -        mock_exists.return_value = True
    -        mock_isfile.return_value = True
    -        mock_path.return_value.absolute.return_value = "/workspace"
    -
    -        result = _unpack_user_workspace()
    -
    -        assert result is not None
    -        mock_unpack.assert_called_once()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists"
    -    )
    -    def test_unpack_user_workspace_no_directory(self, mock_exists):
    -        """Test _unpack_user_workspace when directory doesn't exist"""
    -        mock_exists.return_value = False
    -
    -        result = _unpack_user_workspace()
    -
    -        assert result is None
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists"
    -    )
    -    @patch("builtins.open", new_callable=mock_open)
    -    def test_write_failure_reason_file(self, mock_file, mock_exists):
    -        """Test _write_failure_reason_file"""
    -        mock_exists.return_value = False
    -
    -        _write_failure_reason_file("Test error message")
    -
    -        mock_file.assert_called_once()
    -        mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message")
    -
    -    def test_parse_args(self):
    -        """Test _parse_args"""
    -        args = _parse_args(
    -            [
    -                "--job_conda_env",
    -                "myenv",
    -                "--client_python_version",
    -                "3.8",
    -                "--dependency_settings",
    -                '{"dependency_file": "requirements.txt"}',
    -            ]
    -        )
    -
    -        assert args.job_conda_env == "myenv"
    -        assert args.client_python_version == "3.8"
    -        assert args.dependency_settings == '{"dependency_file": "requirements.txt"}'
    -
    -
    -class TestLoggingFunctions:
    -    """Test cases for logging functions"""
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger"
    -    )
    -    def test_log_key_value_normal(self, mock_logger):
    -        """Test log_key_value with normal key"""
    -        log_key_value("MY_KEY", "my_value")
    -
    +class TestParseArgs:
    +    """Test _parse_args function."""
    +
    +    def test_parse_required_args(self):
    +        """Test parsing required arguments."""
    +        args = [
    +            "--client_python_version", "3.8",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.client_python_version == "3.8"
    +
    +    def test_parse_all_args(self):
    +        """Test parsing all arguments."""
    +        args = [
    +            "--job_conda_env", "my-env",
    +            "--client_python_version", "3.9",
    +            "--client_sagemaker_pysdk_version", "2.100.0",
    +            "--pipeline_execution_id", "exec-123",
    +            "--dependency_settings", '{"dependency_file": "requirements.txt"}',
    +            "--func_step_s3_dir", "s3://bucket/func",
    +            "--distribution", "torchrun",
    +            "--user_nproc_per_node", "4",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.job_conda_env == "my-env"
    +        assert parsed.client_python_version == "3.9"
    +        assert parsed.client_sagemaker_pysdk_version == "2.100.0"
    +        assert parsed.pipeline_execution_id == "exec-123"
    +        assert parsed.dependency_settings == '{"dependency_file": "requirements.txt"}'
    +        assert parsed.func_step_s3_dir == "s3://bucket/func"
    +        assert parsed.distribution == "torchrun"
    +        assert parsed.user_nproc_per_node == "4"
    +
    +    def test_parse_default_values(self):
    +        """Test default values for optional arguments."""
    +        args = [
    +            "--client_python_version", "3.8",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.job_conda_env is None
    +        assert parsed.client_sagemaker_pysdk_version is None
    +        assert parsed.pipeline_execution_id is None
    +        assert parsed.dependency_settings is None
    +        assert parsed.func_step_s3_dir is None
    +        assert parsed.distribution is None
    +        assert parsed.user_nproc_per_node is None
    +
    +
    +class TestLogKeyValue:
    +    """Test log_key_value function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger")
    +    def test_logs_regular_value(self, mock_logger):
    +        """Test logs regular key-value pair."""
    +        log_key_value("my_name", "my_value")
    +        mock_logger.info.assert_called_once_with("%s=%s", "my_name", "my_value")
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger")
    +    def test_masks_sensitive_key(self, mock_logger):
    +        """Test masks sensitive keywords."""
    +        for keyword in ["PASSWORD", "SECRET", "TOKEN", "KEY", "PRIVATE", "CREDENTIALS"]:
    +            mock_logger.reset_mock()
    +            log_key_value(f"my_{keyword}", "sensitive_value")
    +            mock_logger.info.assert_called_once_with("%s=%s", f"my_{keyword}", HIDDEN_VALUE)
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger")
    +    def test_logs_dict_value(self, mock_logger):
    +        """Test logs dictionary value."""
    +        value = {"field1": "value1", "field2": "value2"}
    +        log_key_value("my_config", value)
    +        mock_logger.info.assert_called_once_with("%s=%s", "my_config", json.dumps(value))
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger")
    +    def test_logs_json_string_value(self, mock_logger):
    +        """Test logs JSON string value."""
    +        value = '{"key1": "value1"}'
    +        log_key_value("my_key", value)
             mock_logger.info.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger"
    -    )
    -    def test_log_key_value_sensitive(self, mock_logger):
    -        """Test log_key_value with sensitive key"""
    -        log_key_value("MY_PASSWORD", "secret123")
     
    -        mock_logger.info.assert_called_once()
    -        call_args = mock_logger.info.call_args[0]
    -        assert HIDDEN_VALUE in str(call_args)
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger"
    -    )
    -    def test_log_key_value_dict(self, mock_logger):
    -        """Test log_key_value with dictionary value"""
    -        log_key_value("MY_CONFIG", {"key": "value"})
    -
    -        mock_logger.info.assert_called_once()
    +class TestLogEnvVariables:
    +    """Test log_env_variables function."""
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ",
    -        {"ENV_VAR": "value"},
    -    )
    -    def test_log_env_variables(self, mock_logger):
    -        """Test log_env_variables"""
    -        log_env_variables({"CUSTOM_VAR": "custom_value"})
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_key_value")
    +    @patch.dict("os.environ", {"ENV_VAR1": "value1", "ENV_VAR2": "value2"})
    +    def test_logs_env_and_dict_variables(self, mock_log_kv):
    +        """Test logs both environment and dictionary variables."""
    +        env_dict = {"DICT_VAR1": "dict_value1", "DICT_VAR2": "dict_value2"}
    +        log_env_variables(env_dict)
    +        
    +        # Should be called for env vars and dict vars
    +        assert mock_log_kv.call_count >= 4
     
    -        assert mock_logger.info.call_count >= 2
     
    -    def test_mask_sensitive_info(self):
    -        """Test mask_sensitive_info"""
    -        data = {"username": "user", "password": "secret", "nested": {"api_key": "key123"}}
    +class TestMaskSensitiveInfo:
    +    """Test mask_sensitive_info function."""
     
    +    def test_masks_sensitive_keys_in_dict(self):
    +        """Test masks sensitive keys in dictionary."""
    +        data = {
    +            "username": "user",
    +            "password": "secret123",
    +            "api_key": "key123",
    +        }
             result = mask_sensitive_info(data)
    -
    -        assert result["password"] == HIDDEN_VALUE
    -        assert result["nested"]["api_key"] == HIDDEN_VALUE
             assert result["username"] == "user"
    +        assert result["password"] == HIDDEN_VALUE
    +        assert result["api_key"] == HIDDEN_VALUE
    +
    +    def test_masks_nested_dict(self):
    +        """Test masks sensitive keys in nested dictionary."""
    +        data = {
    +            "config": {
    +                "username": "user",
    +                "secret": "secret123",
    +            }
    +        }
    +        result = mask_sensitive_info(data)
    +        assert result["config"]["username"] == "user"
    +        assert result["config"]["secret"] == HIDDEN_VALUE
     
    +    def test_returns_non_dict_unchanged(self):
    +        """Test returns non-dictionary unchanged."""
    +        data = "string_value"
    +        result = mask_sensitive_info(data)
    +        assert result == "string_value"
     
    -class TestResourceFunctions:
    -    """Test cases for resource detection functions"""
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.multiprocessing.cpu_count"
    -    )
    -    def test_num_cpus(self, mock_cpu_count):
    -        """Test num_cpus"""
    -        mock_cpu_count.return_value = 4
    +class TestNumCpus:
    +    """Test num_cpus function."""
     
    -        result = num_cpus()
    +    @patch("multiprocessing.cpu_count")
    +    def test_returns_cpu_count(self, mock_cpu_count):
    +        """Test returns CPU count."""
    +        mock_cpu_count.return_value = 8
    +        assert num_cpus() == 8
     
    -        assert result == 4
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output"
    -    )
    -    def test_num_gpus_with_gpus(self, mock_check_output):
    -        """Test num_gpus when GPUs are present"""
    -        mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n"
    +class TestNumGpus:
    +    """Test num_gpus function."""
     
    -        result = num_gpus()
    +    @patch("subprocess.check_output")
    +    def test_returns_gpu_count(self, mock_check_output):
    +        """Test returns GPU count."""
    +        mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n"
    +        assert num_gpus() == 2
     
    -        assert result == 2
    +    @patch("subprocess.check_output")
    +    def test_returns_zero_on_error(self, mock_check_output):
    +        """Test returns zero when nvidia-smi fails."""
    +        mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi")
    +        assert num_gpus() == 0
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output"
    -    )
    -    def test_num_gpus_no_gpus(self, mock_check_output):
    -        """Test num_gpus when no GPUs are present"""
    +    @patch("subprocess.check_output")
    +    def test_returns_zero_on_os_error(self, mock_check_output):
    +        """Test returns zero when nvidia-smi not found."""
             mock_check_output.side_effect = OSError()
    +        assert num_gpus() == 0
     
    -        result = num_gpus()
     
    -        assert result == 0
    +class TestNumNeurons:
    +    """Test num_neurons function."""
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output"
    -    )
    -    def test_num_neurons_with_neurons(self, mock_check_output):
    -        """Test num_neurons when neurons are present"""
    -        mock_check_output.return_value = b'[{"nc_count": 2}, {"nc_count": 2}]'
    +    @patch("subprocess.check_output")
    +    def test_returns_neuron_count(self, mock_check_output):
    +        """Test returns neuron core count."""
    +        mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 4}])
    +        mock_check_output.return_value = mock_output.encode("utf-8")
    +        assert num_neurons() == 6
     
    -        result = num_neurons()
    -
    -        assert result == 4
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output"
    -    )
    -    def test_num_neurons_no_neurons(self, mock_check_output):
    -        """Test num_neurons when no neurons are present"""
    +    @patch("subprocess.check_output")
    +    def test_returns_zero_on_os_error(self, mock_check_output):
    +        """Test returns zero when neuron-ls not found."""
             mock_check_output.side_effect = OSError()
    +        assert num_neurons() == 0
     
    -        result = num_neurons()
    -
    -        assert result == 0
    -
    -
    -class TestSerializationFunctions:
    -    """Test cases for serialization functions"""
    -
    -    def test_safe_serialize_string(self):
    -        """Test safe_serialize with string"""
    -        result = safe_serialize("test_string")
    -
    -        assert result == "test_string"
    +    @patch("subprocess.check_output")
    +    def test_returns_zero_on_called_process_error(self, mock_check_output):
    +        """Test returns zero when neuron-ls fails."""
    +        error = subprocess.CalledProcessError(1, "neuron-ls")
    +        error.output = b"error=No neuron devices found"
    +        mock_check_output.side_effect = error
    +        assert num_neurons() == 0
     
    -    def test_safe_serialize_dict(self):
    -        """Test safe_serialize with dictionary"""
    -        result = safe_serialize({"key": "value"})
     
    -        assert result == '{"key": "value"}'
    +class TestSafeSerialize:
    +    """Test safe_serialize function."""
     
    -    def test_safe_serialize_list(self):
    -        """Test safe_serialize with list"""
    -        result = safe_serialize([1, 2, 3])
    +    def test_returns_string_as_is(self):
    +        """Test returns string without quotes."""
    +        assert safe_serialize("test_string") == "test_string"
     
    -        assert result == "[1, 2, 3]"
    +    def test_serializes_dict(self):
    +        """Test serializes dictionary."""
    +        data = {"key": "value"}
    +        assert safe_serialize(data) == '{"key": "value"}'
     
    -    def test_safe_serialize_non_serializable(self):
    -        """Test safe_serialize with non-serializable object"""
    +    def test_serializes_list(self):
    +        """Test serializes list."""
    +        data = [1, 2, 3]
    +        assert safe_serialize(data) == "[1, 2, 3]"
     
    -        class CustomObject:
    +    def test_returns_str_for_non_serializable(self):
    +        """Test returns str() for non-serializable objects."""
    +        class CustomObj:
                 def __str__(self):
                     return "custom_object"
    -
    -        result = safe_serialize(CustomObject())
    -
    -        assert "custom_object" in result
    +        
    +        obj = CustomObj()
    +        assert safe_serialize(obj) == "custom_object"
     
     
     class TestSetEnv:
    -    """Test cases for set_env function"""
    +    """Test set_env function."""
     
         @patch("builtins.open", new_callable=mock_open)
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ",
    -        {"TRAINING_JOB_NAME": "test-job"},
    -    )
    -    def test_set_env_basic(self, mock_neurons, mock_gpus, mock_cpus, mock_file):
    -        """Test set_env with basic configuration"""
    -        mock_cpus.return_value = 4
    -        mock_gpus.return_value = 0
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables")
    +    @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"})
    +    def test_sets_basic_env_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file):
    +        """Test sets basic environment variables."""
    +        mock_cpus.return_value = 8
    +        mock_gpus.return_value = 2
             mock_neurons.return_value = 0
    -
    +        
             resource_config = {
                 "current_host": "algo-1",
    -            "current_instance_type": "ml.m5.xlarge",
    -            "hosts": ["algo-1"],
    +            "current_instance_type": "ml.p3.2xlarge",
    +            "hosts": ["algo-1", "algo-2"],
                 "network_interface_name": "eth0",
             }
    -
    +        
             set_env(resource_config)
    -
    +        
             mock_file.assert_called_once()
    +        mock_log_env.assert_called_once()
     
         @patch("builtins.open", new_callable=mock_open)
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ",
    -        {"TRAINING_JOB_NAME": "test-job"},
    -    )
    -    def test_set_env_with_torchrun(self, mock_neurons, mock_gpus, mock_cpus, mock_file):
    -        """Test set_env with torchrun distribution"""
    -        mock_cpus.return_value = 4
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables")
    +    @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"})
    +    def test_sets_torchrun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file):
    +        """Test sets torchrun distribution environment variables."""
    +        mock_cpus.return_value = 8
             mock_gpus.return_value = 2
             mock_neurons.return_value = 0
    -
    +        
             resource_config = {
                 "current_host": "algo-1",
    -            "current_instance_type": "ml.p3.2xlarge",
    -            "hosts": ["algo-1", "algo-2"],
    +            "current_instance_type": "ml.p4d.24xlarge",
    +            "hosts": ["algo-1"],
                 "network_interface_name": "eth0",
             }
    -
    +        
             set_env(resource_config, distribution="torchrun")
    -
    +        
    +        # Verify file was written
             mock_file.assert_called_once()
     
         @patch("builtins.open", new_callable=mock_open)
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ",
    -        {"TRAINING_JOB_NAME": "test-job"},
    -    )
    -    def test_set_env_with_mpirun(self, mock_neurons, mock_gpus, mock_cpus, mock_file):
    -        """Test set_env with mpirun distribution"""
    -        mock_cpus.return_value = 4
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables")
    +    @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"})
    +    def test_sets_mpirun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file):
    +        """Test sets mpirun distribution environment variables."""
    +        mock_cpus.return_value = 8
             mock_gpus.return_value = 2
             mock_neurons.return_value = 0
    -
    +        
             resource_config = {
                 "current_host": "algo-1",
                 "current_instance_type": "ml.p3.2xlarge",
                 "hosts": ["algo-1", "algo-2"],
                 "network_interface_name": "eth0",
             }
    -
    +        
             set_env(resource_config, distribution="mpirun")
    +        
    +        mock_file.assert_called_once()
     
    +    @patch("builtins.open", new_callable=mock_open)
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables")
    +    @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"})
    +    def test_uses_user_nproc_per_node(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file):
    +        """Test uses user-specified nproc_per_node."""
    +        mock_cpus.return_value = 8
    +        mock_gpus.return_value = 2
    +        mock_neurons.return_value = 0
    +        
    +        resource_config = {
    +            "current_host": "algo-1",
    +            "current_instance_type": "ml.p3.2xlarge",
    +            "hosts": ["algo-1"],
    +            "network_interface_name": "eth0",
    +        }
    +        
    +        set_env(resource_config, user_nproc_per_node="4")
    +        
             mock_file.assert_called_once()
     
     
    +class TestWriteFailureReasonFile:
    +    """Test _write_failure_reason_file function."""
    +
    +    @patch("builtins.open", new_callable=mock_open)
    +    @patch("os.path.exists")
    +    def test_writes_failure_file(self, mock_exists, mock_file):
    +        """Test writes failure reason file."""
    +        mock_exists.return_value = False
    +        
    +        _write_failure_reason_file("Test error message")
    +        
    +        mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w")
    +        mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message")
    +
    +    @patch("builtins.open", new_callable=mock_open)
    +    @patch("os.path.exists")
    +    def test_does_not_write_if_exists(self, mock_exists, mock_file):
    +        """Test does not write if failure file already exists."""
    +        mock_exists.return_value = True
    +        
    +        _write_failure_reason_file("Test error message")
    +        
    +        mock_file.assert_not_called()
    +
    +
    +class TestUnpackUserWorkspace:
    +    """Test _unpack_user_workspace function."""
    +
    +    @patch("os.path.exists")
    +    def test_returns_none_if_dir_not_exists(self, mock_exists):
    +        """Test returns None if workspace directory doesn't exist."""
    +        mock_exists.return_value = False
    +        
    +        result = _unpack_user_workspace()
    +        
    +        assert result is None
    +
    +    @patch("os.path.isfile")
    +    @patch("os.path.exists")
    +    def test_returns_none_if_archive_not_exists(self, mock_exists, mock_isfile):
    +        """Test returns None if workspace archive doesn't exist."""
    +        mock_exists.return_value = True
    +        mock_isfile.return_value = False
    +        
    +        result = _unpack_user_workspace()
    +        
    +        assert result is None
    +
    +    @patch("shutil.unpack_archive")
    +    @patch("os.path.isfile")
    +    @patch("os.path.exists")
    +    @patch("os.getcwd")
    +    def test_unpacks_workspace_successfully(self, mock_getcwd, mock_exists, mock_isfile, mock_unpack):
    +        """Test unpacks workspace successfully."""
    +        mock_getcwd.return_value = "/tmp/workspace"
    +        mock_exists.return_value = True
    +        mock_isfile.return_value = True
    +        
    +        result = _unpack_user_workspace()
    +        
    +        mock_unpack.assert_called_once()
    +        assert result is not None
    +
    +
    +class TestHandlePreExecScripts:
    +    """Test _handle_pre_exec_scripts function."""
    +
    +    @patch("os.path.isfile")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    def test_runs_pre_exec_script(self, mock_manager_class, mock_isfile):
    +        """Test runs pre-execution script."""
    +        mock_isfile.return_value = True
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        
    +        _handle_pre_exec_scripts("/tmp/scripts")
    +        
    +        mock_manager.run_pre_exec_script.assert_called_once()
    +
    +
    +class TestInstallDependencies:
    +    """Test _install_dependencies function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    def test_installs_with_dependency_settings(self, mock_manager_class):
    +        """Test installs dependencies with dependency settings."""
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        
    +        dep_settings = _DependencySettings(dependency_file="requirements.txt")
    +        
    +        _install_dependencies(
    +            "/tmp/deps",
    +            "my-env",
    +            "3.8",
    +            "channel",
    +            dep_settings
    +        )
    +        
    +        mock_manager.bootstrap.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    def test_skips_if_no_dependency_file(self, mock_manager_class):
    +        """Test skips installation if no dependency file."""
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        
    +        dep_settings = _DependencySettings(dependency_file=None)
    +        
    +        _install_dependencies(
    +            "/tmp/deps",
    +            "my-env",
    +            "3.8",
    +            "channel",
    +            dep_settings
    +        )
    +        
    +        mock_manager.bootstrap.assert_not_called()
    +
    +    @patch("os.listdir")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    def test_finds_dependency_file_legacy(self, mock_manager_class, mock_listdir):
    +        """Test finds dependency file in legacy mode."""
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        mock_listdir.return_value = ["requirements.txt", "script.py"]
    +        
    +        _install_dependencies(
    +            "/tmp/deps",
    +            "my-env",
    +            "3.8",
    +            "channel",
    +            None
    +        )
    +        
    +        mock_manager.bootstrap.assert_called_once()
    +
    +
    +class TestBootstrapRuntimeEnvForRemoteFunction:
    +    """Test _bootstrap_runtime_env_for_remote_function function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace")
    +    def test_bootstraps_successfully(self, mock_unpack, mock_handle_scripts, mock_install):
    +        """Test bootstraps runtime environment successfully."""
    +        mock_unpack.return_value = "/tmp/workspace"
    +        
    +        _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None)
    +        
    +        mock_unpack.assert_called_once()
    +        mock_handle_scripts.assert_called_once()
    +        mock_install.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace")
    +    def test_returns_early_if_no_workspace(self, mock_unpack):
    +        """Test returns early if no workspace to unpack."""
    +        mock_unpack.return_value = None
    +        
    +        _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None)
    +        
    +        mock_unpack.assert_called_once()
    +
    +
    +class TestBootstrapRuntimeEnvForPipelineStep:
    +    """Test _bootstrap_runtime_env_for_pipeline_step function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts")
    +    @patch("shutil.copy")
    +    @patch("os.listdir")
    +    @patch("os.path.exists")
    +    @patch("os.mkdir")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace")
    +    def test_bootstraps_with_workspace(self, mock_unpack, mock_mkdir, mock_exists, mock_listdir, mock_copy, mock_handle_scripts, mock_install):
    +        """Test bootstraps pipeline step with workspace."""
    +        mock_unpack.return_value = "/tmp/workspace"
    +        mock_exists.return_value = True
    +        mock_listdir.return_value = ["requirements.txt"]
    +        
    +        _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None)
    +        
    +        mock_unpack.assert_called_once()
    +        mock_handle_scripts.assert_called_once()
    +        mock_install.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts")
    +    @patch("os.path.exists")
    +    @patch("os.mkdir")
    +    @patch("os.getcwd")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace")
    +    def test_creates_workspace_if_none(self, mock_unpack, mock_getcwd, mock_mkdir, mock_exists, mock_handle_scripts, mock_install):
    +        """Test creates workspace directory if none exists."""
    +        mock_unpack.return_value = None
    +        mock_getcwd.return_value = "/tmp"
    +        mock_exists.return_value = False
    +        
    +        _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None)
    +        
    +        mock_mkdir.assert_called_once()
    +
    +
     class TestMain:
    -    """Test cases for main function"""
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.getpass.getuser"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists"
    -    )
    -    def test_main_success(
    -        self, mock_exists, mock_getuser, mock_manager_class, mock_bootstrap, mock_parse
    -    ):
    -        """Test main function successful execution"""
    -        mock_args = Mock()
    +    """Test main function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env")
    +    @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}')
    +    @patch("os.path.exists")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function")
    +    @patch("getpass.getuser")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args")
    +    def test_main_success(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env):
    +        """Test main function successful execution."""
    +        mock_getuser.return_value = "root"
    +        mock_exists.return_value = True
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        
    +        # Mock parsed args
    +        mock_args = MagicMock()
             mock_args.client_python_version = "3.8"
    -        mock_args.client_sagemaker_pysdk_version = "2.0.0"
    +        mock_args.client_sagemaker_pysdk_version = None
             mock_args.job_conda_env = None
             mock_args.pipeline_execution_id = None
             mock_args.dependency_settings = None
             mock_args.func_step_s3_dir = None
             mock_args.distribution = None
             mock_args.user_nproc_per_node = None
    -        mock_parse.return_value = mock_args
    -
    -        mock_getuser.return_value = "root"
    -        mock_exists.return_value = False
    -
    -        mock_manager = Mock()
    -        mock_manager_class.return_value = mock_manager
    -
    +        mock_parse_args.return_value = mock_args
    +        
    +        args = [
    +            "--client_python_version", "3.8",
    +        ]
    +        
             with pytest.raises(SystemExit) as exc_info:
    -            main([])
    -
    +            main(args)
    +        
             assert exc_info.value.code == SUCCESS_EXIT_CODE
    +        mock_bootstrap.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file"
    -    )
    -    def test_main_failure(self, mock_write_failure, mock_parse):
    -        """Test main function with failure"""
    -        mock_parse.side_effect = Exception("Test error")
    -
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    @patch("getpass.getuser")
    +    def test_main_handles_exception(self, mock_getuser, mock_manager_class, mock_write_failure):
    +        """Test main function handles exceptions."""
    +        mock_getuser.return_value = "root"
    +        mock_manager = MagicMock()
    +        mock_manager._validate_python_version.side_effect = Exception("Test error")
    +        mock_manager_class.return_value = mock_manager
    +        
    +        args = [
    +            "--client_python_version", "3.8",
    +        ]
    +        
             with pytest.raises(SystemExit) as exc_info:
    -            main([])
    -
    +            main(args)
    +        
             assert exc_info.value.code == DEFAULT_FAILURE_CODE
             mock_write_failure.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env")
    +    @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}')
    +    @patch("os.path.exists")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_pipeline_step")
    +    @patch("getpass.getuser")
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args")
    +    def test_main_pipeline_execution(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env):
    +        """Test main function for pipeline execution."""
    +        mock_getuser.return_value = "root"
    +        mock_exists.return_value = True
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        
    +        # Mock parsed args
    +        mock_args = MagicMock()
    +        mock_args.client_python_version = "3.8"
    +        mock_args.client_sagemaker_pysdk_version = None
    +        mock_args.job_conda_env = None
    +        mock_args.pipeline_execution_id = "exec-123"
    +        mock_args.dependency_settings = None
    +        mock_args.func_step_s3_dir = "s3://bucket/func"
    +        mock_args.distribution = None
    +        mock_args.user_nproc_per_node = None
    +        mock_parse_args.return_value = mock_args
    +        
    +        args = [
    +            "--client_python_version", "3.8",
    +            "--pipeline_execution_id", "exec-123",
    +            "--func_step_s3_dir", "s3://bucket/func",
    +        ]
    +        
    +        with pytest.raises(SystemExit) as exc_info:
    +            main(args)
    +        
    +        assert exc_info.value.code == SUCCESS_EXIT_CODE
    +        mock_bootstrap.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager")
    +    @patch("getpass.getuser")
    +    def test_main_non_root_user(self, mock_getuser, mock_manager_class):
    +        """Test main function with non-root user."""
    +        mock_getuser.return_value = "ubuntu"
    +        mock_manager = MagicMock()
    +        mock_manager_class.return_value = mock_manager
    +        
    +        args = [
    +            "--client_python_version", "3.8",
    +        ]
    +        
    +        with pytest.raises(SystemExit):
    +            main(args)
    +        
    +        mock_manager.change_dir_permission.assert_called_once()
    
  • sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py+290 232 modified
    @@ -10,10 +10,14 @@
     # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
     # ANY KIND, either express or implied. See the License for the specific
     # language governing permissions and limitations under the License.
    +"""Tests for mpi_utils_remote module."""
    +from __future__ import absolute_import
     
    +import os
     import pytest
    -from unittest.mock import Mock, patch, MagicMock, mock_open
     import subprocess
    +import time
    +from unittest.mock import patch, MagicMock, mock_open, call
     import paramiko
     
     from sagemaker.core.remote_function.runtime_environment.mpi_utils_remote import (
    @@ -32,335 +36,389 @@
         main,
         SUCCESS_EXIT_CODE,
         DEFAULT_FAILURE_CODE,
    +    FAILURE_REASON_PATH,
         FINISHED_STATUS_FILE,
         READY_FILE,
         DEFAULT_SSH_PORT,
     )
     
     
     class TestCustomHostKeyPolicy:
    -    """Test cases for CustomHostKeyPolicy class"""
    +    """Test CustomHostKeyPolicy class."""
     
    -    def test_missing_host_key_algo_hostname(self):
    -        """Test missing_host_key accepts algo-* hostnames"""
    +    def test_accepts_algo_hostname(self):
    +        """Test accepts hostnames starting with algo-."""
             policy = CustomHostKeyPolicy()
    -        client = Mock()
    -        client.get_host_keys.return_value = Mock()
    -        key = Mock()
    -        key.get_name.return_value = "ssh-rsa"
    -
    +        mock_client = MagicMock()
    +        mock_hostname = "algo-1234"
    +        mock_key = MagicMock()
    +        mock_key.get_name.return_value = "ssh-rsa"
    +        
             # Should not raise exception
    -        policy.missing_host_key(client, "algo-1", key)
    -
    -        client.get_host_keys().add.assert_called_once()
    +        policy.missing_host_key(mock_client, mock_hostname, mock_key)
    +        
    +        mock_client.get_host_keys().add.assert_called_once_with(mock_hostname, "ssh-rsa", mock_key)
     
    -    def test_missing_host_key_unknown_hostname(self):
    -        """Test missing_host_key rejects unknown hostnames"""
    +    def test_rejects_non_algo_hostname(self):
    +        """Test rejects hostnames not starting with algo-."""
             policy = CustomHostKeyPolicy()
    -        client = Mock()
    -        key = Mock()
    +        mock_client = MagicMock()
    +        mock_hostname = "unknown-host"
    +        mock_key = MagicMock()
    +        
    +        with pytest.raises(paramiko.SSHException):
    +            policy.missing_host_key(mock_client, mock_hostname, mock_key)
    +
    +
    +class TestParseArgs:
    +    """Test _parse_args function."""
     
    -        with pytest.raises(paramiko.SSHException, match="Unknown host key"):
    -            policy.missing_host_key(client, "unknown-host", key)
    +    def test_parse_default_args(self):
    +        """Test parsing with default arguments."""
    +        args = []
    +        parsed = _parse_args(args)
    +        assert parsed.job_ended == "0"
     
    +    def test_parse_job_ended_true(self):
    +        """Test parsing with job_ended set to true."""
    +        args = ["--job_ended", "1"]
    +        parsed = _parse_args(args)
    +        assert parsed.job_ended == "1"
     
    -class TestConnectionFunctions:
    -    """Test cases for connection functions"""
    +    def test_parse_job_ended_false(self):
    +        """Test parsing with job_ended set to false."""
    +        args = ["--job_ended", "0"]
    +        parsed = _parse_args(args)
    +        assert parsed.job_ended == "0"
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.paramiko.SSHClient")
    +
    +class TestCanConnect:
    +    """Test _can_connect function."""
    +
    +    @patch("paramiko.SSHClient")
         def test_can_connect_success(self, mock_ssh_client_class):
    -        """Test _can_connect when connection succeeds"""
    -        mock_client = Mock()
    +        """Test successful connection."""
    +        mock_client = MagicMock()
             mock_ssh_client_class.return_value.__enter__.return_value = mock_client
    -
    +        
             result = _can_connect("algo-1", DEFAULT_SSH_PORT)
    -
    +        
             assert result is True
             mock_client.connect.assert_called_once_with("algo-1", port=DEFAULT_SSH_PORT)
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.paramiko.SSHClient")
    +    @patch("paramiko.SSHClient")
         def test_can_connect_failure(self, mock_ssh_client_class):
    -        """Test _can_connect when connection fails"""
    -        mock_client = Mock()
    +        """Test failed connection."""
    +        mock_client = MagicMock()
             mock_client.connect.side_effect = Exception("Connection failed")
             mock_ssh_client_class.return_value.__enter__.return_value = mock_client
    -
    +        
             result = _can_connect("algo-1", DEFAULT_SSH_PORT)
    -
    +        
             assert result is False
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.run")
    -    def test_write_file_to_host_success(self, mock_run):
    -        """Test _write_file_to_host when write succeeds"""
    -        mock_run.return_value = Mock()
    +    @patch("paramiko.SSHClient")
    +    def test_can_connect_uses_custom_port(self, mock_ssh_client_class):
    +        """Test connection with custom port."""
    +        mock_client = MagicMock()
    +        mock_ssh_client_class.return_value.__enter__.return_value = mock_client
    +        
    +        _can_connect("algo-1", 2222)
    +        
    +        mock_client.connect.assert_called_once_with("algo-1", port=2222)
     
    -        result = _write_file_to_host("algo-1", "/tmp/status")
     
    +class TestWriteFileToHost:
    +    """Test _write_file_to_host function."""
    +
    +    @patch("subprocess.run")
    +    def test_write_file_success(self, mock_run):
    +        """Test successful file write."""
    +        mock_run.return_value = MagicMock(returncode=0)
    +        
    +        result = _write_file_to_host("algo-1", "/tmp/status")
    +        
             assert result is True
             mock_run.assert_called_once()
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.run")
    -    def test_write_file_to_host_failure(self, mock_run):
    -        """Test _write_file_to_host when write fails"""
    +    @patch("subprocess.run")
    +    def test_write_file_failure(self, mock_run):
    +        """Test failed file write."""
             mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")
    -
    +        
             result = _write_file_to_host("algo-1", "/tmp/status")
    -
    +        
             assert result is False
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists")
    +
    +class TestWriteFailureReasonFile:
    +    """Test _write_failure_reason_file function."""
    +
         @patch("builtins.open", new_callable=mock_open)
    -    def test_write_failure_reason_file(self, mock_file, mock_exists):
    -        """Test _write_failure_reason_file"""
    +    @patch("os.path.exists")
    +    def test_writes_failure_file(self, mock_exists, mock_file):
    +        """Test writes failure reason file."""
             mock_exists.return_value = False
    +        
    +        _write_failure_reason_file("Test error message")
    +        
    +        mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w")
    +        mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message")
     
    -        _write_failure_reason_file("Test error")
    -
    -        mock_file.assert_called_once()
    -        mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error")
    +    @patch("builtins.open", new_callable=mock_open)
    +    @patch("os.path.exists")
    +    def test_does_not_write_if_exists(self, mock_exists, mock_file):
    +        """Test does not write if failure file already exists."""
    +        mock_exists.return_value = True
    +        
    +        _write_failure_reason_file("Test error message")
    +        
    +        mock_file.assert_not_called()
     
     
    -class TestWaitFunctions:
    -    """Test cases for wait functions"""
    +class TestWaitForMaster:
    +    """Test _wait_for_master function."""
     
    +    @patch("time.sleep")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    def test_wait_for_master_success(self, mock_sleep, mock_can_connect):
    -        """Test _wait_for_master when master becomes available"""
    -        mock_can_connect.side_effect = [False, False, True]
    -
    +    def test_wait_for_master_success(self, mock_can_connect, mock_sleep):
    +        """Test successful wait for master."""
    +        mock_can_connect.return_value = True
    +        
             _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300)
    +        
    +        mock_can_connect.assert_called_once_with("algo-1", DEFAULT_SSH_PORT)
     
    -        assert mock_can_connect.call_count == 3
    -
    +    @patch("time.time")
    +    @patch("time.sleep")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.time")
    -    def test_wait_for_master_timeout(self, mock_time, mock_sleep, mock_can_connect):
    -        """Test _wait_for_master when timeout occurs"""
    +    def test_wait_for_master_timeout(self, mock_can_connect, mock_sleep, mock_time):
    +        """Test timeout waiting for master."""
             mock_can_connect.return_value = False
    -        mock_time.side_effect = [0, 100, 200, 301, 301]
    -
    -        with pytest.raises(TimeoutError, match="Timed out waiting for master"):
    +        # Need enough values for all time.time() calls in the loop
    +        mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)]  # Simulate time passing
    +        
    +        with pytest.raises(TimeoutError):
                 _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300)
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    def test_wait_for_status_file(self, mock_sleep, mock_exists):
    -        """Test _wait_for_status_file"""
    -        mock_exists.side_effect = [False, False, True]
    +    @patch("time.time")
    +    @patch("time.sleep")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect")
    +    def test_wait_for_master_retries(self, mock_can_connect, mock_sleep, mock_time):
    +        """Test retries before successful connection."""
    +        mock_can_connect.side_effect = [False, False, True]
    +        # Return value instead of side_effect for time.time()
    +        mock_time.return_value = 0
    +        
    +        _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300)
    +        
    +        assert mock_can_connect.call_count == 3
     
    +
    +class TestWaitForStatusFile:
    +    """Test _wait_for_status_file function."""
    +
    +    @patch("time.sleep")
    +    @patch("os.path.exists")
    +    def test_wait_for_status_file_exists(self, mock_exists, mock_sleep):
    +        """Test wait for status file that exists."""
    +        mock_exists.return_value = True
    +        
             _wait_for_status_file("/tmp/status")
    +        
    +        mock_exists.assert_called_once_with("/tmp/status")
     
    +    @patch("time.sleep")
    +    @patch("os.path.exists")
    +    def test_wait_for_status_file_waits(self, mock_exists, mock_sleep):
    +        """Test waits until status file exists."""
    +        mock_exists.side_effect = [False, False, True]
    +        
    +        _wait_for_status_file("/tmp/status")
    +        
             assert mock_exists.call_count == 3
    +        assert mock_sleep.call_count == 2
    +
    +
    +class TestWaitForWorkers:
    +    """Test _wait_for_workers function."""
    +
    +    @patch("os.path.exists")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect")
    +    def test_wait_for_workers_empty_list(self, mock_can_connect, mock_exists):
    +        """Test wait for workers with empty list."""
    +        _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300)
    +        
    +        mock_can_connect.assert_not_called()
     
    +    @patch("time.sleep")
    +    @patch("os.path.exists")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    def test_wait_for_workers_success(self, mock_sleep, mock_exists, mock_can_connect):
    -        """Test _wait_for_workers when all workers become available"""
    +    def test_wait_for_workers_success(self, mock_can_connect, mock_exists, mock_sleep):
    +        """Test successful wait for workers."""
             mock_can_connect.return_value = True
             mock_exists.return_value = True
    -
    +        
             _wait_for_workers(["algo-2", "algo-3"], DEFAULT_SSH_PORT, timeout=300)
    -
    +        
             assert mock_can_connect.call_count == 2
     
    +    @patch("time.time")
    +    @patch("time.sleep")
    +    @patch("os.path.exists")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.time")
    -    def test_wait_for_workers_timeout(self, mock_time, mock_sleep, mock_can_connect):
    -        """Test _wait_for_workers when timeout occurs"""
    +    def test_wait_for_workers_timeout(self, mock_can_connect, mock_exists, mock_sleep, mock_time):
    +        """Test timeout waiting for workers."""
             mock_can_connect.return_value = False
    -        mock_time.side_effect = [0, 100, 200, 301, 301]
    -
    -        with pytest.raises(TimeoutError, match="Timed out waiting for workers"):
    +        mock_exists.return_value = False
    +        # Need enough values for all time.time() calls in the loop
    +        mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)]
    +        
    +        with pytest.raises(TimeoutError):
                 _wait_for_workers(["algo-2"], DEFAULT_SSH_PORT, timeout=300)
     
    -    def test_wait_for_workers_no_workers(self):
    -        """Test _wait_for_workers with no workers"""
    -        # Should not raise exception
    -        _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300)
    -
     
    -class TestBootstrapFunctions:
    -    """Test cases for bootstrap functions"""
    +class TestBootstrapMasterNode:
    +    """Test bootstrap_master_node function."""
     
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_workers")
         def test_bootstrap_master_node(self, mock_wait):
    -        """Test bootstrap_master_node"""
    -        bootstrap_master_node(["algo-2", "algo-3"])
    +        """Test bootstrap master node."""
    +        worker_hosts = ["algo-2", "algo-3"]
    +        
    +        bootstrap_master_node(worker_hosts)
    +        
    +        mock_wait.assert_called_once_with(worker_hosts)
     
    -        mock_wait.assert_called_once_with(["algo-2", "algo-3"])
     
    +class TestBootstrapWorkerNode:
    +    """Test bootstrap_worker_node function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_master")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file"
    -    )
    -    def test_bootstrap_worker_node(self, mock_wait_status, mock_write, mock_wait_master):
    -        """Test bootstrap_worker_node"""
    +    def test_bootstrap_worker_node(self, mock_wait_master, mock_write, mock_wait_status):
    +        """Test bootstrap worker node."""
             bootstrap_worker_node("algo-1", "algo-2", "/tmp/status")
    -
    +        
             mock_wait_master.assert_called_once_with("algo-1")
             mock_write.assert_called_once()
             mock_wait_status.assert_called_once_with("/tmp/status")
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists")
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.Popen")
    -    def test_start_sshd_daemon_success(self, mock_popen, mock_exists):
    -        """Test start_sshd_daemon when sshd exists"""
    -        mock_exists.return_value = True
     
    -        start_sshd_daemon()
    +class TestStartSshdDaemon:
    +    """Test start_sshd_daemon function."""
     
    -        mock_popen.assert_called_once()
    +    @patch("subprocess.Popen")
    +    @patch("os.path.exists")
    +    def test_starts_sshd_successfully(self, mock_exists, mock_popen):
    +        """Test starts SSH daemon successfully."""
    +        mock_exists.return_value = True
    +        
    +        start_sshd_daemon()
    +        
    +        mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"])
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists")
    -    def test_start_sshd_daemon_not_found(self, mock_exists):
    -        """Test start_sshd_daemon when sshd not found"""
    +    @patch("os.path.exists")
    +    def test_raises_error_if_sshd_not_found(self, mock_exists):
    +        """Test raises error if SSH daemon not found."""
             mock_exists.return_value = False
    -
    -        with pytest.raises(RuntimeError, match="SSH daemon not found"):
    +        
    +        with pytest.raises(RuntimeError):
                 start_sshd_daemon()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host"
    -    )
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    def test_write_status_file_to_workers_success(self, mock_sleep, mock_write):
    -        """Test write_status_file_to_workers when writes succeed"""
    -        mock_write.return_value = True
     
    -        write_status_file_to_workers(["algo-2", "algo-3"], "/tmp/status")
    +class TestWriteStatusFileToWorkers:
    +    """Test write_status_file_to_workers function."""
     
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host")
    +    def test_writes_to_all_workers(self, mock_write):
    +        """Test writes status file to all workers."""
    +        mock_write.return_value = True
    +        worker_hosts = ["algo-2", "algo-3"]
    +        
    +        write_status_file_to_workers(worker_hosts, "/tmp/status")
    +        
             assert mock_write.call_count == 2
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host"
    -    )
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep")
    -    def test_write_status_file_to_workers_timeout(self, mock_sleep, mock_write):
    -        """Test write_status_file_to_workers when timeout occurs"""
    +    @patch("time.sleep")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host")
    +    def test_retries_on_failure(self, mock_write, mock_sleep):
    +        """Test retries writing status file on failure."""
    +        mock_write.side_effect = [False, False, True]
    +        worker_hosts = ["algo-2"]
    +        
    +        write_status_file_to_workers(worker_hosts, "/tmp/status")
    +        
    +        assert mock_write.call_count == 3
    +        assert mock_sleep.call_count == 2
    +
    +    @patch("time.sleep")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host")
    +    def test_raises_timeout_after_retries(self, mock_write, mock_sleep):
    +        """Test raises timeout after max retries."""
             mock_write.return_value = False
    -
    -        with pytest.raises(TimeoutError, match="Timed out waiting"):
    -            write_status_file_to_workers(["algo-2"], "/tmp/status")
    -
    -
    -class TestParseArgs:
    -    """Test cases for _parse_args function"""
    -
    -    def test_parse_args_job_ended_false(self):
    -        """Test _parse_args with job_ended=0"""
    -        args = _parse_args(["--job_ended", "0"])
    -
    -        assert args.job_ended == "0"
    -
    -    def test_parse_args_job_ended_true(self):
    -        """Test _parse_args with job_ended=1"""
    -        args = _parse_args(["--job_ended", "1"])
    -
    -        assert args.job_ended == "1"
    -
    -    def test_parse_args_default(self):
    -        """Test _parse_args with default values"""
    -        args = _parse_args([])
    -
    -        assert args.job_ended == "0"
    +        worker_hosts = ["algo-2"]
    +        
    +        with pytest.raises(TimeoutError):
    +            write_status_file_to_workers(worker_hosts, "/tmp/status")
     
     
     class TestMain:
    -    """Test cases for main function"""
    +    """Test main function."""
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ",
    -        {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"},
    -    )
    -    def test_main_worker_node_job_running(self, mock_bootstrap_worker, mock_start_sshd, mock_parse):
    -        """Test main for worker node when job is running"""
    -        mock_args = Mock()
    -        mock_args.job_ended = "0"
    -        mock_parse.return_value = mock_args
    -
    -        main([])
    -
    +    @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"})
    +    def test_main_worker_node_running(self, mock_start_sshd, mock_bootstrap_worker):
    +        """Test main function for worker node during job run."""
    +        args = ["--job_ended", "0"]
    +        
    +        main(args)
    +        
             mock_start_sshd.assert_called_once()
             mock_bootstrap_worker.assert_called_once()
     
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node")
         @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node"
    -    )
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.json.loads")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ",
    -        {
    -            "SM_MASTER_ADDR": "algo-1",
    -            "SM_CURRENT_HOST": "algo-1",
    -            "SM_HOSTS": '["algo-1", "algo-2", "algo-3"]',
    -        },
    -    )
    -    def test_main_master_node_job_running(
    -        self, mock_json_loads, mock_bootstrap_master, mock_start_sshd, mock_parse
    -    ):
    -        """Test main for master node when job is running"""
    -        mock_args = Mock()
    -        mock_args.job_ended = "0"
    -        mock_parse.return_value = mock_args
    -        mock_json_loads.return_value = ["algo-1", "algo-2", "algo-3"]
    -
    -        main([])
    -
    +    @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'})
    +    def test_main_master_node_running(self, mock_start_sshd, mock_bootstrap_master):
    +        """Test main function for master node during job run."""
    +        args = ["--job_ended", "0"]
    +        
    +        main(args)
    +        
             mock_start_sshd.assert_called_once()
    -        mock_bootstrap_master.assert_called_once_with(["algo-2", "algo-3"])
    -
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers"
    -    )
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.json.loads")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ",
    -        {
    -            "SM_MASTER_ADDR": "algo-1",
    -            "SM_CURRENT_HOST": "algo-1",
    -            "SM_HOSTS": '["algo-1", "algo-2"]',
    -        },
    -    )
    -    def test_main_master_node_job_ended(self, mock_json_loads, mock_write_status, mock_parse):
    -        """Test main for master node when job has ended"""
    -        mock_args = Mock()
    -        mock_args.job_ended = "1"
    -        mock_parse.return_value = mock_args
    -        mock_json_loads.return_value = ["algo-1", "algo-2"]
    -
    -        main([])
    -
    -        mock_write_status.assert_called_once_with(["algo-2"])
    -
    -    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ",
    -        {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"},
    -    )
    -    def test_main_with_exception(self, mock_write_failure, mock_parse):
    -        """Test main when exception occurs"""
    -        mock_parse.side_effect = Exception("Test error")
    -
    +        mock_bootstrap_master.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers")
    +    @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'})
    +    def test_main_master_node_job_ended(self, mock_write_status):
    +        """Test main function for master node after job ends."""
    +        args = ["--job_ended", "1"]
    +        
    +        main(args)
    +        
    +        mock_write_status.assert_called_once()
    +
    +    @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"})
    +    def test_main_worker_node_job_ended(self):
    +        """Test main function for worker node after job ends."""
    +        args = ["--job_ended", "1"]
    +        
    +        # Should not raise any exceptions
    +        main(args)
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file")
    +    @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon")
    +    @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"})
    +    def test_main_handles_exception(self, mock_start_sshd, mock_write_failure):
    +        """Test main function handles exceptions."""
    +        mock_start_sshd.side_effect = Exception("Test error")
    +        args = ["--job_ended", "0"]
    +        
             with pytest.raises(SystemExit) as exc_info:
    -            main([])
    -
    +            main(args)
    +        
             assert exc_info.value.code == DEFAULT_FAILURE_CODE
             mock_write_failure.assert_called_once()
    
  • sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py+361 351 modified
    @@ -10,16 +10,20 @@
     # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
     # ANY KIND, either express or implied. See the License for the specific
     # language governing permissions and limitations under the License.
    +"""Tests for runtime_environment_manager module."""
    +from __future__ import absolute_import
     
    -import pytest
    -from unittest.mock import Mock, patch, MagicMock, mock_open
    +import json
    +import os
     import subprocess
     import sys
    +import pytest
    +from unittest.mock import patch, MagicMock, mock_open, call
     
     from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import (
    +    _DependencySettings,
         RuntimeEnvironmentManager,
         RuntimeEnvironmentError,
    -    _DependencySettings,
         get_logger,
         _run_and_get_output_shell_cmd,
         _run_pre_execution_command_script,
    @@ -31,467 +35,465 @@
     
     
     class TestDependencySettings:
    -    """Test cases for _DependencySettings class"""
    +    """Test _DependencySettings class."""
    +
    +    def test_init_with_no_file(self):
    +        """Test initialization without dependency file."""
    +        settings = _DependencySettings()
    +        assert settings.dependency_file is None
     
         def test_init_with_file(self):
    -        """Test initialization with dependency file"""
    +        """Test initialization with dependency file."""
             settings = _DependencySettings(dependency_file="requirements.txt")
    -
             assert settings.dependency_file == "requirements.txt"
     
    -    def test_init_without_file(self):
    -        """Test initialization without dependency file"""
    -        settings = _DependencySettings()
    -
    -        assert settings.dependency_file is None
    -
         def test_to_string(self):
    -        """Test to_string method"""
    +        """Test converts to JSON string."""
             settings = _DependencySettings(dependency_file="requirements.txt")
    -
             result = settings.to_string()
    +        assert result == '{"dependency_file": "requirements.txt"}'
     
    -        assert "requirements.txt" in result
    -
    -    def test_from_string(self):
    -        """Test from_string method"""
    +    def test_from_string_with_file(self):
    +        """Test creates from JSON string with file."""
             json_str = '{"dependency_file": "requirements.txt"}'
    -
             settings = _DependencySettings.from_string(json_str)
    -
             assert settings.dependency_file == "requirements.txt"
     
    -    def test_from_string_none(self):
    -        """Test from_string with None"""
    +    def test_from_string_with_none(self):
    +        """Test creates from None."""
             settings = _DependencySettings.from_string(None)
    -
             assert settings is None
     
    -    def test_from_dependency_file_path(self):
    -        """Test from_dependency_file_path method"""
    -        settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt")
    -
    -        assert settings.dependency_file == "requirements.txt"
    +    def test_from_dependency_file_path_with_none(self):
    +        """Test creates from None file path."""
    +        settings = _DependencySettings.from_dependency_file_path(None)
    +        assert settings.dependency_file is None
     
    -    def test_from_dependency_file_path_auto_capture(self):
    -        """Test from_dependency_file_path with auto_capture"""
    +    def test_from_dependency_file_path_with_auto_capture(self):
    +        """Test creates from auto_capture."""
             settings = _DependencySettings.from_dependency_file_path("auto_capture")
    -
             assert settings.dependency_file == "env_snapshot.yml"
     
    -    def test_from_dependency_file_path_none(self):
    -        """Test from_dependency_file_path with None"""
    -        settings = _DependencySettings.from_dependency_file_path(None)
    +    def test_from_dependency_file_path_with_path(self):
    +        """Test creates from file path."""
    +        settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt")
    +        assert settings.dependency_file == "requirements.txt"
     
    -        assert settings.dependency_file is None
    +
    +class TestGetLogger:
    +    """Test get_logger function."""
    +
    +    def test_returns_logger(self):
    +        """Test returns logger instance."""
    +        logger = get_logger()
    +        assert logger is not None
    +        assert logger.name == "sagemaker.remote_function"
     
     
     class TestRuntimeEnvironmentManager:
    -    """Test cases for RuntimeEnvironmentManager class"""
    +    """Test RuntimeEnvironmentManager class."""
     
         def test_init(self):
    -        """Test initialization"""
    +        """Test initialization."""
             manager = RuntimeEnvironmentManager()
    -
             assert manager is not None
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile"
    -    )
    -    def test_snapshot_with_requirements_txt(self, mock_isfile):
    -        """Test snapshot with requirements.txt"""
    -        mock_isfile.return_value = True
    +    @patch("os.path.isfile")
    +    def test_snapshot_returns_none_for_none(self, mock_isfile):
    +        """Test snapshot returns None when dependencies is None."""
             manager = RuntimeEnvironmentManager()
    +        result = manager.snapshot(None)
    +        assert result is None
     
    -        result = manager.snapshot("requirements.txt")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._capture_from_local_runtime")
    +    def test_snapshot_auto_capture(self, mock_capture):
    +        """Test snapshot with auto_capture."""
    +        mock_capture.return_value = "/path/to/env_snapshot.yml"
    +        manager = RuntimeEnvironmentManager()
    +        result = manager.snapshot("auto_capture")
    +        assert result == "/path/to/env_snapshot.yml"
    +        mock_capture.assert_called_once()
     
    +    @patch("os.path.isfile")
    +    def test_snapshot_with_txt_file(self, mock_isfile):
    +        """Test snapshot with requirements.txt file."""
    +        mock_isfile.return_value = True
    +        manager = RuntimeEnvironmentManager()
    +        result = manager.snapshot("requirements.txt")
             assert result == "requirements.txt"
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile"
    -    )
    -    def test_snapshot_with_conda_yml(self, mock_isfile):
    -        """Test snapshot with conda environment.yml"""
    +    @patch("os.path.isfile")
    +    def test_snapshot_with_yml_file(self, mock_isfile):
    +        """Test snapshot with conda.yml file."""
             mock_isfile.return_value = True
             manager = RuntimeEnvironmentManager()
    -
             result = manager.snapshot("environment.yml")
    -
             assert result == "environment.yml"
     
    -    @patch.object(RuntimeEnvironmentManager, "_capture_from_local_runtime")
    -    def test_snapshot_with_auto_capture(self, mock_capture):
    -        """Test snapshot with auto_capture"""
    -        mock_capture.return_value = "env_snapshot.yml"
    -        manager = RuntimeEnvironmentManager()
    -
    -        result = manager.snapshot("auto_capture")
    -
    -        assert result == "env_snapshot.yml"
    -        mock_capture.assert_called_once()
    -
    -    def test_snapshot_with_none(self):
    -        """Test snapshot with None"""
    -        manager = RuntimeEnvironmentManager()
    -
    -        result = manager.snapshot(None)
    -
    -        assert result is None
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile"
    -    )
    -    def test_snapshot_with_invalid_file(self, mock_isfile):
    -        """Test snapshot with invalid file"""
    +    @patch("os.path.isfile")
    +    def test_snapshot_raises_error_for_invalid_file(self, mock_isfile):
    +        """Test snapshot raises error for invalid file."""
             mock_isfile.return_value = False
             manager = RuntimeEnvironmentManager()
    +        with pytest.raises(ValueError):
    +            manager.snapshot("requirements.txt")
     
    -        with pytest.raises(ValueError, match="No dependencies file named"):
    -            manager.snapshot("invalid.txt")
    -
    -    @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_name")
    -    @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_prefix")
    -    @patch.object(RuntimeEnvironmentManager, "_export_conda_env_from_prefix")
    -    def test_capture_from_local_runtime_with_conda_env(self, mock_export, mock_prefix, mock_name):
    -        """Test _capture_from_local_runtime with conda environment"""
    -        mock_name.return_value = "myenv"
    -        mock_prefix.return_value = "/opt/conda/envs/myenv"
    +    def test_snapshot_raises_error_for_invalid_format(self):
    +        """Test snapshot raises error for invalid format."""
             manager = RuntimeEnvironmentManager()
    +        with pytest.raises(ValueError):
    +            manager.snapshot("invalid.json")
     
    -        result = manager._capture_from_local_runtime()
    -
    -        assert "env_snapshot.yml" in result
    -        mock_export.assert_called_once()
    -
    -    @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_name")
    -    @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_prefix")
    -    def test_capture_from_local_runtime_no_conda_env(self, mock_prefix, mock_name):
    -        """Test _capture_from_local_runtime without conda environment"""
    -        mock_name.return_value = None
    -        mock_prefix.return_value = None
    -        manager = RuntimeEnvironmentManager()
    -
    -        with pytest.raises(ValueError, match="No conda environment"):
    -            manager._capture_from_local_runtime()
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.getenv"
    -    )
    +    @patch("os.getenv")
         def test_get_active_conda_env_prefix(self, mock_getenv):
    -        """Test _get_active_conda_env_prefix"""
    +        """Test gets active conda environment prefix."""
             mock_getenv.return_value = "/opt/conda/envs/myenv"
             manager = RuntimeEnvironmentManager()
    -
             result = manager._get_active_conda_env_prefix()
    -
             assert result == "/opt/conda/envs/myenv"
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.getenv"
    -    )
    +    @patch("os.getenv")
         def test_get_active_conda_env_name(self, mock_getenv):
    -        """Test _get_active_conda_env_name"""
    +        """Test gets active conda environment name."""
             mock_getenv.return_value = "myenv"
             manager = RuntimeEnvironmentManager()
    -
             result = manager._get_active_conda_env_name()
    -
             assert result == "myenv"
     
    -    @patch.object(RuntimeEnvironmentManager, "_install_req_txt_in_conda_env")
    -    @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file")
    -    def test_bootstrap_with_requirements_txt_and_conda_env(self, mock_write, mock_install):
    -        """Test bootstrap with requirements.txt and conda environment"""
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._export_conda_env_from_prefix")
    +    @patch("os.getcwd")
    +    @patch("os.getenv")
    +    def test_capture_from_local_runtime(self, mock_getenv, mock_getcwd, mock_export):
    +        """Test captures from local runtime."""
    +        mock_getenv.side_effect = lambda x: "myenv" if x == "CONDA_DEFAULT_ENV" else "/opt/conda/envs/myenv"
    +        mock_getcwd.return_value = "/tmp"
             manager = RuntimeEnvironmentManager()
    +        result = manager._capture_from_local_runtime()
    +        assert result == "/tmp/env_snapshot.yml"
    +        mock_export.assert_called_once()
     
    -        manager.bootstrap(
    -            local_dependencies_file="requirements.txt",
    -            client_python_version="3.8",
    -            conda_env="myenv",
    -        )
    -
    -        mock_install.assert_called_once_with("myenv", "requirements.txt")
    -        mock_write.assert_called_once_with("myenv")
    -
    -    @patch.object(RuntimeEnvironmentManager, "_install_requirements_txt")
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._python_executable"
    -    )
    -    def test_bootstrap_with_requirements_txt_no_conda_env(self, mock_python_exec, mock_install):
    -        """Test bootstrap with requirements.txt without conda environment"""
    -        mock_python_exec.return_value = "/usr/bin/python3"
    +    @patch("os.getenv")
    +    def test_capture_from_local_runtime_raises_error_no_conda(self, mock_getenv):
    +        """Test raises error when no conda environment active."""
    +        mock_getenv.return_value = None
             manager = RuntimeEnvironmentManager()
    +        with pytest.raises(ValueError):
    +            manager._capture_from_local_runtime()
     
    -        manager.bootstrap(local_dependencies_file="requirements.txt", client_python_version="3.8")
    -
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_requirements_txt")
    +    def test_bootstrap_with_txt_file_no_conda(self, mock_install):
    +        """Test bootstrap with requirements.txt without conda."""
    +        manager = RuntimeEnvironmentManager()
    +        manager.bootstrap("requirements.txt", "3.8", None)
             mock_install.assert_called_once()
     
    -    @patch.object(RuntimeEnvironmentManager, "_update_conda_env")
    -    @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file")
    -    def test_bootstrap_with_conda_yml_and_conda_env(self, mock_write, mock_update):
    -        """Test bootstrap with conda yml and existing conda environment"""
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_req_txt_in_conda_env")
    +    def test_bootstrap_with_txt_file_with_conda(self, mock_install, mock_write):
    +        """Test bootstrap with requirements.txt with conda."""
             manager = RuntimeEnvironmentManager()
    +        manager.bootstrap("requirements.txt", "3.8", "myenv")
    +        mock_install.assert_called_once()
    +        mock_write.assert_called_once()
     
    -        manager.bootstrap(
    -            local_dependencies_file="environment.yml",
    -            client_python_version="3.8",
    -            conda_env="myenv",
    -        )
    -
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._update_conda_env")
    +    def test_bootstrap_with_yml_file_with_conda(self, mock_update, mock_write):
    +        """Test bootstrap with conda.yml with existing conda env."""
    +        manager = RuntimeEnvironmentManager()
    +        manager.bootstrap("environment.yml", "3.8", "myenv")
             mock_update.assert_called_once()
             mock_write.assert_called_once()
     
    -    @patch.object(RuntimeEnvironmentManager, "_create_conda_env")
    -    @patch.object(RuntimeEnvironmentManager, "_validate_python_version")
    -    @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file")
    -    def test_bootstrap_with_conda_yml_no_conda_env(self, mock_write, mock_validate, mock_create):
    -        """Test bootstrap with conda yml without existing conda environment"""
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._validate_python_version")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._create_conda_env")
    +    def test_bootstrap_with_yml_file_without_conda(self, mock_create, mock_validate, mock_write):
    +        """Test bootstrap with conda.yml without existing conda env."""
             manager = RuntimeEnvironmentManager()
    -
    -        manager.bootstrap(local_dependencies_file="environment.yml", client_python_version="3.8")
    -
    +        manager.bootstrap("environment.yml", "3.8", None)
             mock_create.assert_called_once()
             mock_validate.assert_called_once()
             mock_write.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script"
    -    )
    -    def test_run_pre_exec_script_exists(self, mock_run_script, mock_isfile):
    -        """Test run_pre_exec_script when script exists"""
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script")
    +    @patch("os.path.isfile")
    +    def test_run_pre_exec_script_exists(self, mock_isfile, mock_run_script):
    +        """Test runs pre-execution script when it exists."""
             mock_isfile.return_value = True
             mock_run_script.return_value = (0, "")
             manager = RuntimeEnvironmentManager()
    -
             manager.run_pre_exec_script("/path/to/script.sh")
    -
             mock_run_script.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script"
    -    )
    -    def test_run_pre_exec_script_fails(self, mock_run_script, mock_isfile):
    -        """Test run_pre_exec_script when script fails"""
    +    @patch("os.path.isfile")
    +    def test_run_pre_exec_script_not_exists(self, mock_isfile):
    +        """Test handles pre-execution script not existing."""
    +        mock_isfile.return_value = False
    +        manager = RuntimeEnvironmentManager()
    +        # Should not raise exception
    +        manager.run_pre_exec_script("/path/to/script.sh")
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script")
    +    @patch("os.path.isfile")
    +    def test_run_pre_exec_script_raises_error_on_failure(self, mock_isfile, mock_run_script):
    +        """Test raises error when pre-execution script fails."""
             mock_isfile.return_value = True
             mock_run_script.return_value = (1, "Error message")
             manager = RuntimeEnvironmentManager()
    -
    -        with pytest.raises(RuntimeEnvironmentError, match="Encountered error"):
    +        with pytest.raises(RuntimeEnvironmentError):
                 manager.run_pre_exec_script("/path/to/script.sh")
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.run"
    -    )
    +    @patch("subprocess.run")
         def test_change_dir_permission_success(self, mock_run):
    -        """Test change_dir_permission successfully"""
    +        """Test changes directory permissions successfully."""
             manager = RuntimeEnvironmentManager()
    -
             manager.change_dir_permission(["/tmp/dir1", "/tmp/dir2"], "777")
    -
             mock_run.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.run"
    -    )
    -    def test_change_dir_permission_failure(self, mock_run):
    -        """Test change_dir_permission with failure"""
    -        mock_run.side_effect = subprocess.CalledProcessError(
    -            1, "chmod", stderr=b"Permission denied"
    -        )
    +    @patch("subprocess.run")
    +    def test_change_dir_permission_raises_error_on_failure(self, mock_run):
    +        """Test raises error when permission change fails."""
    +        mock_run.side_effect = subprocess.CalledProcessError(1, "chmod", stderr=b"Permission denied")
             manager = RuntimeEnvironmentManager()
    +        with pytest.raises(RuntimeEnvironmentError):
    +            manager.change_dir_permission(["/tmp/dir1"], "777")
     
    +    @patch("subprocess.run")
    +    def test_change_dir_permission_raises_error_no_sudo(self, mock_run):
    +        """Test raises error when sudo not found."""
    +        mock_run.side_effect = FileNotFoundError("[Errno 2] No such file or directory: 'sudo'")
    +        manager = RuntimeEnvironmentManager()
             with pytest.raises(RuntimeEnvironmentError):
    -            manager.change_dir_permission(["/tmp/dir"], "777")
    +            manager.change_dir_permission(["/tmp/dir1"], "777")
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd"
    -    )
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd")
         def test_install_requirements_txt(self, mock_run_cmd):
    -        """Test _install_requirements_txt"""
    +        """Test installs requirements.txt."""
             manager = RuntimeEnvironmentManager()
    -
    -        manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python3")
    -
    +        manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python")
             mock_run_cmd.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd"
    -    )
    -    @patch.object(RuntimeEnvironmentManager, "_get_conda_exe")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe")
         def test_create_conda_env(self, mock_get_conda, mock_run_cmd):
    -        """Test _create_conda_env"""
    +        """Test creates conda environment."""
             mock_get_conda.return_value = "conda"
             manager = RuntimeEnvironmentManager()
    -
             manager._create_conda_env("myenv", "/path/to/environment.yml")
    +        mock_run_cmd.assert_called_once()
     
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe")
    +    def test_install_req_txt_in_conda_env(self, mock_get_conda, mock_run_cmd):
    +        """Test installs requirements.txt in conda environment."""
    +        mock_get_conda.return_value = "conda"
    +        manager = RuntimeEnvironmentManager()
    +        manager._install_req_txt_in_conda_env("myenv", "/path/to/requirements.txt")
             mock_run_cmd.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd"
    -    )
    -    @patch.object(RuntimeEnvironmentManager, "_get_conda_exe")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe")
         def test_update_conda_env(self, mock_get_conda, mock_run_cmd):
    -        """Test _update_conda_env"""
    +        """Test updates conda environment."""
             mock_get_conda.return_value = "conda"
             manager = RuntimeEnvironmentManager()
    -
             manager._update_conda_env("myenv", "/path/to/environment.yml")
    -
             mock_run_cmd.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen"
    -    )
    -    def test_get_conda_exe_mamba(self, mock_popen):
    -        """Test _get_conda_exe returns mamba"""
    -        mock_process = Mock()
    +    @patch("builtins.open", new_callable=mock_open)
    +    @patch("subprocess.Popen")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe")
    +    def test_export_conda_env_from_prefix(self, mock_get_conda, mock_popen, mock_file):
    +        """Test exports conda environment."""
    +        mock_get_conda.return_value = "conda"
    +        mock_process = MagicMock()
    +        mock_process.communicate.return_value = (b"env output", b"")
             mock_process.wait.return_value = 0
             mock_popen.return_value = mock_process
    +        
             manager = RuntimeEnvironmentManager()
    +        manager._export_conda_env_from_prefix("/opt/conda/envs/myenv", "/tmp/env.yml")
    +        
    +        mock_popen.assert_called_once()
    +        mock_file.assert_called_once_with("/tmp/env.yml", "w")
     
    +    @patch("builtins.open", new_callable=mock_open)
    +    @patch("os.getcwd")
    +    def test_write_conda_env_to_file(self, mock_getcwd, mock_file):
    +        """Test writes conda environment name to file."""
    +        mock_getcwd.return_value = "/tmp"
    +        manager = RuntimeEnvironmentManager()
    +        manager._write_conda_env_to_file("myenv")
    +        mock_file.assert_called_once_with("/tmp/remote_function_conda_env.txt", "w")
    +        mock_file().write.assert_called_once_with("myenv")
    +
    +    @patch("subprocess.Popen")
    +    def test_get_conda_exe_returns_mamba(self, mock_popen):
    +        """Test returns mamba when available."""
    +        mock_popen.return_value.wait.side_effect = [0, 1]  # mamba exists, conda doesn't
    +        manager = RuntimeEnvironmentManager()
             result = manager._get_conda_exe()
    -
             assert result == "mamba"
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen"
    -    )
    -    def test_get_conda_exe_conda(self, mock_popen):
    -        """Test _get_conda_exe returns conda"""
    -        mock_process = Mock()
    -        mock_process.wait.side_effect = [1, 0]  # mamba not found, conda found
    -        mock_popen.return_value = mock_process
    +    @patch("subprocess.Popen")
    +    def test_get_conda_exe_returns_conda(self, mock_popen):
    +        """Test returns conda when mamba not available."""
    +        mock_popen.return_value.wait.side_effect = [1, 0]  # mamba doesn't exist, conda does
             manager = RuntimeEnvironmentManager()
    -
             result = manager._get_conda_exe()
    -
             assert result == "conda"
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen"
    -    )
    -    def test_get_conda_exe_not_found(self, mock_popen):
    -        """Test _get_conda_exe when neither mamba nor conda found"""
    -        mock_process = Mock()
    -        mock_process.wait.return_value = 1
    -        mock_popen.return_value = mock_process
    +    @patch("subprocess.Popen")
    +    def test_get_conda_exe_raises_error(self, mock_popen):
    +        """Test raises error when neither conda nor mamba available."""
    +        mock_popen.return_value.wait.return_value = 1
             manager = RuntimeEnvironmentManager()
    -
    -        with pytest.raises(ValueError, match="Neither conda nor mamba"):
    +        with pytest.raises(ValueError):
                 manager._get_conda_exe()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.check_output"
    -    )
    -    @patch.object(RuntimeEnvironmentManager, "_get_conda_exe")
    +    @patch("subprocess.check_output")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe")
         def test_python_version_in_conda_env(self, mock_get_conda, mock_check_output):
    -        """Test _python_version_in_conda_env"""
    +        """Test gets Python version in conda environment."""
             mock_get_conda.return_value = "conda"
             mock_check_output.return_value = b"Python 3.8.10"
             manager = RuntimeEnvironmentManager()
    -
             result = manager._python_version_in_conda_env("myenv")
    -
             assert result == "3.8"
     
    -    def test_current_python_version(self):
    -        """Test _current_python_version"""
    +    @patch("subprocess.check_output")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe")
    +    def test_python_version_in_conda_env_raises_error(self, mock_get_conda, mock_check_output):
    +        """Test raises error when getting Python version fails."""
    +        mock_get_conda.return_value = "conda"
    +        mock_check_output.side_effect = subprocess.CalledProcessError(1, "conda", output=b"Error")
             manager = RuntimeEnvironmentManager()
    +        with pytest.raises(RuntimeEnvironmentError):
    +            manager._python_version_in_conda_env("myenv")
     
    +    def test_current_python_version(self):
    +        """Test gets current Python version."""
    +        manager = RuntimeEnvironmentManager()
             result = manager._current_python_version()
    +        expected = f"{sys.version_info.major}.{sys.version_info.minor}"
    +        assert result == expected
     
    -        assert result == f"{sys.version_info.major}.{sys.version_info.minor}"
    -
    -    @patch.object(RuntimeEnvironmentManager, "_python_version_in_conda_env")
    -    def test_validate_python_version_match(self, mock_python_version):
    -        """Test _validate_python_version when versions match"""
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env")
    +    def test_validate_python_version_with_conda(self, mock_python_version):
    +        """Test validates Python version with conda environment."""
             mock_python_version.return_value = "3.8"
             manager = RuntimeEnvironmentManager()
    +        # Should not raise exception
    +        manager._validate_python_version("3.8", "myenv")
     
    -        # Should not raise error
    -        manager._validate_python_version("3.8", conda_env="myenv")
    -
    -    @patch.object(RuntimeEnvironmentManager, "_python_version_in_conda_env")
    -    def test_validate_python_version_mismatch(self, mock_python_version):
    -        """Test _validate_python_version when versions don't match"""
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env")
    +    def test_validate_python_version_mismatch_with_conda(self, mock_python_version):
    +        """Test raises error on Python version mismatch with conda."""
             mock_python_version.return_value = "3.9"
             manager = RuntimeEnvironmentManager()
    +        with pytest.raises(RuntimeEnvironmentError):
    +            manager._validate_python_version("3.8", "myenv")
     
    -        with pytest.raises(RuntimeEnvironmentError, match="does not match"):
    -            manager._validate_python_version("3.8", conda_env="myenv")
    -
    -    @patch.object(RuntimeEnvironmentManager, "_current_sagemaker_pysdk_version")
    -    def test_validate_sagemaker_pysdk_version_match(self, mock_version):
    -        """Test _validate_sagemaker_pysdk_version when versions match"""
    -        mock_version.return_value = "2.0.0"
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version")
    +    def test_validate_python_version_without_conda(self, mock_current_version):
    +        """Test validates Python version without conda environment."""
    +        mock_current_version.return_value = "3.8"
             manager = RuntimeEnvironmentManager()
    +        # Should not raise exception
    +        manager._validate_python_version("3.8", None)
     
    -        # Should not raise error, just log warning
    -        manager._validate_sagemaker_pysdk_version("2.0.0")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version")
    +    def test_validate_python_version_mismatch_without_conda(self, mock_current_version):
    +        """Test raises error on Python version mismatch without conda."""
    +        mock_current_version.return_value = "3.9"
    +        manager = RuntimeEnvironmentManager()
    +        with pytest.raises(RuntimeEnvironmentError):
    +            manager._validate_python_version("3.8", None)
     
    -    @patch.object(RuntimeEnvironmentManager, "_current_sagemaker_pysdk_version")
    -    def test_validate_sagemaker_pysdk_version_mismatch(self, mock_version):
    -        """Test _validate_sagemaker_pysdk_version when versions don't match"""
    -        mock_version.return_value = "2.1.0"
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version")
    +    def test_validate_sagemaker_pysdk_version_match(self, mock_current_version):
    +        """Test validates matching SageMaker SDK version."""
    +        mock_current_version.return_value = "2.100.0"
             manager = RuntimeEnvironmentManager()
    +        # Should not raise exception or warning
    +        manager._validate_sagemaker_pysdk_version("2.100.0")
     
    -        # Should log warning but not raise error
    -        manager._validate_sagemaker_pysdk_version("2.0.0")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version")
    +    def test_validate_sagemaker_pysdk_version_mismatch(self, mock_current_version):
    +        """Test logs warning on SageMaker SDK version mismatch."""
    +        mock_current_version.return_value = "2.101.0"
    +        manager = RuntimeEnvironmentManager()
    +        # Should log warning but not raise exception
    +        manager._validate_sagemaker_pysdk_version("2.100.0")
     
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version")
    +    def test_validate_sagemaker_pysdk_version_none(self, mock_current_version):
    +        """Test handles None client version."""
    +        mock_current_version.return_value = "2.100.0"
    +        manager = RuntimeEnvironmentManager()
    +        # Should not raise exception
    +        manager._validate_sagemaker_pysdk_version(None)
     
    -class TestHelperFunctions:
    -    """Test cases for helper functions"""
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.check_output"
    -    )
    -    def test_run_and_get_output_shell_cmd(self, mock_check_output):
    -        """Test _run_and_get_output_shell_cmd"""
    -        mock_check_output.return_value = b"output"
    +class TestRunAndGetOutputShellCmd:
    +    """Test _run_and_get_output_shell_cmd function."""
     
    +    @patch("subprocess.check_output")
    +    def test_runs_command_successfully(self, mock_check_output):
    +        """Test runs command and returns output."""
    +        mock_check_output.return_value = b"command output"
             result = _run_and_get_output_shell_cmd("echo test")
    +        assert result == "command output"
    +
     
    -        assert result == "output"
    -
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error"
    -    )
    -    def test_run_pre_execution_command_script(self, mock_log_error, mock_log_output, mock_popen):
    -        """Test _run_pre_execution_command_script"""
    -        mock_process = Mock()
    +class TestRunPreExecutionCommandScript:
    +    """Test _run_pre_execution_command_script function."""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output")
    +    @patch("subprocess.Popen")
    +    @patch("os.path.dirname")
    +    def test_runs_script_successfully(self, mock_dirname, mock_popen, mock_log_output, mock_log_error):
    +        """Test runs script successfully."""
    +        mock_dirname.return_value = "/tmp"
    +        mock_process = MagicMock()
             mock_process.wait.return_value = 0
             mock_popen.return_value = mock_process
             mock_log_error.return_value = ""
    +        
    +        return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh")
    +        
    +        assert return_code == 0
    +        assert error_logs == ""
    +
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output")
    +    @patch("subprocess.Popen")
    +    @patch("os.path.dirname")
    +    def test_runs_script_with_error(self, mock_dirname, mock_popen, mock_log_output, mock_log_error):
    +        """Test runs script that returns error."""
    +        mock_dirname.return_value = "/tmp"
    +        mock_process = MagicMock()
    +        mock_process.wait.return_value = 1
    +        mock_popen.return_value = mock_process
    +        mock_log_error.return_value = "Error message"
    +        
    +        return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh")
    +        
    +        assert return_code == 1
    +        assert error_logs == "Error message"
     
    -        return_code, error_logs = _run_pre_execution_command_script("/path/to/script.sh")
     
    -        assert return_code == 0
    +class TestRunShellCmd:
    +    """Test _run_shell_cmd function."""
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error"
    -    )
    -    def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen):
    -        """Test _run_shell_cmd with successful command"""
    -        mock_process = Mock()
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output")
    +    @patch("subprocess.Popen")
    +    def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_error):
    +        """Test runs command successfully."""
    +        mock_process = MagicMock()
             mock_process.wait.return_value = 0
             mock_popen.return_value = mock_process
             mock_log_error.return_value = ""
    @@ -500,63 +502,71 @@ def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen
             
             mock_popen.assert_called_once()
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output"
    -    )
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error"
    -    )
    -    def test_run_shell_cmd_failure(self, mock_log_error, mock_log_output, mock_popen):
    -        """Test _run_shell_cmd with failed command"""
    -        mock_process = Mock()
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output")
    +    @patch("subprocess.Popen")
    +    def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error):
    +        """Test raises error when command fails."""
    +        mock_process = MagicMock()
             mock_process.wait.return_value = 1
             mock_popen.return_value = mock_process
             mock_log_error.return_value = "Error message"
    -
    -        with pytest.raises(RuntimeEnvironmentError, match="Encountered error"):
    +        
    +        with pytest.raises(RuntimeEnvironmentError):
                 _run_shell_cmd(["false"])
     
    -    def test_python_executable(self):
    -        """Test _python_executable"""
    -        result = _python_executable()
     
    -        assert result == sys.executable
    +class TestLogOutput:
    +    """Test _log_output function."""
     
    -    @patch(
    -        "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.sys.executable",
    -        None,
    -    )
    -    def test_python_executable_not_found(self):
    -        """Test _python_executable when not found"""
    -        with pytest.raises(RuntimeEnvironmentError, match="Failed to retrieve"):
    -            _python_executable()
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger")
    +    def test_logs_output(self, mock_logger):
    +        """Test logs process output."""
    +        from io import BytesIO
    +        mock_process = MagicMock()
    +        mock_process.stdout = BytesIO(b"line1\nline2\n")
    +        
    +        _log_output(mock_process)
    +        
    +        assert mock_logger.info.call_count == 2
     
     
    -class TestRuntimeEnvironmentError:
    -    """Test cases for RuntimeEnvironmentError exception"""
    +class TestLogError:
    +    """Test _log_error function."""
     
    -    def test_init(self):
    -        """Test initialization"""
    -        error = RuntimeEnvironmentError("Test error message")
    +    @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger")
    +    def test_logs_error(self, mock_logger):
    +        """Test logs process errors."""
    +        from io import BytesIO
    +        mock_process = MagicMock()
    +        mock_process.stderr = BytesIO(b"ERROR: error message\nwarning message\n")
    +        
    +        error_logs = _log_error(mock_process)
    +        
    +        assert "ERROR: error message" in error_logs
    +        assert "warning message" in error_logs
     
    -        assert error.message == "Test error message"
    -        assert str(error) == "Test error message"
     
    -    def test_raise(self):
    -        """Test raising the exception"""
    -        with pytest.raises(RuntimeEnvironmentError, match="Test error"):
    -            raise RuntimeEnvironmentError("Test error")
    +class TestPythonExecutable:
    +    """Test _python_executable function."""
     
    +    def test_returns_python_executable(self):
    +        """Test returns Python executable path."""
    +        result = _python_executable()
    +        assert result == sys.executable
     
    -class TestGetLogger:
    -    """Test cases for get_logger function"""
    +    @patch("sys.executable", None)
    +    def test_raises_error_if_no_executable(self):
    +        """Test raises error if no Python executable."""
    +        with pytest.raises(RuntimeEnvironmentError):
    +            _python_executable()
     
    -    def test_get_logger(self):
    -        """Test get_logger returns logger"""
    -        logger = get_logger()
     
    -        assert logger is not None
    -        assert logger.name == "sagemaker.remote_function"
    +class TestRuntimeEnvironmentError:
    +    """Test RuntimeEnvironmentError class."""
    +
    +    def test_creates_error_with_message(self):
    +        """Test creates error with message."""
    +        error = RuntimeEnvironmentError("Test error")
    +        assert str(error) == "Test error"
    +        assert error.message == "Test error"
    
  • sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py+82 0 added
    @@ -0,0 +1,82 @@
    +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    +#
    +# Licensed under the Apache License, Version 2.0 (the "License"). You
    +# may not use this file except in compliance with the License. A copy of
    +# the License is located at
    +#
    +#     http://aws.amazon.com/apache2.0/
    +#
    +# or in the "license" file accompanying this file. This file is
    +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
    +# ANY KIND, either express or implied. See the License for the specific
    +# language governing permissions and limitations under the License.
    +"""Tests for checkpoint_location module."""
    +from __future__ import absolute_import
    +
    +import pytest
    +from sagemaker.core.remote_function.checkpoint_location import (
    +    CheckpointLocation,
    +    _validate_s3_uri_for_checkpoint,
    +    _JOB_CHECKPOINT_LOCATION,
    +)
    +
    +
    +class TestValidateS3Uri:
    +    """Test _validate_s3_uri_for_checkpoint function."""
    +
    +    def test_valid_s3_uri(self):
    +        """Test valid s3:// URI."""
    +        assert _validate_s3_uri_for_checkpoint("s3://my-bucket/path/to/checkpoints")
    +
    +    def test_valid_https_uri(self):
    +        """Test valid https:// URI."""
    +        assert _validate_s3_uri_for_checkpoint("https://my-bucket.s3.amazonaws.com/path")
    +
    +    def test_valid_s3_uri_no_path(self):
    +        """Test valid s3:// URI without path."""
    +        assert _validate_s3_uri_for_checkpoint("s3://my-bucket")
    +
    +    def test_invalid_uri_no_protocol(self):
    +        """Test invalid URI without protocol."""
    +        assert not _validate_s3_uri_for_checkpoint("my-bucket/path")
    +
    +    def test_invalid_uri_wrong_protocol(self):
    +        """Test invalid URI with wrong protocol."""
    +        assert not _validate_s3_uri_for_checkpoint("http://my-bucket/path")
    +
    +    def test_invalid_uri_empty(self):
    +        """Test invalid empty URI."""
    +        assert not _validate_s3_uri_for_checkpoint("")
    +
    +
    +class TestCheckpointLocation:
    +    """Test CheckpointLocation class."""
    +
    +    def test_init_with_valid_s3_uri(self):
    +        """Test initialization with valid s3 URI."""
    +        s3_uri = "s3://my-bucket/checkpoints"
    +        checkpoint_loc = CheckpointLocation(s3_uri)
    +        assert checkpoint_loc._s3_uri == s3_uri
    +
    +    def test_init_with_valid_https_uri(self):
    +        """Test initialization with valid https URI."""
    +        s3_uri = "https://my-bucket.s3.amazonaws.com/checkpoints"
    +        checkpoint_loc = CheckpointLocation(s3_uri)
    +        assert checkpoint_loc._s3_uri == s3_uri
    +
    +    def test_init_with_invalid_uri_raises_error(self):
    +        """Test initialization with invalid URI raises ValueError."""
    +        with pytest.raises(ValueError, match="CheckpointLocation should be specified with valid s3 URI"):
    +            CheckpointLocation("invalid-uri")
    +
    +    def test_fspath_returns_local_path(self):
    +        """Test __fspath__ returns the job local path."""
    +        checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints")
    +        assert checkpoint_loc.__fspath__() == _JOB_CHECKPOINT_LOCATION
    +
    +    def test_can_be_used_as_pathlike(self):
    +        """Test CheckpointLocation can be used as os.PathLike."""
    +        import os
    +        checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints")
    +        path = os.fspath(checkpoint_loc)
    +        assert path == _JOB_CHECKPOINT_LOCATION
    
  • sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py+169 0 added
    @@ -0,0 +1,169 @@
    +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    +#
    +# Licensed under the Apache License, Version 2.0 (the "License"). You
    +# may not use this file except in compliance with the License. A copy of
    +# the License is located at
    +#
    +#     http://aws.amazon.com/apache2.0/
    +#
    +# or in the "license" file accompanying this file. This file is
    +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
    +# ANY KIND, either express or implied. See the License for the specific
    +# language governing permissions and limitations under the License.
    +"""Tests for custom_file_filter module."""
    +from __future__ import absolute_import
    +
    +import os
    +import tempfile
    +import shutil
    +from unittest.mock import patch, MagicMock
    +import pytest
    +
    +from sagemaker.core.remote_function.custom_file_filter import (
    +    CustomFileFilter,
    +    resolve_custom_file_filter_from_config_file,
    +    copy_workdir,
    +)
    +
    +
    +class TestCustomFileFilter:
    +    """Test CustomFileFilter class."""
    +
    +    def test_init_with_no_patterns(self):
    +        """Test initialization without ignore patterns."""
    +        filter_obj = CustomFileFilter()
    +        assert filter_obj.ignore_name_patterns == []
    +        assert filter_obj.workdir == os.getcwd()
    +
    +    def test_init_with_patterns(self):
    +        """Test initialization with ignore patterns."""
    +        patterns = ["*.pyc", "__pycache__", "*.log"]
    +        filter_obj = CustomFileFilter(ignore_name_patterns=patterns)
    +        assert filter_obj.ignore_name_patterns == patterns
    +
    +    def test_ignore_name_patterns_property(self):
    +        """Test ignore_name_patterns property."""
    +        patterns = ["*.txt", "temp*"]
    +        filter_obj = CustomFileFilter(ignore_name_patterns=patterns)
    +        assert filter_obj.ignore_name_patterns == patterns
    +
    +    def test_workdir_property(self):
    +        """Test workdir property."""
    +        filter_obj = CustomFileFilter()
    +        assert filter_obj.workdir == os.getcwd()
    +
    +
    +class TestResolveCustomFileFilterFromConfigFile:
    +    """Test resolve_custom_file_filter_from_config_file function."""
    +
    +    def test_returns_direct_input_when_provided_as_filter(self):
    +        """Test returns direct input when CustomFileFilter is provided."""
    +        filter_obj = CustomFileFilter(ignore_name_patterns=["*.pyc"])
    +        result = resolve_custom_file_filter_from_config_file(direct_input=filter_obj)
    +        assert result is filter_obj
    +
    +    def test_returns_direct_input_when_provided_as_callable(self):
    +        """Test returns direct input when callable is provided."""
    +        def custom_filter(path, names):
    +            return []
    +        result = resolve_custom_file_filter_from_config_file(direct_input=custom_filter)
    +        assert result is custom_filter
    +
    +    @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config")
    +    def test_returns_none_when_no_config(self, mock_resolve):
    +        """Test returns None when no config is found."""
    +        mock_resolve.return_value = None
    +        result = resolve_custom_file_filter_from_config_file()
    +        assert result is None
    +
    +    @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config")
    +    def test_creates_filter_from_config(self, mock_resolve):
    +        """Test creates CustomFileFilter from config."""
    +        patterns = ["*.pyc", "*.log"]
    +        mock_resolve.return_value = patterns
    +        result = resolve_custom_file_filter_from_config_file()
    +        assert isinstance(result, CustomFileFilter)
    +        assert result.ignore_name_patterns == patterns
    +
    +    @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config")
    +    def test_passes_sagemaker_session_to_resolve(self, mock_resolve):
    +        """Test passes sagemaker_session to resolve_value_from_config."""
    +        mock_session = MagicMock()
    +        mock_resolve.return_value = None
    +        resolve_custom_file_filter_from_config_file(sagemaker_session=mock_session)
    +        mock_resolve.assert_called_once()
    +        assert mock_resolve.call_args[1]["sagemaker_session"] == mock_session
    +
    +
    +class TestCopyWorkdir:
    +    """Test copy_workdir function."""
    +
    +    def setup_method(self):
    +        """Set up test fixtures."""
    +        self.temp_src = tempfile.mkdtemp()
    +        self.temp_dst = tempfile.mkdtemp()
    +        
    +        # Create test files
    +        with open(os.path.join(self.temp_src, "test.py"), "w") as f:
    +            f.write("print('test')")
    +        with open(os.path.join(self.temp_src, "test.txt"), "w") as f:
    +            f.write("text file")
    +        os.makedirs(os.path.join(self.temp_src, "__pycache__"))
    +        with open(os.path.join(self.temp_src, "__pycache__", "test.pyc"), "w") as f:
    +            f.write("compiled")
    +
    +    def teardown_method(self):
    +        """Clean up test fixtures."""
    +        if os.path.exists(self.temp_src):
    +            shutil.rmtree(self.temp_src)
    +        if os.path.exists(self.temp_dst):
    +            shutil.rmtree(self.temp_dst)
    +
    +    @patch("os.getcwd")
    +    def test_copy_workdir_without_filter_only_python_files(self, mock_getcwd):
    +        """Test copy_workdir without filter copies only Python files."""
    +        mock_getcwd.return_value = self.temp_src
    +        dst = os.path.join(self.temp_dst, "output")
    +        
    +        copy_workdir(dst)
    +        
    +        assert os.path.exists(os.path.join(dst, "test.py"))
    +        assert not os.path.exists(os.path.join(dst, "test.txt"))
    +        assert not os.path.exists(os.path.join(dst, "__pycache__"))
    +
    +    @patch("os.getcwd")
    +    def test_copy_workdir_with_callable_filter(self, mock_getcwd):
    +        """Test copy_workdir with callable filter."""
    +        mock_getcwd.return_value = self.temp_src
    +        dst = os.path.join(self.temp_dst, "output")
    +        
    +        def custom_filter(path, names):
    +            return ["test.txt"]
    +        
    +        copy_workdir(dst, custom_file_filter=custom_filter)
    +        
    +        assert os.path.exists(os.path.join(dst, "test.py"))
    +        assert not os.path.exists(os.path.join(dst, "test.txt"))
    +
    +    def test_copy_workdir_with_custom_file_filter_object(self):
    +        """Test copy_workdir with CustomFileFilter object."""
    +        filter_obj = CustomFileFilter(ignore_name_patterns=["*.py"])
    +        filter_obj._workdir = self.temp_src
    +        dst = os.path.join(self.temp_dst, "output")
    +        
    +        copy_workdir(dst, custom_file_filter=filter_obj)
    +        
    +        assert not os.path.exists(os.path.join(dst, "test.py"))
    +        assert os.path.exists(os.path.join(dst, "test.txt"))
    +
    +    def test_copy_workdir_with_pattern_matching(self):
    +        """Test copy_workdir with pattern matching in CustomFileFilter."""
    +        filter_obj = CustomFileFilter(ignore_name_patterns=["*.txt", "__pycache__"])
    +        filter_obj._workdir = self.temp_src
    +        dst = os.path.join(self.temp_dst, "output")
    +        
    +        copy_workdir(dst, custom_file_filter=filter_obj)
    +        
    +        assert os.path.exists(os.path.join(dst, "test.py"))
    +        assert not os.path.exists(os.path.join(dst, "test.txt"))
    +        assert not os.path.exists(os.path.join(dst, "__pycache__"))
    
  • sagemaker-core/tests/unit/remote_function/test_invoke_function.py+280 0 added
    @@ -0,0 +1,280 @@
    +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    +#
    +# Licensed under the Apache License, Version 2.0 (the "License"). You
    +# may not use this file except in compliance with the License. A copy of
    +# the License is located at
    +#
    +#     http://aws.amazon.com/apache2.0/
    +#
    +# or in the "license" file accompanying this file. This file is
    +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
    +# ANY KIND, either express or implied. See the License for the specific
    +# language governing permissions and limitations under the License.
    +"""Tests for invoke_function module."""
    +from __future__ import absolute_import
    +
    +import json
    +import pytest
    +from unittest.mock import patch, MagicMock, call
    +
    +from sagemaker.core.remote_function.invoke_function import (
    +    _parse_args,
    +    _get_sagemaker_session,
    +    _load_run_object,
    +    _load_pipeline_context,
    +    _execute_remote_function,
    +    main,
    +    SUCCESS_EXIT_CODE,
    +)
    +from sagemaker.core.remote_function.job import KEY_EXPERIMENT_NAME, KEY_RUN_NAME
    +
    +
    +class TestParseArgs:
    +    """Test _parse_args function."""
    +
    +    def test_parse_required_args(self):
    +        """Test parsing required arguments."""
    +        args = [
    +            "--region", "us-west-2",
    +            "--s3_base_uri", "s3://my-bucket/path",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.region == "us-west-2"
    +        assert parsed.s3_base_uri == "s3://my-bucket/path"
    +
    +    def test_parse_all_args(self):
    +        """Test parsing all arguments."""
    +        args = [
    +            "--region", "us-east-1",
    +            "--s3_base_uri", "s3://bucket/path",
    +            "--s3_kms_key", "key-123",
    +            "--run_in_context", '{"experiment": "exp1"}',
    +            "--pipeline_step_name", "step1",
    +            "--pipeline_execution_id", "exec-123",
    +            "--property_references", "prop1", "val1", "prop2", "val2",
    +            "--serialize_output_to_json", "true",
    +            "--func_step_s3_dir", "s3://bucket/func",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.region == "us-east-1"
    +        assert parsed.s3_base_uri == "s3://bucket/path"
    +        assert parsed.s3_kms_key == "key-123"
    +        assert parsed.run_in_context == '{"experiment": "exp1"}'
    +        assert parsed.pipeline_step_name == "step1"
    +        assert parsed.pipeline_execution_id == "exec-123"
    +        assert parsed.property_references == ["prop1", "val1", "prop2", "val2"]
    +        assert parsed.serialize_output_to_json is True
    +        assert parsed.func_step_s3_dir == "s3://bucket/func"
    +
    +    def test_parse_serialize_output_false(self):
    +        """Test parsing serialize_output_to_json as false."""
    +        args = [
    +            "--region", "us-west-2",
    +            "--s3_base_uri", "s3://bucket/path",
    +            "--serialize_output_to_json", "false",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.serialize_output_to_json is False
    +
    +    def test_parse_default_values(self):
    +        """Test default values for optional arguments."""
    +        args = [
    +            "--region", "us-west-2",
    +            "--s3_base_uri", "s3://bucket/path",
    +        ]
    +        parsed = _parse_args(args)
    +        assert parsed.s3_kms_key is None
    +        assert parsed.run_in_context is None
    +        assert parsed.pipeline_step_name is None
    +        assert parsed.pipeline_execution_id is None
    +        assert parsed.property_references == []
    +        assert parsed.serialize_output_to_json is False
    +        assert parsed.func_step_s3_dir is None
    +
    +
    +class TestGetSagemakerSession:
    +    """Test _get_sagemaker_session function."""
    +
    +    @patch("sagemaker.core.remote_function.invoke_function.boto3.session.Session")
    +    @patch("sagemaker.core.remote_function.invoke_function.Session")
    +    def test_creates_session_with_region(self, mock_session_class, mock_boto_session):
    +        """Test creates SageMaker session with correct region."""
    +        mock_boto = MagicMock()
    +        mock_boto_session.return_value = mock_boto
    +        
    +        _get_sagemaker_session("us-west-2")
    +        
    +        mock_boto_session.assert_called_once_with(region_name="us-west-2")
    +        mock_session_class.assert_called_once_with(boto_session=mock_boto)
    +
    +
    +class TestLoadRunObject:
    +    """Test _load_run_object function."""
    +
    +    @patch("sagemaker.core.experiments.run.Run")
    +    def test_loads_run_from_json(self, mock_run_class):
    +        """Test loads Run object from JSON string."""
    +        run_dict = {
    +            KEY_EXPERIMENT_NAME: "my-experiment",
    +            KEY_RUN_NAME: "my-run",
    +        }
    +        run_json = json.dumps(run_dict)
    +        mock_session = MagicMock()
    +        
    +        _load_run_object(run_json, mock_session)
    +        
    +        mock_run_class.assert_called_once_with(
    +            experiment_name="my-experiment",
    +            run_name="my-run",
    +            sagemaker_session=mock_session,
    +        )
    +
    +
    +class TestLoadPipelineContext:
    +    """Test _load_pipeline_context function."""
    +
    +    def test_loads_context_with_all_fields(self):
    +        """Test loads pipeline context with all fields."""
    +        args = MagicMock()
    +        args.pipeline_step_name = "step1"
    +        args.pipeline_execution_id = "exec-123"
    +        args.property_references = ["prop1", "val1", "prop2", "val2"]
    +        args.serialize_output_to_json = True
    +        args.func_step_s3_dir = "s3://bucket/func"
    +        
    +        context = _load_pipeline_context(args)
    +        
    +        assert context.step_name == "step1"
    +        assert context.execution_id == "exec-123"
    +        assert context.property_references == {"prop1": "val1", "prop2": "val2"}
    +        assert context.serialize_output_to_json is True
    +        assert context.func_step_s3_dir == "s3://bucket/func"
    +
    +    def test_loads_context_with_empty_property_references(self):
    +        """Test loads pipeline context with empty property references."""
    +        args = MagicMock()
    +        args.pipeline_step_name = "step1"
    +        args.pipeline_execution_id = "exec-123"
    +        args.property_references = []
    +        args.serialize_output_to_json = False
    +        args.func_step_s3_dir = None
    +        
    +        context = _load_pipeline_context(args)
    +        
    +        assert context.property_references == {}
    +
    +
    +class TestExecuteRemoteFunction:
    +    """Test _execute_remote_function function."""
    +
    +    @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction")
    +    def test_executes_without_run_context(self, mock_stored_function_class):
    +        """Test executes stored function without run context."""
    +        mock_stored_func = MagicMock()
    +        mock_stored_function_class.return_value = mock_stored_func
    +        mock_session = MagicMock()
    +        mock_context = MagicMock()
    +        
    +        _execute_remote_function(
    +            sagemaker_session=mock_session,
    +            s3_base_uri="s3://bucket/path",
    +            s3_kms_key="key-123",
    +            run_in_context=None,
    +            context=mock_context,
    +        )
    +        
    +        mock_stored_function_class.assert_called_once_with(
    +            sagemaker_session=mock_session,
    +            s3_base_uri="s3://bucket/path",
    +            s3_kms_key="key-123",
    +            context=mock_context,
    +        )
    +        mock_stored_func.load_and_invoke.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.invoke_function._load_run_object")
    +    @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction")
    +    def test_executes_with_run_context(self, mock_stored_function_class, mock_load_run):
    +        """Test executes stored function with run context."""
    +        mock_stored_func = MagicMock()
    +        mock_stored_function_class.return_value = mock_stored_func
    +        mock_run = MagicMock()
    +        mock_load_run.return_value = mock_run
    +        mock_session = MagicMock()
    +        mock_context = MagicMock()
    +        run_json = '{"experiment": "exp1"}'
    +        
    +        _execute_remote_function(
    +            sagemaker_session=mock_session,
    +            s3_base_uri="s3://bucket/path",
    +            s3_kms_key=None,
    +            run_in_context=run_json,
    +            context=mock_context,
    +        )
    +        
    +        # Verify run object was loaded and used as context manager
    +        mock_load_run.assert_called_once_with(run_json, mock_session)
    +        mock_run.__enter__.assert_called_once()
    +        mock_run.__exit__.assert_called_once()
    +
    +
    +class TestMain:
    +    """Test main function."""
    +
    +    @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function")
    +    @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session")
    +    @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context")
    +    @patch("sagemaker.core.remote_function.invoke_function._parse_args")
    +    def test_main_success(self, mock_parse, mock_load_context, mock_get_session, mock_execute):
    +        """Test main function successful execution."""
    +        mock_args = MagicMock()
    +        mock_args.region = "us-west-2"
    +        mock_args.s3_base_uri = "s3://bucket/path"
    +        mock_args.s3_kms_key = None
    +        mock_args.run_in_context = None
    +        mock_parse.return_value = mock_args
    +        
    +        mock_context = MagicMock()
    +        mock_context.step_name = None
    +        mock_load_context.return_value = mock_context
    +        
    +        mock_session = MagicMock()
    +        mock_get_session.return_value = mock_session
    +        
    +        with pytest.raises(SystemExit) as exc_info:
    +            main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"])
    +        
    +        assert exc_info.value.code == SUCCESS_EXIT_CODE
    +        mock_execute.assert_called_once()
    +
    +    @patch("sagemaker.core.remote_function.invoke_function.handle_error")
    +    @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function")
    +    @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session")
    +    @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context")
    +    @patch("sagemaker.core.remote_function.invoke_function._parse_args")
    +    def test_main_handles_exception(
    +        self, mock_parse, mock_load_context, mock_get_session, mock_execute, mock_handle_error
    +    ):
    +        """Test main function handles exceptions."""
    +        mock_args = MagicMock()
    +        mock_args.region = "us-west-2"
    +        mock_args.s3_base_uri = "s3://bucket/path"
    +        mock_args.s3_kms_key = None
    +        mock_args.run_in_context = None
    +        mock_parse.return_value = mock_args
    +        
    +        mock_context = MagicMock()
    +        mock_context.step_name = None
    +        mock_load_context.return_value = mock_context
    +        
    +        mock_session = MagicMock()
    +        mock_get_session.return_value = mock_session
    +        
    +        test_exception = Exception("Test error")
    +        mock_execute.side_effect = test_exception
    +        mock_handle_error.return_value = 1
    +        
    +        with pytest.raises(SystemExit) as exc_info:
    +            main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"])
    +        
    +        assert exc_info.value.code == 1
    +        mock_handle_error.assert_called_once()
    
  • sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py+5 7 modified
    @@ -144,17 +144,15 @@ def test_from_describe_response(self, mock_session):
             response = {
                 "TrainingJobName": "test-job",
                 "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"},
    -            "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"},
             }
             job = _Job.from_describe_response(response, mock_session)
             assert job.job_name == "test-job"
             assert job.s3_uri == "s3://bucket/output"
    -        assert job.hmac_key == "test-key"
             assert job._last_describe_response == response
     
         def test_describe_cached_completed(self, mock_session):
             """Test lines 865-871: describe with cached completed job."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             job._last_describe_response = {"TrainingJobStatus": "Completed"}
     
             result = job.describe()
    @@ -163,7 +161,7 @@ def test_describe_cached_completed(self, mock_session):
     
         def test_describe_cached_failed(self, mock_session):
             """Test lines 865-871: describe with cached failed job."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             job._last_describe_response = {"TrainingJobStatus": "Failed"}
     
             result = job.describe()
    @@ -172,7 +170,7 @@ def test_describe_cached_failed(self, mock_session):
     
         def test_describe_cached_stopped(self, mock_session):
             """Test lines 865-871: describe with cached stopped job."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             job._last_describe_response = {"TrainingJobStatus": "Stopped"}
     
             result = job.describe()
    @@ -181,7 +179,7 @@ def test_describe_cached_stopped(self, mock_session):
     
         def test_stop(self, mock_session):
             """Test lines 886-887: stop method."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             job.stop()
             mock_session.sagemaker_client.stop_training_job.assert_called_once_with(
                 TrainingJobName="test-job"
    @@ -190,7 +188,7 @@ def test_stop(self, mock_session):
         @patch("sagemaker.core.remote_function.job._logs_for_job")
         def test_wait(self, mock_logs, mock_session):
             """Test lines 889-903: wait method."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             mock_logs.return_value = {"TrainingJobStatus": "Completed"}
     
             job.wait(timeout=100)
    
  • sagemaker-core/tests/unit/remote_function/test_job.py+6 9 modified
    @@ -143,26 +143,23 @@ class TestJob:
     
         def test_init(self, mock_session):
             """Test _Job initialization."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             assert job.job_name == "test-job"
             assert job.s3_uri == "s3://bucket/output"
    -        assert job.hmac_key == "test-key"
     
         def test_from_describe_response(self, mock_session):
             """Test creating _Job from describe response."""
             response = {
                 "TrainingJobName": "test-job",
                 "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"},
    -            "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"},
             }
             job = _Job.from_describe_response(response, mock_session)
             assert job.job_name == "test-job"
             assert job.s3_uri == "s3://bucket/output"
    -        assert job.hmac_key == "test-key"
     
         def test_describe_returns_cached_response(self, mock_session):
             """Test that describe returns cached response for completed jobs."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             job._last_describe_response = {"TrainingJobStatus": "Completed"}
     
             result = job.describe()
    @@ -171,7 +168,7 @@ def test_describe_returns_cached_response(self, mock_session):
     
         def test_describe_calls_api_for_in_progress_jobs(self, mock_session):
             """Test that describe calls API for in-progress jobs."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             mock_session.sagemaker_client.describe_training_job.return_value = {
                 "TrainingJobStatus": "InProgress"
             }
    @@ -182,7 +179,7 @@ def test_describe_calls_api_for_in_progress_jobs(self, mock_session):
     
         def test_stop(self, mock_session):
             """Test stopping a job."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             job.stop()
             mock_session.sagemaker_client.stop_training_job.assert_called_once_with(
                 TrainingJobName="test-job"
    @@ -191,7 +188,7 @@ def test_stop(self, mock_session):
         @patch("sagemaker.core.remote_function.job._logs_for_job")
         def test_wait(self, mock_logs, mock_session):
             """Test waiting for job completion."""
    -        job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
    +        job = _Job("test-job", "s3://bucket/output", mock_session)
             mock_logs.return_value = {"TrainingJobStatus": "Completed"}
     
             job.wait(timeout=100)
    @@ -882,7 +879,7 @@ def test_start(self, mock_get_name, mock_compile, mock_session):
             mock_get_name.return_value = "test-job"
             mock_compile.return_value = {
                 "TrainingJobName": "test-job",
    -            "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"},
    +            "Environment": {},
             }
     
             job_settings = Mock()
    
  • sagemaker-core/tests/unit/remote_function/test_logging_config.py+86 0 added
    @@ -0,0 +1,86 @@
    +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    +#
    +# Licensed under the Apache License, Version 2.0 (the "License"). You
    +# may not use this file except in compliance with the License. A copy of
    +# the License is located at
    +#
    +#     http://aws.amazon.com/apache2.0/
    +#
    +# or in the "license" file accompanying this file. This file is
    +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
    +# ANY KIND, either express or implied. See the License for the specific
    +# language governing permissions and limitations under the License.
    +"""Tests for logging_config module."""
    +from __future__ import absolute_import
    +
    +import logging
    +import time
    +from unittest.mock import patch
    +from sagemaker.core.remote_function.logging_config import _UTCFormatter, get_logger
    +
    +
    +class TestUTCFormatter:
    +    """Test _UTCFormatter class."""
    +
    +    def test_converter_is_gmtime(self):
    +        """Test that converter is set to gmtime."""
    +        formatter = _UTCFormatter()
    +        assert formatter.converter == time.gmtime
    +
    +    def test_formats_time_in_utc(self):
    +        """Test that time is formatted in UTC."""
    +        formatter = _UTCFormatter("%(asctime)s")
    +        record = logging.LogRecord(
    +            name="test",
    +            level=logging.INFO,
    +            pathname="",
    +            lineno=0,
    +            msg="test message",
    +            args=(),
    +            exc_info=None,
    +        )
    +        formatted = formatter.format(record)
    +        # Should contain UTC time format
    +        assert formatted
    +
    +
    +class TestGetLogger:
    +    """Test get_logger function."""
    +
    +    def test_returns_logger_with_correct_name(self):
    +        """Test that logger has correct name."""
    +        logger = get_logger()
    +        assert logger.name == "sagemaker.remote_function"
    +
    +    def test_logger_has_info_level(self):
    +        """Test that logger is set to INFO level."""
    +        logger = get_logger()
    +        assert logger.level == logging.INFO
    +
    +    def test_logger_has_handler(self):
    +        """Test that logger has at least one handler."""
    +        logger = get_logger()
    +        assert len(logger.handlers) > 0
    +
    +    def test_logger_handler_has_utc_formatter(self):
    +        """Test that logger handler uses UTC formatter."""
    +        logger = get_logger()
    +        handler = logger.handlers[0]
    +        # Check that formatter has gmtime converter (UTC formatter characteristic)
    +        assert handler.formatter.converter == time.gmtime
    +
    +    def test_logger_does_not_propagate(self):
    +        """Test that logger does not propagate to root logger."""
    +        logger = get_logger()
    +        assert logger.propagate == 0
    +
    +    def test_get_logger_is_idempotent(self):
    +        """Test that calling get_logger multiple times returns same logger."""
    +        logger1 = get_logger()
    +        logger2 = get_logger()
    +        assert logger1 is logger2
    +
    +    def test_logger_handler_is_stream_handler(self):
    +        """Test that logger uses StreamHandler."""
    +        logger = get_logger()
    +        assert isinstance(logger.handlers[0], logging.StreamHandler)
    
  • sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py+0 1 modified
    @@ -1036,7 +1036,6 @@ def get_function_step_result(
             return deserialize_obj_from_s3(
                 sagemaker_session=sagemaker_session,
                 s3_uri=s3_uri,
    -            hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
             )
     
         raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE)
    
  • sagemaker-mlops/tests/unit/workflow/test_pipeline.py+0 4 modified
    @@ -360,7 +360,6 @@ def test_get_function_step_result_incomplete_job(mock_session):
             "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT},
             "OutputDataConfig": {"S3OutputPath": "s3://bucket/path"},
             "TrainingJobStatus": "Failed",
    -        "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"}
         }
         
         with pytest.raises(RemoteFunctionError, match="not in Completed status"):
    @@ -376,7 +375,6 @@ def test_get_function_step_result_success(mock_session):
             "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT},
             "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"},
             "TrainingJobStatus": "Completed",
    -        "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"}
         }
         
         with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"):
    @@ -443,7 +441,6 @@ def test_pipeline_execution_result_terminal_failure(mock_session):
             "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT},
             "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"},
             "TrainingJobStatus": "Completed",
    -        "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"}
         }
         
         with patch.object(execution, "wait", side_effect=WaiterError("name", "Waiter encountered a terminal failure state", {})):
    @@ -461,7 +458,6 @@ def test_get_function_step_result_obsolete_s3_path(mock_session):
             "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT},
             "OutputDataConfig": {"S3OutputPath": "s3://bucket/different/path"},
             "TrainingJobStatus": "Completed",
    -        "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"}
         }
         
         with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"):
    
708c7b2f4135

Bug fix for hmac key (#5348)

https://github.com/aws/sagemaker-python-sdkaviruthenDec 15, 2025via ghsa
19 files changed · +83 281
  • src/sagemaker/feature_store/feature_processor/_config_uploader.py+0 3 modified
    @@ -120,9 +120,6 @@ def _prepare_and_upload_callable(
             stored_function = StoredFunction(
                 sagemaker_session=sagemaker_session,
                 s3_base_uri=s3_base_uri,
    -            hmac_key=self.remote_decorator_config.environment_variables[
    -                "REMOTE_FUNCTION_SECRET_KEY"
    -            ],
                 s3_kms_key=self.remote_decorator_config.s3_kms_key,
             )
             stored_function.save(func)
    
  • src/sagemaker/remote_function/client.py+0 6 modified
    @@ -362,7 +362,6 @@ def wrapper(*args, **kwargs):
                                 s3_uri=s3_path_join(
                                     job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
                                 ),
    -                            hmac_key=job.hmac_key,
                             )
                         except ServiceError as serr:
                             chained_e = serr.__cause__
    @@ -399,7 +398,6 @@ def wrapper(*args, **kwargs):
                     return serialization.deserialize_obj_from_s3(
                         sagemaker_session=job_settings.sagemaker_session,
                         s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
    -                    hmac_key=job.hmac_key,
                     )
     
                 if job.describe()["TrainingJobStatus"] == "Stopped":
    @@ -979,7 +977,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
                     job_return = serialization.deserialize_obj_from_s3(
                         sagemaker_session=sagemaker_session,
                         s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
    -                    hmac_key=job.hmac_key,
                     )
                 except DeserializationError as e:
                     client_exception = e
    @@ -991,7 +988,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
                     job_exception = serialization.deserialize_exception_from_s3(
                         sagemaker_session=sagemaker_session,
                         s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
    -                    hmac_key=job.hmac_key,
                     )
                 except ServiceError as serr:
                     chained_e = serr.__cause__
    @@ -1081,7 +1077,6 @@ def result(self, timeout: float = None) -> Any:
                         self._return = serialization.deserialize_obj_from_s3(
                             sagemaker_session=self._job.sagemaker_session,
                             s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
    -                        hmac_key=self._job.hmac_key,
                         )
                         self._state = _FINISHED
                         return self._return
    @@ -1090,7 +1085,6 @@ def result(self, timeout: float = None) -> Any:
                             self._exception = serialization.deserialize_exception_from_s3(
                                 sagemaker_session=self._job.sagemaker_session,
                                 s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
    -                            hmac_key=self._job.hmac_key,
                             )
                         except ServiceError as serr:
                             chained_e = serr.__cause__
    
  • src/sagemaker/remote_function/core/pipeline_variables.py+0 6 modified
    @@ -164,7 +164,6 @@ class _DelayedReturnResolver:
         def __init__(
             self,
             delayed_returns: List[_DelayedReturn],
    -        hmac_key: str,
             properties_resolver: _PropertiesResolver,
             parameter_resolver: _ParameterResolver,
             execution_variable_resolver: _ExecutionVariableResolver,
    @@ -175,7 +174,6 @@ def __init__(
     
             Args:
                 delayed_returns: list of delayed returns to resolve.
    -            hmac_key: key used to encrypt serialized and deserialized function and arguments.
                 properties_resolver: resolver used to resolve step properties.
                 parameter_resolver: resolver used to pipeline parameters.
                 execution_variable_resolver: resolver used to resolve execution variables.
    @@ -197,7 +195,6 @@ def deserialization_task(uri):
                 return uri, deserialize_obj_from_s3(
                     sagemaker_session=settings["sagemaker_session"],
                     s3_uri=uri,
    -                hmac_key=hmac_key,
                 )
     
             with ThreadPoolExecutor() as executor:
    @@ -247,7 +244,6 @@ def resolve_pipeline_variables(
         context: Context,
         func_args: Tuple,
         func_kwargs: Dict,
    -    hmac_key: str,
         s3_base_uri: str,
         **settings,
     ):
    @@ -257,7 +253,6 @@ def resolve_pipeline_variables(
             context: context for the execution.
             func_args: function args.
             func_kwargs: function kwargs.
    -        hmac_key: key used to encrypt serialized and deserialized function and arguments.
             s3_base_uri: the s3 base uri of the function step that the serialized artifacts
                 will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
             **settings: settings to pass to the deserialization function.
    @@ -280,7 +275,6 @@ def resolve_pipeline_variables(
         properties_resolver = _PropertiesResolver(context)
         delayed_return_resolver = _DelayedReturnResolver(
             delayed_returns=delayed_returns,
    -        hmac_key=hmac_key,
             properties_resolver=properties_resolver,
             parameter_resolver=parameter_resolver,
             execution_variable_resolver=execution_variable_resolver,
    
  • src/sagemaker/remote_function/core/serialization.py+15 34 modified
    @@ -152,15 +152,14 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
     
     # TODO: use dask serializer in case dask distributed is installed in users' environment.
     def serialize_func_to_s3(
    -    func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
    +    func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
     ):
         """Serializes function and uploads it to S3.
     
         Args:
             sagemaker_session (sagemaker.session.Session):
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
             func: function to be serialized and persisted
         Raises:
    @@ -169,14 +168,13 @@ def serialize_func_to_s3(
     
         _upload_payload_and_metadata_to_s3(
             bytes_to_upload=CloudpickleSerializer.serialize(func),
    -        hmac_key=hmac_key,
             s3_uri=s3_uri,
             sagemaker_session=sagemaker_session,
             s3_kms_key=s3_kms_key,
         )
     
     
    -def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
    +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
         """Downloads from S3 and then deserializes data objects.
     
         This method downloads the serialized training job outputs to a temporary directory and
    @@ -186,7 +184,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
             sagemaker_session (sagemaker.session.Session):
                 The underlying sagemaker session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
         Returns :
             The deserialized function.
         Raises:
    @@ -198,32 +195,26 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
     
         bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
     
    -    _perform_integrity_check(
    -        expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
    -    )
    +    _perform_integrity_check(expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize)
     
         return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
     
     
    -def serialize_obj_to_s3(
    -    obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
    -):
    +def serialize_obj_to_s3(obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None):
         """Serializes data object and uploads it to S3.
     
         Args:
             sagemaker_session (sagemaker.session.Session):
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
             obj: object to be serialized and persisted
         Raises:
             SerializationError: when fail to serialize object to bytes.
         """
     
         _upload_payload_and_metadata_to_s3(
             bytes_to_upload=CloudpickleSerializer.serialize(obj),
    -        hmac_key=hmac_key,
             s3_uri=s3_uri,
             sagemaker_session=sagemaker_session,
             s3_kms_key=s3_kms_key,
    @@ -270,14 +261,13 @@ def json_serialize_obj_to_s3(
         )
     
     
    -def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
    +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
         """Downloads from S3 and then deserializes data objects.
     
         Args:
             sagemaker_session (sagemaker.session.Session):
                 The underlying sagemaker session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
         Returns :
             Deserialized python objects.
         Raises:
    @@ -290,15 +280,13 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
     
         bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
     
    -    _perform_integrity_check(
    -        expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
    -    )
    +    _perform_integrity_check(expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize)
     
         return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
     
     
     def serialize_exception_to_s3(
    -    exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
    +    exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
     ):
         """Serializes exception with traceback and uploads it to S3.
     
    @@ -307,7 +295,6 @@ def serialize_exception_to_s3(
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
             exc: Exception to be serialized and persisted
         Raises:
             SerializationError: when fail to serialize object to bytes.
    @@ -316,7 +303,6 @@ def serialize_exception_to_s3(
     
         _upload_payload_and_metadata_to_s3(
             bytes_to_upload=CloudpickleSerializer.serialize(exc),
    -        hmac_key=hmac_key,
             s3_uri=s3_uri,
             sagemaker_session=sagemaker_session,
             s3_kms_key=s3_kms_key,
    @@ -325,7 +311,6 @@ def serialize_exception_to_s3(
     
     def _upload_payload_and_metadata_to_s3(
         bytes_to_upload: Union[bytes, io.BytesIO],
    -    hmac_key: str,
         s3_uri: str,
         sagemaker_session: Session,
         s3_kms_key,
    @@ -334,15 +319,14 @@ def _upload_payload_and_metadata_to_s3(
     
         Args:
             bytes_to_upload (bytes): Serialized bytes to upload.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
             sagemaker_session (sagemaker.session.Session):
                 The underlying Boto3 session which AWS service calls are delegated to.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
         """
         _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
     
    -    sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
    +    sha256_hash = _compute_hash(bytes_to_upload)
     
         _upload_bytes_to_s3(
             _MetaData(sha256_hash).to_json(),
    @@ -352,14 +336,13 @@ def _upload_payload_and_metadata_to_s3(
         )
     
     
    -def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
    +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
         """Downloads from S3 and then deserializes exception.
     
         Args:
             sagemaker_session (sagemaker.session.Session):
                 The underlying sagemaker session which AWS service calls are delegated to.
             s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
    -        hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
         Returns :
             Deserialized exception with traceback.
         Raises:
    @@ -372,9 +355,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
     
         bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
     
    -    _perform_integrity_check(
    -        expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
    -    )
    +    _perform_integrity_check(expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize)
     
         return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
     
    @@ -399,18 +380,18 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
             ) from e
     
     
    -def _compute_hash(buffer: bytes, secret_key: str) -> str:
    -    """Compute the hmac-sha256 hash"""
    -    return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
    +def _compute_hash(buffer: bytes) -> str:
    +    """Compute the sha256 hash"""
    +    return hashlib.sha256(buffer).hexdigest()
     
     
    -def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
    +def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
         """Performs integrity checks for serialized code/arguments uploaded to s3.
     
         Verifies whether the hash read from s3 matches the hash calculated
         during remote function execution.
         """
    -    actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
    +    actual_hash_value = _compute_hash(buffer=buffer)
         if not hmac.compare_digest(expected_hash_value, actual_hash_value):
             raise DeserializationError(
                 "Integrity check for the serialized function or data failed. "
    
  • src/sagemaker/remote_function/core/stored_function.py+0 11 modified
    @@ -52,7 +52,6 @@ def __init__(
             self,
             sagemaker_session: Session,
             s3_base_uri: str,
    -        hmac_key: str,
             s3_kms_key: str = None,
             context: Context = Context(),
         ):
    @@ -63,13 +62,11 @@ def __init__(
                     AWS service calls are delegated to.
                 s3_base_uri: the base uri to which serialized artifacts will be uploaded.
                 s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
    -            hmac_key: Key used to encrypt serialized and deserialized function and arguments.
                 context: Build or run context of a pipeline step.
             """
             self.sagemaker_session = sagemaker_session
             self.s3_base_uri = s3_base_uri
             self.s3_kms_key = s3_kms_key
    -        self.hmac_key = hmac_key
             self.context = context
     
             self.func_upload_path = s3_path_join(
    @@ -98,7 +95,6 @@ def save(self, func, *args, **kwargs):
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
                 s3_kms_key=self.s3_kms_key,
    -            hmac_key=self.hmac_key,
             )
     
             logger.info(
    @@ -110,7 +106,6 @@ def save(self, func, *args, **kwargs):
                 obj=(args, kwargs),
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
    -            hmac_key=self.hmac_key,
                 s3_kms_key=self.s3_kms_key,
             )
     
    @@ -128,7 +123,6 @@ def save_pipeline_step_function(self, serialized_data):
             )
             serialization._upload_payload_and_metadata_to_s3(
                 bytes_to_upload=serialized_data.func,
    -            hmac_key=self.hmac_key,
                 s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
                 sagemaker_session=self.sagemaker_session,
                 s3_kms_key=self.s3_kms_key,
    @@ -140,7 +134,6 @@ def save_pipeline_step_function(self, serialized_data):
             )
             serialization._upload_payload_and_metadata_to_s3(
                 bytes_to_upload=serialized_data.args,
    -            hmac_key=self.hmac_key,
                 s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
                 sagemaker_session=self.sagemaker_session,
                 s3_kms_key=self.s3_kms_key,
    @@ -156,7 +149,6 @@ def load_and_invoke(self) -> Any:
             func = serialization.deserialize_func_from_s3(
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
    -            hmac_key=self.hmac_key,
             )
     
             logger.info(
    @@ -166,15 +158,13 @@ def load_and_invoke(self) -> Any:
             args, kwargs = serialization.deserialize_obj_from_s3(
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
    -            hmac_key=self.hmac_key,
             )
     
             logger.info("Resolving pipeline variables")
             resolved_args, resolved_kwargs = resolve_pipeline_variables(
                 self.context,
                 args,
                 kwargs,
    -            hmac_key=self.hmac_key,
                 s3_base_uri=self.s3_base_uri,
                 sagemaker_session=self.sagemaker_session,
             )
    @@ -190,7 +180,6 @@ def load_and_invoke(self) -> Any:
                 obj=result,
                 sagemaker_session=self.sagemaker_session,
                 s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER),
    -            hmac_key=self.hmac_key,
                 s3_kms_key=self.s3_kms_key,
             )
     
    
  • src/sagemaker/remote_function/errors.py+1 3 modified
    @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg):
                 f.write(failure_msg)
     
     
    -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int:
    +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int:
         """Handle all exceptions raised during remote function execution.
     
         Args:
    @@ -79,7 +79,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) ->
                  AWS service calls are delegated to.
             s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded.
             s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
    -        hmac_key (str): Key used to calculate hmac hash of the serialized exception.
         Returns :
             exit_code (int): Exit code to terminate current job.
         """
    @@ -97,7 +96,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) ->
             exc=error,
             sagemaker_session=sagemaker_session,
             s3_uri=s3_path_join(s3_base_uri, "exception"),
    -        hmac_key=hmac_key,
             s3_kms_key=s3_kms_key,
         )
     
    
  • src/sagemaker/remote_function/invoke_function.py+1 9 modified
    @@ -17,7 +17,6 @@
     import argparse
     import sys
     import json
    -import os
     from typing import TYPE_CHECKING
     
     import boto3
    @@ -97,17 +96,14 @@ def _load_pipeline_context(args) -> Context:
         )
     
     
    -def _execute_remote_function(
    -    sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context
    -):
    +def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context):
         """Execute stored remote function"""
         from sagemaker.remote_function.core.stored_function import StoredFunction
     
         stored_function = StoredFunction(
             sagemaker_session=sagemaker_session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=s3_kms_key,
    -        hmac_key=hmac_key,
             context=context,
         )
     
    @@ -138,15 +134,12 @@ def main(sys_args=None):
             run_in_context = args.run_in_context
             pipeline_context = _load_pipeline_context(args)
     
    -        hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY")
    -
             sagemaker_session = _get_sagemaker_session(region)
             _execute_remote_function(
                 sagemaker_session=sagemaker_session,
                 s3_base_uri=s3_base_uri,
                 s3_kms_key=s3_kms_key,
                 run_in_context=run_in_context,
    -            hmac_key=hmac_key,
                 context=pipeline_context,
             )
     
    @@ -162,7 +155,6 @@ def main(sys_args=None):
                 sagemaker_session=sagemaker_session,
                 s3_base_uri=s3_uri,
                 s3_kms_key=s3_kms_key,
    -            hmac_key=hmac_key,
             )
         finally:
             sys.exit(exit_code)
    
  • src/sagemaker/remote_function/job.py+2 21 modified
    @@ -19,7 +19,6 @@
     import shutil
     import sys
     import json
    -import secrets
     
     from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
     from urllib.parse import urlparse
    @@ -583,11 +582,6 @@ def __init__(
                 {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name}
             )
     
    -        # The following will be overridden by the _Job.compile method.
    -        # However, it needs to be kept here for feature store SDK.
    -        # TODO: update the feature store SDK to set the HMAC key there.
    -        self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)})
    -
             if spark_config and image_uri:
                 raise ValueError("spark_config and image_uri cannot be specified at the same time!")
     
    @@ -799,19 +793,17 @@ def _get_default_spark_image(session):
     class _Job:
         """Helper class that interacts with the SageMaker training service."""
     
    -    def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str):
    +    def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session):
             """Initialize a _Job object.
     
             Args:
                 job_name (str): The training job name.
                 s3_uri (str): The training job output S3 uri.
                 sagemaker_session (Session): SageMaker boto session.
    -            hmac_key (str): Remote function secret key.
             """
             self.job_name = job_name
             self.s3_uri = s3_uri
             self.sagemaker_session = sagemaker_session
    -        self.hmac_key = hmac_key
             self._last_describe_response = None
     
         @staticmethod
    @@ -827,9 +819,8 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
             """
             job_name = describe_training_job_response["TrainingJobName"]
             s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
    -        hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"]
     
    -        job = _Job(job_name, s3_uri, sagemaker_session, hmac_key)
    +        job = _Job(job_name, s3_uri, sagemaker_session)
             job._last_describe_response = describe_training_job_response
             return job
     
    @@ -867,7 +858,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
                 job_name,
                 s3_base_uri,
                 job_settings.sagemaker_session,
    -            training_job_request["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
             )
     
         @staticmethod
    @@ -892,26 +882,18 @@ def compile(
     
             jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:]
     
    -        # generate hmac key for integrity check
    -        if step_compilation_context is None:
    -            hmac_key = secrets.token_hex(32)
    -        else:
    -            hmac_key = step_compilation_context.function_step_secret_token
    -
             # serialize function and arguments
             if step_compilation_context is None:
                 stored_function = StoredFunction(
                     sagemaker_session=job_settings.sagemaker_session,
                     s3_base_uri=s3_base_uri,
    -                hmac_key=hmac_key,
                     s3_kms_key=job_settings.s3_kms_key,
                 )
                 stored_function.save(func, *func_args, **func_kwargs)
             else:
                 stored_function = StoredFunction(
                     sagemaker_session=job_settings.sagemaker_session,
                     s3_base_uri=s3_base_uri,
    -                hmac_key=hmac_key,
                     s3_kms_key=job_settings.s3_kms_key,
                     context=Context(
                         step_name=step_compilation_context.step_name,
    @@ -1061,7 +1043,6 @@ def compile(
             request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances
     
             request_dict["Environment"] = job_settings.environment_variables
    -        request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})
     
             extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
             extended_request = _extend_mpirun_to_request(extended_request, job_settings)
    
  • src/sagemaker/workflow/pipeline.py+0 1 modified
    @@ -1084,7 +1084,6 @@ def get_function_step_result(
             return deserialize_obj_from_s3(
                 sagemaker_session=sagemaker_session,
                 s3_uri=s3_uri,
    -            hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
             )
     
         raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE)
    
  • tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py+0 3 modified
    @@ -68,7 +68,6 @@ def remote_decorator_config(sagemaker_session):
             pre_execution_commands="some_commands",
             pre_execution_script="some_path",
             python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
    -        environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
             custom_file_filter=None,
         )
     
    @@ -91,7 +90,6 @@ def remote_decorator_config_with_filter(sagemaker_session):
             pre_execution_commands="some_commands",
             pre_execution_script="some_path",
             python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
    -        environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
             custom_file_filter=custom_file_filter,
         )
     
    @@ -103,7 +101,6 @@ def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrap
         assert mock_stored_function.called_once_with(
             s3_base_uri="s3_base_uri",
             s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key,
    -        hmac_key="some_secret_key",
             sagemaker_session=sagemaker_session,
         )
     
    
  • tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py+0 3 modified
    @@ -309,9 +309,6 @@ def test_to_pipeline(
             input_mode="File",
             environment={
                 "AWS_DEFAULT_REGION": "us-west-2",
    -            "REMOTE_FUNCTION_SECRET_KEY": job_settings.environment_variables[
    -                "REMOTE_FUNCTION_SECRET_KEY"
    -            ],
                 "scheduled_time": Parameter(
                     name="scheduled_time", parameter_type=ParameterTypeEnum.STRING
                 ),
    
  • tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py+4 9 modified
    @@ -83,13 +83,12 @@ def test_resolve_delayed_returns(mock_deserializer):
             }
         )
         resolver = _DelayedReturnResolver(
    -        delayed_returns,
    -        "1234",
    +        delayed_returns=delayed_returns,
             properties_resolver=_PropertiesResolver(context),
             parameter_resolver=_ParameterResolver(context),
             execution_variable_resolver=_ExecutionVariableResolver(context),
    -        sagemaker_session=None,
             s3_base_uri=f"s3://my-bucket/{PIPELINE_NAME}",
    +        sagemaker_session=None,
         )
     
         assert resolver.resolve(delayed_returns[0]) == 1
    @@ -122,13 +121,12 @@ def test_deserializer_fails(mock_deserializer):
         )
         with pytest.raises(Exception, match="Something went wrong"):
             _DelayedReturnResolver(
    -            delayed_returns,
    -            "1234",
    +            delayed_returns=delayed_returns,
                 properties_resolver=_PropertiesResolver(context),
                 parameter_resolver=_ParameterResolver(context),
                 execution_variable_resolver=_ExecutionVariableResolver(context),
    -            sagemaker_session=None,
                 s3_base_uri=f"s3://my-bucket/{PIPELINE_NAME}",
    +            sagemaker_session=None,
             )
     
     
    @@ -149,7 +147,6 @@ def test_no_pipeline_variables_to_resolve(mock_deserializer, func_args, func_kwa
             Context(),
             func_args,
             func_kwargs,
    -        hmac_key="1234",
             s3_base_uri="s3://my-bucket",
             sagemaker_session=None,
         )
    @@ -275,7 +272,6 @@ def test_resolve_pipeline_variables(
             context,
             func_args,
             func_kwargs,
    -        hmac_key="1234",
             s3_base_uri=s3_base_uri,
             sagemaker_session=None,
         )
    @@ -285,7 +281,6 @@ def test_resolve_pipeline_variables(
         mock_deserializer.assert_called_once_with(
             sagemaker_session=None,
             s3_uri=s3_results_uri,
    -        hmac_key="1234",
         )
     
     
    
  • tests/unit/sagemaker/remote_function/core/test_serialization.py+26 69 modified
    @@ -32,7 +32,6 @@
     from tblib import pickling_support
     
     KMS_KEY = "kms-key"
    -HMAC_KEY = "some-hmac-key"
     
     
     mock_s3 = {}
    @@ -66,15 +65,11 @@ def square(x):
             return x * x
     
         s3_uri = random_s3_uri()
    -    serialize_func_to_s3(
    -        func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         del square
     
    -    deserialized = deserialize_func_from_s3(
    -        sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -    )
    +    deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
         assert deserialized(3) == 9
     
    @@ -89,12 +84,9 @@ def test_serialize_deserialize_lambda():
             sagemaker_session=Mock(),
             s3_uri=s3_uri,
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
         )
     
    -    deserialized = deserialize_func_from_s3(
    -        sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -    )
    +    deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
         assert deserialized(3) == 9
     
    @@ -126,7 +118,6 @@ def train(x):
                 sagemaker_session=Mock(),
                 s3_uri=s3_uri,
                 s3_kms_key=KMS_KEY,
    -            hmac_key=HMAC_KEY,
             )
     
     
    @@ -153,7 +144,6 @@ def func(x):
                 sagemaker_session=Mock(),
                 s3_uri=s3_uri,
                 s3_kms_key=KMS_KEY,
    -            hmac_key=HMAC_KEY,
             )
     
     
    @@ -177,7 +167,6 @@ def square(x):
                 sagemaker_session=Mock(),
                 s3_uri=s3_uri,
                 s3_kms_key=KMS_KEY,
    -            hmac_key=HMAC_KEY,
             )
     
     
    @@ -192,9 +181,7 @@ def square(x):
     
         s3_uri = random_s3_uri()
     
    -    serialize_func_to_s3(
    -        func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         del square
     
    @@ -204,7 +191,7 @@ def square(x):
             + r"RuntimeError\('some failure when loads'\). "
             + r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
         ):
    -        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
    +        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
     
     @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload)
    @@ -215,15 +202,13 @@ def square(x):
     
         s3_uri = random_s3_uri()
     
    -    serialize_func_to_s3(
    -        func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
         mock_s3[f"{s3_uri}/metadata.json"] = b"not json serializable"
     
         del square
     
         with pytest.raises(DeserializationError, match=r"Corrupt metadata file."):
    -        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
    +        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
     
     @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload)
    @@ -233,16 +218,16 @@ def square(x):
             return x * x
     
         s3_uri = random_s3_uri()
    -    serialize_func_to_s3(
    -        func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
    +    # Tamper with the payload to trigger integrity check failure
    +    mock_s3[f"{s3_uri}/payload.pkl"] = b"tampered data"
     
         del square
     
         with pytest.raises(
             DeserializationError, match=r"Integrity check for the serialized function or data failed."
         ):
    -        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key="invalid_key")
    +        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
     
     @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload)
    @@ -255,16 +240,12 @@ def __init__(self, x):
         my_data = MyData(10)
     
         s3_uri = random_s3_uri()
    -    serialize_obj_to_s3(
    -        my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         del my_data
         del MyData
     
    -    deserialized = deserialize_obj_from_s3(
    -        sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -    )
    +    deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
         assert deserialized.x == 10
     
    @@ -276,15 +257,11 @@ def test_serialize_deserialize_data_built_in_types():
         my_data = {"a": [10]}
     
         s3_uri = random_s3_uri()
    -    serialize_obj_to_s3(
    -        my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         del my_data
     
    -    deserialized = deserialize_obj_from_s3(
    -        sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -    )
    +    deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
         assert deserialized == {"a": [10]}
     
    @@ -294,13 +271,9 @@ def test_serialize_deserialize_data_built_in_types():
     def test_serialize_deserialize_none():
     
         s3_uri = random_s3_uri()
    -    serialize_obj_to_s3(
    -        None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_obj_to_s3(None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
    -    deserialized = deserialize_obj_from_s3(
    -        sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -    )
    +    deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
         assert deserialized is None
     
    @@ -327,7 +300,6 @@ def test_serialize_run(sagemaker_session, *args, **kwargs):
                     sagemaker_session=Mock(),
                     s3_uri=s3_uri,
                     s3_kms_key=KMS_KEY,
    -                hmac_key=HMAC_KEY,
                 )
     
     
    @@ -351,7 +323,6 @@ def test_serialize_pipeline_variables(pipeline_variable):
                 sagemaker_session=Mock(),
                 s3_uri=s3_uri,
                 s3_kms_key=KMS_KEY,
    -            hmac_key=HMAC_KEY,
             )
     
     
    @@ -377,7 +348,6 @@ def __init__(self, x):
                 sagemaker_session=Mock(),
                 s3_uri=s3_uri,
                 s3_kms_key=KMS_KEY,
    -            hmac_key=HMAC_KEY,
             )
     
     
    @@ -394,9 +364,7 @@ def __init__(self, x):
         my_data = MyData(10)
         s3_uri = random_s3_uri()
     
    -    serialize_obj_to_s3(
    -        obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -    )
    +    serialize_obj_to_s3(obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         del my_data
         del MyData
    @@ -407,7 +375,7 @@ def __init__(self, x):
             + r"RuntimeError\('some failure when loads'\). "
             + r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
         ):
    -        deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
    +        deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
     
     @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_error)
    @@ -427,7 +395,6 @@ def test_serialize_deserialize_service_error():
                 sagemaker_session=Mock(),
                 s3_uri=s3_uri,
                 s3_kms_key=KMS_KEY,
    -            hmac_key=HMAC_KEY,
             )
     
         del my_func
    @@ -437,7 +404,7 @@ def test_serialize_deserialize_service_error():
             match=rf"Failed to read serialized bytes from {s3_uri}/metadata.json: "
             + r"RuntimeError\('some failure when read_bytes'\)",
         ):
    -        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
    +        deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
     
     
     @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload)
    @@ -460,12 +427,10 @@ def func_b():
             func_b()
         except Exception as e:
             pickling_support.install()
    -        serialize_obj_to_s3(
    -            e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -        )
    +        serialize_obj_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         with pytest.raises(CustomError, match="Some error") as exc_info:
    -        raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
    +        raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
         assert type(exc_info.value.__cause__) is TypeError
     
     
    @@ -488,14 +453,10 @@ def func_b():
         try:
             func_b()
         except Exception as e:
    -        serialize_exception_to_s3(
    -            e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -        )
    +        serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         with pytest.raises(CustomError, match="Some error") as exc_info:
    -        raise deserialize_exception_from_s3(
    -            sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -        )
    +        raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
         assert type(exc_info.value.__cause__) is TypeError
     
     
    @@ -518,12 +479,8 @@ def func_b():
         try:
             func_b()
         except Exception as e:
    -        serialize_exception_to_s3(
    -            e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    -        )
    +        serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
     
         with pytest.raises(ServiceError, match="Some error") as exc_info:
    -        raise deserialize_exception_from_s3(
    -            sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY
    -        )
    +        raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
         assert type(exc_info.value.__cause__) is TypeError
    
  • tests/unit/sagemaker/remote_function/core/test_stored_function.py+10 20 modified
    @@ -50,7 +50,6 @@
     )
     
     KMS_KEY = "kms-key"
    -HMAC_KEY = "some-hmac-key"
     FUNCTION_FOLDER = "function"
     ARGUMENT_FOLDER = "arguments"
     RESULT_FOLDER = "results"
    @@ -96,14 +95,14 @@ def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwarg
         s3_base_uri = random_s3_uri()
     
         stored_function = StoredFunction(
    -        sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    +        sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY
         )
         stored_function.save(quadratic, *args, **kwargs)
         stored_function.load_and_invoke()
     
    -    assert deserialize_obj_from_s3(
    -        session, s3_uri=f"{s3_base_uri}/results", hmac_key=HMAC_KEY
    -    ) == quadratic(*args, **kwargs)
    +    assert deserialize_obj_from_s3(session, s3_uri=f"{s3_base_uri}/results") == quadratic(
    +        *args, **kwargs
    +    )
     
     
     @patch(
    @@ -139,7 +138,7 @@ def test_save_with_parameter_of_run_type(
             sagemaker_session=session,
         )
         stored_function = StoredFunction(
    -        sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
    +        sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY
         )
         with pytest.raises(SerializationError) as e:
             stored_function.save(log_bigger, 1, 2, run)
    @@ -165,7 +164,6 @@ def test_save_s3_paths_verification(
             sagemaker_session=session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
             context=Context(
                 step_name=step_name,
                 execution_id=execution_id,
    @@ -180,13 +178,11 @@ def test_save_s3_paths_verification(
             sagemaker_session=session,
             s3_uri=(upload_path + FUNCTION_FOLDER),
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
         )
         serialize_obj.assert_called_once_with(
             obj=((3,), {}),
             sagemaker_session=session,
             s3_uri=(upload_path + ARGUMENT_FOLDER),
    -        hmac_key=HMAC_KEY,
             s3_kms_key=KMS_KEY,
         )
     
    @@ -226,7 +222,6 @@ def test_load_and_invoke_s3_paths_verification(
             sagemaker_session=session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
             context=Context(
                 step_name=step_name,
                 execution_id=execution_id,
    @@ -237,12 +232,11 @@ def test_load_and_invoke_s3_paths_verification(
         stored_function.load_and_invoke()
     
         deserialize_func.assert_called_once_with(
    -        sagemaker_session=session, s3_uri=(download_path + FUNCTION_FOLDER), hmac_key=HMAC_KEY
    +        sagemaker_session=session, s3_uri=(download_path + FUNCTION_FOLDER)
         )
         deserialize_obj.assert_called_once_with(
             sagemaker_session=session,
             s3_uri=(download_path + ARGUMENT_FOLDER),
    -        hmac_key=HMAC_KEY,
         )
     
         result = deserialize_func.return_value(
    @@ -253,7 +247,6 @@ def test_load_and_invoke_s3_paths_verification(
             obj=result,
             sagemaker_session=session,
             s3_uri=(upload_path + RESULT_FOLDER),
    -        hmac_key=HMAC_KEY,
             s3_kms_key=KMS_KEY,
         )
     
    @@ -283,7 +276,6 @@ def test_load_and_invoke_json_serialization(
             sagemaker_session=session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
             context=Context(
                 serialize_output_to_json=serialize_output_to_json,
             ),
    @@ -318,13 +310,12 @@ def test_save_and_load_with_pipeline_variable(monkeypatch):
     
         function_step = _FunctionStep(name="func_1", display_name=None, description=None)
         x = DelayedReturn(function_step=function_step)
    -    serialize_obj_to_s3(3.0, session, func1_result_path, HMAC_KEY, KMS_KEY)
    +    serialize_obj_to_s3(3.0, session, func1_result_path, KMS_KEY)
     
         stored_function = StoredFunction(
             sagemaker_session=session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
             context=Context(
                 property_references={
                     "Parameters.a": "1.0",
    @@ -355,9 +346,9 @@ def test_save_and_load_with_pipeline_variable(monkeypatch):
         stored_function.load_and_invoke()
     
         func2_result_path = f"{s3_base_uri}/execution-id/func_2/results"
    -    assert deserialize_obj_from_s3(
    -        session, s3_uri=func2_result_path, hmac_key=HMAC_KEY
    -    ) == quadratic(3.0, a=1.0, b=2.0, c=3.0)
    +    assert deserialize_obj_from_s3(session, s3_uri=func2_result_path) == quadratic(
    +        3.0, a=1.0, b=2.0, c=3.0
    +    )
     
     
     @patch("sagemaker.remote_function.core.serialization._upload_payload_and_metadata_to_s3")
    @@ -371,7 +362,6 @@ def test_save_pipeline_step_function(mock_job_settings, upload_payload):
             sagemaker_session=session,
             s3_base_uri=s3_base_uri,
             s3_kms_key=KMS_KEY,
    -        hmac_key=HMAC_KEY,
             context=Context(
                 step_name="step_name",
                 execution_id="execution_id",
    
  • tests/unit/sagemaker/remote_function/test_client.py+4 7 modified
    @@ -54,7 +54,7 @@
     S3_URI = f"s3://{BUCKET}/keyprefix"
     EXPECTED_JOB_RESULT = [1, 2, 3]
     PATH_TO_SRC_DIR = "path/to/src/dir"
    -HMAC_KEY = "some-hmac-key"
    +
     ROLE_ARN = "arn:aws:iam::555555555555:role/my_execution_role_arn"
     
     
    @@ -69,7 +69,7 @@ def describe_training_job_response(job_status):
                 "VolumeSizeInGB": 30,
             },
             "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"},
    -        "Environment": {"REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        "Environment": {},
         }
     
     
    @@ -1027,7 +1027,7 @@ def test_future_get_result_from_completed_job(mock_start, mock_deserialize):
     def test_future_get_result_from_failed_job_remote_error_client_function(
         mock_start, mock_deserialize
     ):
    -    mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI, hmac_key=HMAC_KEY)
    +    mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI)
         mock_start.return_value = mock_job
         mock_job.describe.return_value = FAILED_TRAINING_JOB
     
    @@ -1042,9 +1042,7 @@ def test_future_get_result_from_failed_job_remote_error_client_function(
     
         assert future.done()
         mock_job.wait.assert_called_once()
    -    mock_deserialize.assert_called_with(
    -        sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception", hmac_key=HMAC_KEY
    -    )
    +    mock_deserialize.assert_called_with(sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception")
     
     
     @patch("sagemaker.s3.S3Downloader.read_bytes")
    @@ -1374,7 +1372,6 @@ def test_get_future_completed_job_deserialization_error(mock_session, mock_deser
         mock_deserialize.assert_called_with(
             sagemaker_session=ANY,
             s3_uri="s3://sagemaker-123/image_uri/output/results",
    -        hmac_key=HMAC_KEY,
         )
     
     
    
  • tests/unit/sagemaker/remote_function/test_errors.py+0 3 modified
    @@ -20,7 +20,6 @@
     
     TEST_S3_BASE_URI = "s3://my-bucket/"
     TEST_S3_KMS_KEY = "my-kms-key"
    -TEST_HMAC_KEY = "some-hmac-key"
     
     
     class _InvalidErrorNumberException(Exception):
    @@ -76,7 +75,6 @@ def test_handle_error(
             sagemaker_session=sagemaker_session,
             s3_base_uri=TEST_S3_BASE_URI,
             s3_kms_key=TEST_S3_KMS_KEY,
    -        hmac_key=TEST_HMAC_KEY,
         )
     
         assert exit_code == expected_exit_code
    @@ -87,6 +85,5 @@ def test_handle_error(
             exc=err,
             sagemaker_session=sagemaker_session,
             s3_uri=TEST_S3_BASE_URI + "exception",
    -        hmac_key=TEST_HMAC_KEY,
             s3_kms_key=TEST_S3_KMS_KEY,
         )
    
  • tests/unit/sagemaker/remote_function/test_invoke_function.py+0 11 modified
    @@ -12,8 +12,6 @@
     # language governing permissions and limitations under the License.
     from __future__ import absolute_import
     
    -import os
    -
     import pytest
     from mock import patch, Mock, ANY
     from sagemaker.remote_function import invoke_function
    @@ -25,7 +23,6 @@
     TEST_S3_BASE_URI = "s3://my-bucket/"
     TEST_S3_KMS_KEY = "my-kms-key"
     TEST_RUN_IN_CONTEXT = '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}'
    -TEST_HMAC_KEY = "some-hmac-key"
     TEST_STEP_NAME = "training-step"
     TEST_EXECUTION_ID = "some-execution-id"
     FUNC_STEP_S3_DIR = sagemaker_timestamp()
    @@ -89,7 +86,6 @@ def mock_session():
         return_value=mock_session(),
     )
     def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object):
    -    os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY
         invoke_function.main(mock_args())
     
         _get_sagemaker_session.assert_called_with(TEST_REGION)
    @@ -108,7 +104,6 @@ def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _l
     def test_main_success_with_run(
         _get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object
     ):
    -    os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY
         invoke_function.main(mock_args_with_run_in_context())
     
         _get_sagemaker_session.assert_called_with(TEST_REGION)
    @@ -137,7 +132,6 @@ def test_main_success_with_run(
     def test_main_success_with_pipeline_context(
         _get_sagemaker_session, mock_stored_function, _exit_process, _load_run_object, args
     ):
    -    os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY
     
         args_input, expected_serialize_output_to_json = args
         invoke_function.main(args_input)
    @@ -147,7 +141,6 @@ def test_main_success_with_pipeline_context(
             sagemaker_session=ANY,
             s3_base_uri=TEST_S3_BASE_URI,
             s3_kms_key=TEST_S3_KMS_KEY,
    -        hmac_key=TEST_HMAC_KEY,
             context=Context(
                 execution_id=TEST_EXECUTION_ID,
                 step_name=TEST_STEP_NAME,
    @@ -174,7 +167,6 @@ def test_main_success_with_pipeline_context(
     def test_main_failure(
         _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object
     ):
    -    os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY
         ser_err = SerializationError("some failure reason")
         load_and_invoke.side_effect = ser_err
         handle_error.return_value = 1
    @@ -189,7 +181,6 @@ def test_main_failure(
             sagemaker_session=_get_sagemaker_session(),
             s3_base_uri=TEST_S3_BASE_URI,
             s3_kms_key=TEST_S3_KMS_KEY,
    -        hmac_key=TEST_HMAC_KEY,
         )
         _exit_process.assert_called_with(1)
     
    @@ -205,7 +196,6 @@ def test_main_failure(
     def test_main_failure_with_step(
         _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object
     ):
    -    os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY
         ser_err = SerializationError("some failure reason")
         load_and_invoke.side_effect = ser_err
         handle_error.return_value = 1
    @@ -221,6 +211,5 @@ def test_main_failure_with_step(
             sagemaker_session=_get_sagemaker_session(),
             s3_base_uri=s3_uri,
             s3_kms_key=TEST_S3_KMS_KEY,
    -        hmac_key=TEST_HMAC_KEY,
         )
         _exit_process.assert_called_with(1)
    
  • tests/unit/sagemaker/remote_function/test_job.py+20 61 modified
    @@ -68,7 +68,7 @@
     RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap"
     REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws"
     SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies"
    -HMAC_KEY = "some-hmac-key"
    +
     
     EXPECTED_FUNCTION_URI = S3_URI + "/function.pkl"
     EXPECTED_OUTPUT_URI = S3_URI + "/output"
    @@ -249,11 +249,7 @@
     DESCRIBE_TRAINING_JOB_RESPONSE = {
         "TrainingJobArn": TRAINING_JOB_ARN,
         "TrainingJobStatus": "{}",
    -    "ResourceConfig": {
    -        "InstanceCount": 1,
    -        "InstanceType": "ml.c4.xlarge",
    -        "VolumeSizeInGB": 30,
    -    },
    +    "ResourceConfig": {"InstanceCount": 1, "InstanceType": "ml.c4.xlarge", "VolumeSizeInGB": 30},
         "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"},
     }
     
    @@ -359,32 +355,27 @@ def serialized_data():
         return _SerializedData(func=b"serialized_func", args=b"serialized_args")
     
     
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job.Session", return_value=mock_session())
     @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
    -def test_sagemaker_config_job_settings(get_execution_role, session, secret_token):
    +def test_sagemaker_config_job_settings(get_execution_role, session):
     
         job_settings = _JobSettings(image_uri="image_uri", instance_type="ml.m5.xlarge")
         assert job_settings.image_uri == "image_uri"
         assert job_settings.s3_root_uri == f"s3://{BUCKET}"
         assert job_settings.role == DEFAULT_ROLE_ARN
    -    assert job_settings.environment_variables == {
    -        "AWS_DEFAULT_REGION": "us-west-2",
    -        "REMOTE_FUNCTION_SECRET_KEY": "some-hmac-key",
    -    }
    +    assert job_settings.environment_variables == {"AWS_DEFAULT_REGION": "us-west-2"}
         assert job_settings.include_local_workdir is False
         assert job_settings.instance_type == "ml.m5.xlarge"
     
     
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch(
         "sagemaker.remote_function.job._JobSettings._get_default_spark_image",
         return_value="some_image_uri",
     )
     @patch("sagemaker.remote_function.job.Session", return_value=mock_session())
     @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
     def test_sagemaker_config_job_settings_with_spark_config(
    -    get_execution_role, session, mock_get_default_spark_image, secret_token
    +    get_execution_role, session, mock_get_default_spark_image
     ):
     
         spark_config = SparkConfig()
    @@ -393,10 +384,7 @@ def test_sagemaker_config_job_settings_with_spark_config(
         assert job_settings.image_uri == "some_image_uri"
         assert job_settings.s3_root_uri == f"s3://{BUCKET}"
         assert job_settings.role == DEFAULT_ROLE_ARN
    -    assert job_settings.environment_variables == {
    -        "AWS_DEFAULT_REGION": "us-west-2",
    -        "REMOTE_FUNCTION_SECRET_KEY": "some-hmac-key",
    -    }
    +    assert job_settings.environment_variables == {"AWS_DEFAULT_REGION": "us-west-2"}
         assert job_settings.include_local_workdir is False
         assert job_settings.instance_type == "ml.m5.xlarge"
         assert job_settings.spark_config == spark_config
    @@ -434,12 +422,9 @@ def test_sagemaker_config_job_settings_with_not_supported_param_by_spark():
             )
     
     
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job.Session", return_value=mock_session())
     @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
    -def test_sagemaker_config_job_settings_with_configuration_file(
    -    get_execution_role, session, secret_token
    -):
    +def test_sagemaker_config_job_settings_with_configuration_file(get_execution_role, session):
         config_tags = [
             {"Key": "someTagKey", "Value": "someTagValue"},
             {"Key": "someTagKey2", "Value": "someTagValue2"},
    @@ -458,7 +443,6 @@ def test_sagemaker_config_job_settings_with_configuration_file(
         assert job_settings.pre_execution_commands == ["command_1", "command_2"]
         assert job_settings.environment_variables == {
             "AWS_DEFAULT_REGION": "us-west-2",
    -        "REMOTE_FUNCTION_SECRET_KEY": "some-hmac-key",
             "EnvVarKey": "EnvVarValue",
         }
         assert job_settings.job_conda_env == "my_conda_env"
    @@ -542,7 +526,6 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess
     
     
     @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -556,7 +539,6 @@ def test_start(
         mock_runtime_manager,
         mock_script_upload,
         mock_dependency_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -575,7 +557,6 @@ def test_start(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=None,
         )
     
    @@ -662,12 +643,11 @@ def test_start(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
     @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -681,7 +661,6 @@ def test_start_with_checkpoint_location(
         mock_runtime_manager,
         mock_script_upload,
         mock_user_workspace_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -707,7 +686,6 @@ def test_start_with_checkpoint_location(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=None,
         )
     
    @@ -779,7 +757,7 @@ def test_start_with_checkpoint_location(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
    @@ -819,7 +797,6 @@ def test_start_with_checkpoint_location_failed_with_multiple_checkpoint_location
             )
     
     
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -833,7 +810,6 @@ def test_start_with_complete_job_settings(
         mock_runtime_manager,
         mock_bootstrap_script_upload,
         mock_user_workspace_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -860,7 +836,6 @@ def test_start_with_complete_job_settings(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=KMS_KEY_ARN,
         )
     
    @@ -913,7 +888,10 @@ def test_start_with_complete_job_settings(
                     },
                 ),
             ],
    -        OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}", "KmsKeyId": KMS_KEY_ARN},
    +        OutputDataConfig={
    +            "S3OutputPath": f"{S3_URI}/{job.job_name}",
    +            "KmsKeyId": KMS_KEY_ARN,
    +        },
             AlgorithmSpecification=dict(
                 TrainingImage=IMAGE,
                 TrainingInputMode="File",
    @@ -949,12 +927,11 @@ def test_start_with_complete_job_settings(
             EnableInterContainerTrafficEncryption=False,
             VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]),
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
     @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
    -@patch("secrets.token_hex", MagicMock(return_value=HMAC_KEY))
     @patch(
         "sagemaker.remote_function.job._prepare_dependencies_and_pre_execution_scripts",
         return_value="some_s3_uri",
    @@ -1027,7 +1004,6 @@ def test_get_train_args_under_pipeline_context(
         mock_stored_function_ctr.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=s3_base_uri,
    -        hmac_key="token-from-pipeline",
             s3_kms_key=KMS_KEY_ARN,
             context=Context(
                 step_name=MOCKED_PIPELINE_CONFIG.step_name,
    @@ -1160,14 +1136,10 @@ def test_get_train_args_under_pipeline_context(
             EnableInterContainerTrafficEncryption=False,
             VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]),
             EnableManagedSpotTraining=False,
    -        Environment={
    -            "AWS_DEFAULT_REGION": "us-west-2",
    -            "REMOTE_FUNCTION_SECRET_KEY": "token-from-pipeline",
    -        },
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch(
         "sagemaker.remote_function.job._JobSettings._get_default_spark_image",
         return_value="some_image_uri",
    @@ -1192,7 +1164,6 @@ def test_start_with_spark(
         mock_dependency_upload,
         mock_spark_dependency_upload,
         mock_get_default_spark_image,
    -    secrete_token,
     ):
         spark_config = SparkConfig()
         job_settings = _JobSettings(
    @@ -1301,7 +1272,7 @@ def test_start_with_spark(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
    @@ -1834,7 +1805,6 @@ def test_extend_spark_config_to_request(
     
     
     @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -1848,7 +1818,6 @@ def test_start_with_torchrun_single_node(
         mock_runtime_manager,
         mock_script_upload,
         mock_dependency_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -1869,7 +1838,6 @@ def test_start_with_torchrun_single_node(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=None,
         )
     
    @@ -1958,12 +1926,11 @@ def test_start_with_torchrun_single_node(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
     @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -1977,7 +1944,6 @@ def test_start_with_torchrun_multi_node(
         mock_runtime_manager,
         mock_script_upload,
         mock_dependency_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -1999,7 +1965,6 @@ def test_start_with_torchrun_multi_node(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=None,
         )
     
    @@ -2090,7 +2055,7 @@ def test_start_with_torchrun_multi_node(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
    @@ -2355,7 +2320,6 @@ def test_set_env_multi_node_multi_gpu_mpirun(
     
     
     @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -2369,7 +2333,6 @@ def test_start_with_torchrun_single_node_with_nproc_per_node(
         mock_runtime_manager,
         mock_script_upload,
         mock_dependency_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -2391,7 +2354,6 @@ def test_start_with_torchrun_single_node_with_nproc_per_node(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=None,
         )
     
    @@ -2482,12 +2444,11 @@ def test_start_with_torchrun_single_node_with_nproc_per_node(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
     @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
    -@patch("secrets.token_hex", return_value=HMAC_KEY)
     @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
     @patch(
         "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
    @@ -2501,7 +2462,6 @@ def test_start_with_mpirun_single_node_with_nproc_per_node(
         mock_runtime_manager,
         mock_script_upload,
         mock_dependency_upload,
    -    secret_token,
     ):
     
         job_settings = _JobSettings(
    @@ -2523,7 +2483,6 @@ def test_start_with_mpirun_single_node_with_nproc_per_node(
         mock_stored_function.assert_called_once_with(
             sagemaker_session=session(),
             s3_base_uri=f"{S3_URI}/{job.job_name}",
    -        hmac_key=HMAC_KEY,
             s3_kms_key=None,
         )
     
    @@ -2614,7 +2573,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node(
             EnableNetworkIsolation=False,
             EnableInterContainerTrafficEncryption=True,
             EnableManagedSpotTraining=False,
    -        Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
    +        Environment={"AWS_DEFAULT_REGION": "us-west-2"},
         )
     
     
    
  • tests/unit/sagemaker/workflow/test_pipeline.py+0 1 modified
    @@ -324,7 +324,6 @@ def test_pipeline_execution_result(
             },
             "TrainingJobStatus": "Completed",
             "OutputDataConfig": {"S3OutputPath": s3_output_path},
    -        "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "abcdefg"},
         }
         execution.result("stepA")
     
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

9

News mentions

0

No linked articles in our index yet.