Webllm#
import asyncio
import panel as pn
import param
from panel.custom import JSComponent, ESMEvent
pn.extension('mathjax', template='material')
This example demonstrates how to wrap an external library (specifically WebLLM) as a JSComponent
and interface it with the ChatInterface
.
MODELS = {
'SmolLM (130MB)': 'SmolLM-135M-Instruct-q4f16_1-MLC',
'TinyLlama-1.1B-Chat (675 MB)': 'TinyLlama-1.1B-Chat-v1.0-q4f16_1-MLC-1k',
'Gemma-2b (2GB)': 'gemma-2-2b-it-q4f16_1-MLC',
'Llama-3.2-3B-Instruct (2.2GB)': 'Llama-3.2-3B-Instruct-q4f16_1-MLC',
'Mistral-7b-Instruct (5GB)': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',
}
class WebLLM(JSComponent):
loaded = param.Boolean(default=False, doc="""
Whether the model is loaded.""")
history = param.Integer(default=3)
status = param.Dict(default={'text': '', 'progress': 0})
load_model = param.Event()
model = param.Selector(default='SmolLM-135M-Instruct-q4f16_1-MLC', objects=MODELS)
running = param.Boolean(default=False, doc="""
Whether the LLM is currently running.""")
temperature = param.Number(default=1, bounds=(0, 2), doc="""
Temperature of the model completions.""")
_esm = """
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
const engines = new Map()
export async function render({ model }) {
model.on("msg:custom", async (event) => {
if (event.type === 'load') {
if (!engines.has(model.model)) {
const initProgressCallback = (status) => {
model.status = status
}
const mlc = await webllm.CreateMLCEngine(
model.model,
{initProgressCallback}
)
engines.set(model.model, mlc)
}
model.loaded = true
} else if (event.type === 'completion') {
const engine = engines.get(model.model)
if (engine == null) {
model.send_msg({'finish_reason': 'error'})
}
const chunks = await engine.chat.completions.create({
messages: event.messages,
temperature: model.temperature ,
stream: true,
})
model.running = true
for await (const chunk of chunks) {
if (!model.running) {
break
}
model.send_msg(chunk.choices[0])
}
}
})
}
"""
def __init__(self, **params):
super().__init__(**params)
if pn.state.location:
pn.state.location.sync(self, {'model': 'model'})
self._buffer = []
@param.depends('load_model', watch=True)
def _load_model(self):
self.loading = True
self._send_msg({'type': 'load'})
@param.depends('loaded', watch=True)
def _loaded(self):
self.loading = False
@param.depends('model', watch=True)
def _update_load_model(self):
self.loaded = False
def _handle_msg(self, msg):
if self.running:
self._buffer.insert(0, msg)
async def create_completion(self, msgs):
self._send_msg({'type': 'completion', 'messages': msgs})
latest = None
while True:
await asyncio.sleep(0.01)
if not self._buffer:
continue
choice = self._buffer.pop()
yield choice
reason = choice['finish_reason']
if reason == 'error':
raise RuntimeError('Model not loaded')
elif reason:
return
async def callback(self, contents: str, user: str):
if not self.loaded:
if self.loading:
yield pn.pane.Markdown(
f'## `{self.model}`\n\n' + self.param.status.rx()['text']
)
else:
yield 'Load the model'
return
self.running = False
self._buffer.clear()
message = ""
async for chunk in llm.create_completion([{'role': 'user', 'content': contents}]):
message += chunk['delta'].get('content', '')
yield message
def menu(self):
status = self.param.status.rx()
return pn.Column(
pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),
pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),
pn.widgets.Button.from_param(
self.param.load_model, sizing_mode='stretch_width',
disabled=self.param.loaded.rx().rx.or_(self.param.loading)
),
pn.indicators.Progress(
value=(status['progress']*100).rx.pipe(int), visible=self.param.loading,
sizing_mode='stretch_width'
),
pn.pane.Markdown(status['text'], visible=self.param.loading)
)
Having implemented the WebLLM
component we can render the WebLLM UI:
llm = WebLLM()
intro = pn.pane.Alert("""
`WebLLM` runs large-language models entirely in your browser.
When visiting the application the first time the model has
to be downloaded and loaded into memory, which may take
some time. Models are ordered by size (and capability),
e.g. SmolLLM is very quick to download but produces poor
quality output while Mistral-7b will take a while to
download but produces much higher quality output.
""".replace('\n', ' '))
pn.Column(
llm.menu(),
intro,
llm
).servable(area='sidebar')
And connect it to a ChatInterface
:
chat_interface = pn.chat.ChatInterface(callback=llm.callback)
chat_interface.send(
"Load a model and start chatting.",
user="System",
respond=False,
)
llm.param.watch(lambda e: chat_interface.send(f'Loaded `{e.obj.model}`, start chatting!', user='System', respond=False), 'loaded')
pn.Row(chat_interface).servable(title='WebLLM')