diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 5fe1555f2a..7aca6f0a7e 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -965,6 +965,23 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new_tensor) +@implements_torch_function(torch.Tensor.t) +def _(func, types, args, kwargs): + assert len(args) == 1 + self = args[0] + assert len(self.block_size) == 2 + new_tensor = self.__class__( + self.qdata.t(), + self.scale.t(), + (self.block_size[1], self.block_size[0]), + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return new_tensor + + @implements(aten.split.Tensor) def _(func, types, args, kwargs): tensor, split_size_or_sections, dim = args