From 2248e868bca91ac361bd8e48cb3dc302208fc900 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Tue, 17 Dec 2024 12:24:21 +0000 Subject: [PATCH] allow device to be set from constructor --- pyterrier_t5/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pyterrier_t5/__init__.py b/pyterrier_t5/__init__.py index 0627585..e098e48 100644 --- a/pyterrier_t5/__init__.py +++ b/pyterrier_t5/__init__.py @@ -17,10 +17,13 @@ def __init__(self, model='castorini/monot5-base-msmarco', batch_size=4, text_field='text', + device=None, verbose=True): self.verbose = verbose self.batch_size = batch_size - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(device) self.tokenizer = T5Tokenizer.from_pretrained(tok_model) self.model_name = model self.model = T5ForConditionalGeneration.from_pretrained(model) @@ -69,11 +72,14 @@ def __init__(self, model='castorini/duot5-base-msmarco', batch_size=4, text_field='text', + device=None, verbose=True, agg='sum'): self.verbose = verbose self.batch_size = batch_size - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(device) self.tokenizer = T5Tokenizer.from_pretrained(tok_model) self.model_name = model self.model = T5ForConditionalGeneration.from_pretrained(model)