This commit is contained in:
Darren Burns 2023-02-07 13:38:21 +00:00
parent 3cf010ebe7
commit cc9e342b40
No known key found for this signature in database
GPG key ID: B0939B45037DC345
2 changed files with 11 additions and 24 deletions

View file

@ -7,3 +7,4 @@ exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
if __name__ == "__main__": if __name__ == "__main__":
@overload @overload
__rich_repr__

View file

@ -13,6 +13,7 @@ from typing import (
NamedTuple, NamedTuple,
Callable, Callable,
Sequence, Sequence,
Any,
) )
import rich.repr import rich.repr
@ -1429,35 +1430,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def sort( def sort(
self, self,
column: str | ColumnKey | Sequence[str] | Sequence[ColumnKey], *columns: ColumnKey | str,
reverse: bool = False, reverse: bool = False,
) -> None: ) -> None:
if isinstance(column, (str, ColumnKey)): def sort_by_column_keys(
column = (column,) row: tuple[RowKey, dict[ColumnKey | str, CellType]]
indices = [self._column_locations.get(key) for key in column] ) -> Any:
ordered_keys = sorted(self.rows, key=itemgetter(*indices), reverse=reverse) _, row_data = row
self._row_locations = TwoWayDict( return itemgetter(*columns)(row_data)
{key: new_index for new_index, key in enumerate(ordered_keys)}
)
self._update_count += 1
self.refresh()
def sort_columns( ordered_rows = sorted(
self, key: Callable[[ColumnKey | str], str] = None, reverse: bool = False self.data.items(), key=sort_by_column_keys, reverse=reverse
) -> None:
ordered_keys = sorted(self.columns.keys(), key=key, reverse=reverse)
self._column_locations = TwoWayDict(
{key: new_index for new_index, key in enumerate(ordered_keys)}
) )
self._update_count += 1
self.refresh()
def sort_rows(
self, key: Callable[[RowKey | str], str] = None, reverse: bool = False
):
ordered_keys = sorted(self.rows.keys(), key=key, reverse=reverse)
self._row_locations = TwoWayDict( self._row_locations = TwoWayDict(
{key: new_index for new_index, key in enumerate(ordered_keys)} {key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
) )
self._update_count += 1 self._update_count += 1
self.refresh() self.refresh()