Skip to content

Commit

Permalink
factor out some shared pervasion code
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Oct 3, 2024
1 parent 29c9555 commit 38c8e2f
Showing 1 changed file with 83 additions and 59 deletions.
142 changes: 83 additions & 59 deletions src/algorithm/pervade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,44 +92,59 @@ where
let _a_depth = a_depth.min(a.rank());
let _b_depth = b_depth.min(b.rank());

let new_rank = a.rank().max(b.rank());
let mut new_shape = Shape::with_capacity(new_rank);
let a_fill = env.scalar_fill::<A>();
let b_fill = env.scalar_fill::<B>();
for i in 0..new_rank {
let c = match (a.shape.get(i).copied(), b.shape.get(i).copied()) {
(None, None) => unreachable!(),
(Some(a), None) => a,
(None, Some(b)) => b,
(Some(ad), Some(bd)) => {
if ad == bd || ad == 1 || bd == 1 {
pervade_dim(ad, bd)
} else if ad < bd {
match &a_fill {
Ok(_) => pervade_dim(ad, bd),
Err(e) => {
return Err(env.error(format!(
"Shapes {} and {} are not compatible{e}",
a.shape, b.shape
)))
fn derive_new_shape(
ash: &Shape,
bsh: &Shape,
a_fill_err: Option<&'static str>,
b_fill_err: Option<&'static str>,
env: &Uiua,
) -> UiuaResult<Shape> {
let new_rank = ash.len().max(bsh.len());
let mut new_shape = Shape::with_capacity(new_rank);
for i in 0..new_rank {
let c = match (ash.get(i).copied(), bsh.get(i).copied()) {
(None, None) => unreachable!(),
(Some(a), None) => a,
(None, Some(b)) => b,
(Some(ad), Some(bd)) => {
if ad == bd || ad == 1 || bd == 1 {
pervade_dim(ad, bd)
} else if ad < bd {
match a_fill_err {
None => pervade_dim(ad, bd),
Some(e) => {
return Err(env.error(format!(
"Shapes {ash} and {bsh} are not compatible{e}"
)))
}
}
}
} else {
match &b_fill {
Ok(_) => pervade_dim(ad, bd),
Err(e) => {
return Err(env.error(format!(
"Shapes {} and {} are not compatible{e}",
a.shape, b.shape
)))
} else {
match b_fill_err {
None => pervade_dim(ad, bd),
Some(e) => {
return Err(env.error(format!(
"Shapes {ash} and {bsh} are not compatible{e}"
)))
}
}
}
}
}
};
new_shape.push(c);
};
new_shape.push(c);
}
Ok(new_shape)
}

let a_fill = env.scalar_fill::<A>();
let b_fill = env.scalar_fill::<B>();
let new_shape = derive_new_shape(
&a.shape,
&b.shape,
a_fill.as_ref().err().copied(),
b_fill.as_ref().err().copied(),
env,
)?;

// dbg!(&a.shape, &b.shape, &new_shape);

let mut new_data = eco_vec![C::default(); new_shape.elements()];
Expand Down Expand Up @@ -273,38 +288,47 @@ where
let _a_depth = a_depth.min(a.rank());
let _b_depth = b_depth.min(b.rank());

let new_rank = a.rank().max(b.rank());
let mut new_shape = Shape::with_capacity(new_rank);
let fill = env.scalar_fill::<T>();
let mut requires_fill = false;
for i in 0..new_rank {
let c = match (a.shape.get(i).copied(), b.shape.get(i).copied()) {
(None, None) => unreachable!(),
(Some(a), None) => a,
(None, Some(b)) => b,
(Some(ad), Some(bd)) => {
if ad == bd || ad == 1 || bd == 1 {
requires_fill |= ad != bd && fill.is_ok();
pervade_dim(ad, bd)
} else {
match &fill {
Ok(_) => {
requires_fill = true;
pervade_dim(ad, bd)
}
Err(e) => {
return Err(env.error(format!(
"Shapes {} and {} are not compatible{e}",
a.shape, b.shape
)))
fn derive_new_shape(
ash: &Shape,
bsh: &Shape,
fill_err: Option<&str>,
env: &Uiua,
) -> UiuaResult<(Shape, bool)> {
let new_rank = ash.len().max(bsh.len());
let mut new_shape = Shape::with_capacity(new_rank);
let mut requires_fill = false;
for i in 0..new_rank {
let c = match (ash.get(i).copied(), bsh.get(i).copied()) {
(None, None) => unreachable!(),
(Some(a), None) => a,
(None, Some(b)) => b,
(Some(ad), Some(bd)) => {
if ad == bd || ad == 1 || bd == 1 {
requires_fill |= ad != bd && fill_err.is_none();
pervade_dim(ad, bd)
} else {
match &fill_err {
None => {
requires_fill = true;
pervade_dim(ad, bd)
}
Some(e) => {
return Err(env.error(format!(
"Shapes {ash} and {bsh} are not compatible{e}"
)))
}
}
}
}
}
};
new_shape.push(c);
};
new_shape.push(c);
}
Ok((new_shape, requires_fill))
}

let fill = env.scalar_fill::<T>();
let (new_shape, requires_fill) =
derive_new_shape(a.shape(), b.shape(), fill.as_ref().err().copied(), env)?;
let fill = if requires_fill { fill.ok() } else { None };

fn reuse_no_fill<T: ArrayValue + Copy>(
Expand Down

0 comments on commit 38c8e2f

Please sign in to comment.