use std::{fmt, ops};
use serde::{Deserialize, Serialize};
use zng_txt::Txt;
#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ApiExtensionPayload(#[serde(with = "serde_bytes")] pub Vec<u8>);
impl ApiExtensionPayload {
pub fn serialize<T: Serialize>(payload: &T) -> bincode::Result<Self> {
bincode::serialize(payload).map(Self)
}
pub fn deserialize<T: serde::de::DeserializeOwned>(&self) -> Result<T, ApiExtensionRecvError> {
if let Some((id, error)) = self.parse_invalid_request() {
Err(ApiExtensionRecvError::InvalidRequest {
extension_id: id,
error: Txt::from_str(error),
})
} else if let Some(id) = self.parse_unknown_extension() {
Err(ApiExtensionRecvError::UnknownExtension { extension_id: id })
} else {
bincode::deserialize(&self.0).map_err(ApiExtensionRecvError::Deserialize)
}
}
pub const fn empty() -> Self {
Self(vec![])
}
pub fn unknown_extension(extension_id: ApiExtensionId) -> Self {
Self(format!("zng-view-api.unknown_extension;id={extension_id}").into_bytes())
}
pub fn invalid_request(extension_id: ApiExtensionId, error: impl fmt::Display) -> Self {
Self(format!("zng-view-api.invalid_request;id={extension_id};error={error}").into_bytes())
}
pub fn parse_unknown_extension(&self) -> Option<ApiExtensionId> {
let p = self.0.strip_prefix(b"zng-view-api.unknown_extension;")?;
if let Some(p) = p.strip_prefix(b"id=") {
if let Ok(id_str) = std::str::from_utf8(p) {
return match id_str.parse::<ApiExtensionId>() {
Ok(id) => Some(id),
Err(id) => Some(id),
};
}
}
Some(ApiExtensionId::INVALID)
}
pub fn parse_invalid_request(&self) -> Option<(ApiExtensionId, &str)> {
let p = self.0.strip_prefix(b"zng-view-api.invalid_request;")?;
if let Some(p) = p.strip_prefix(b"id=") {
if let Some(id_end) = p.iter().position(|&b| b == b';') {
if let Ok(id_str) = std::str::from_utf8(&p[..id_end]) {
let id = match id_str.parse::<ApiExtensionId>() {
Ok(id) => id,
Err(id) => id,
};
if let Some(p) = p[id_end..].strip_prefix(b";error=") {
if let Ok(err_str) = std::str::from_utf8(p) {
return Some((id, err_str));
}
}
return Some((id, "invalid request, corrupted payload, unknown error"));
}
}
}
Some((
ApiExtensionId::INVALID,
"invalid request, corrupted payload, unknown extension_id and error",
))
}
}
impl fmt::Debug for ApiExtensionPayload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ExtensionPayload({} bytes)", self.0.len())
}
}
#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ApiExtensionName {
name: Txt,
}
impl ApiExtensionName {
pub fn new(name: impl Into<Txt>) -> Result<Self, ApiExtensionNameError> {
let name = name.into();
Self::new_impl(name)
}
fn new_impl(name: Txt) -> Result<ApiExtensionName, ApiExtensionNameError> {
if name.is_empty() {
return Err(ApiExtensionNameError::NameCannotBeEmpty);
}
for (i, c) in name.char_indices() {
if i == 0 {
if !c.is_ascii_alphabetic() {
return Err(ApiExtensionNameError::NameCannotStartWithChar(c));
}
} else if !c.is_ascii_alphanumeric() && c != '_' && c != '-' && c != '.' {
return Err(ApiExtensionNameError::NameInvalidChar(c));
}
}
Ok(Self { name })
}
}
impl fmt::Debug for ApiExtensionName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.name, f)
}
}
impl fmt::Display for ApiExtensionName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.name, f)
}
}
impl ops::Deref for ApiExtensionName {
type Target = str;
fn deref(&self) -> &Self::Target {
self.name.as_str()
}
}
impl From<&'static str> for ApiExtensionName {
fn from(value: &'static str) -> Self {
Self::new(value).unwrap()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum ApiExtensionNameError {
NameCannotBeEmpty,
NameCannotStartWithChar(char),
NameInvalidChar(char),
}
impl fmt::Display for ApiExtensionNameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ApiExtensionNameError::NameCannotBeEmpty => write!(f, "API extension name cannot be empty"),
ApiExtensionNameError::NameCannotStartWithChar(c) => {
write!(f, "API cannot start with '{c}', name pattern `[a-zA-Z][a-zA-Z0-9-_.]`")
}
ApiExtensionNameError::NameInvalidChar(c) => write!(f, "API cannot contain '{c}', name pattern `[a-zA-Z][a-zA-Z0-9-_.]`"),
}
}
}
impl std::error::Error for ApiExtensionNameError {}
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ApiExtensions(Vec<ApiExtensionName>);
impl ops::Deref for ApiExtensions {
type Target = [ApiExtensionName];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ApiExtensions {
pub fn new() -> Self {
Self::default()
}
pub fn id(&self, ext: &ApiExtensionName) -> Option<ApiExtensionId> {
self.0.iter().position(|e| e == ext).map(ApiExtensionId::from_index)
}
pub fn insert(&mut self, ext: ApiExtensionName) -> Result<ApiExtensionId, ApiExtensionId> {
if let Some(key) = self.id(&ext) {
Err(key)
} else {
let key = self.0.len();
self.0.push(ext);
Ok(ApiExtensionId::from_index(key))
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ApiExtensionId(u32);
impl fmt::Debug for ApiExtensionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if *self == Self::INVALID {
if f.alternate() {
write!(f, "ApiExtensionId::")?;
}
write!(f, "INVALID")
} else {
write!(f, "ApiExtensionId({})", self.0 - 1)
}
}
}
impl fmt::Display for ApiExtensionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if *self == Self::INVALID {
write!(f, "invalid")
} else {
write!(f, "{}", self.0 - 1)
}
}
}
impl ApiExtensionId {
pub const INVALID: Self = Self(0);
pub fn index(self) -> usize {
self.0.checked_sub(1).expect("invalid id") as _
}
pub fn from_index(idx: usize) -> Self {
if idx > (u32::MAX - 1) as _ {
panic!("index out-of-bounds")
}
Self(idx as u32 + 1)
}
}
impl std::str::FromStr for ApiExtensionId {
type Err = Self;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.parse::<u32>() {
Ok(i) => {
let r = Self::from_index(i as _);
if r == Self::INVALID {
Err(r)
} else {
Ok(r)
}
}
Err(_) => Err(Self::INVALID),
}
}
}
#[derive(Debug)]
pub enum ApiExtensionRecvError {
UnknownExtension {
extension_id: ApiExtensionId,
},
InvalidRequest {
extension_id: ApiExtensionId,
error: Txt,
},
Deserialize(bincode::Error),
}
impl fmt::Display for ApiExtensionRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ApiExtensionRecvError::UnknownExtension { extension_id } => write!(f, "invalid API request for unknown id {extension_id:?}"),
ApiExtensionRecvError::InvalidRequest { extension_id, error } => {
write!(f, "invalid API request for extension id {extension_id:?}, {error}")
}
ApiExtensionRecvError::Deserialize(e) => write!(f, "API extension response failed to deserialize, {e}"),
}
}
}
impl std::error::Error for ApiExtensionRecvError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
if let Self::Deserialize(e) = self {
Some(e)
} else {
None
}
}
}