Skip to content

Commit 9fe6cdf

Browse files
Merge pull request #2987 from hx89:hx/nvfp4_no_rht
PiperOrigin-RevId: 866117394
2 parents cdf4e6b + f0e1eb4 commit 9fe6cdf

2 files changed

Lines changed: 2 additions & 0 deletions

File tree

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class QuantizationType(str, Enum):
8484
TE_FP8_CS = "te_fp8_currentscaling"
8585
TE_MXFP8 = "te_mxfp8"
8686
TE_NVFP4 = "te_nvfp4"
87+
TE_NVFP4_NO_RHT = "te_nvfp4_no_rht"
8788

8889

8990
class KvQuantAxis(str, Enum):

src/MaxText/layers/quantizations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def _get_recipe(recipe_name: str):
749749
"te_fp8_currentscaling": recipe.Float8CurrentScaling,
750750
"te_mxfp8": recipe.MXFP8BlockScaling,
751751
"te_nvfp4": recipe.NVFP4BlockScaling, # pytype: disable=module-attr
752+
"te_nvfp4_no_rht": functools.partial(recipe.NVFP4BlockScaling, disable_rht=True), # pytype: disable=module-attr
752753
}
753754
if recipe_name not in RECIPES:
754755
raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}")

0 commit comments

Comments
 (0)