Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sam2
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha committed Nov 18, 2024
2 parents 18cda08 + d37e4f8 commit 9c9893c
Show file tree
Hide file tree
Showing 38 changed files with 618 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def CanonicalizeShapesPass : Pass<"stablehlo-ext-canonicalize-shapes", "ModuleOp
}];

let options = [
Option<"maxIterations", "max-iterations", "int64_t", "4",
Option<"maxIterations", "max-iterations", "int64_t", "8",
"the maximum number of iterations to run the dynamism simplification and "
"shape refinement if a fixed-point is not reached">
];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ struct SimplifyExtractOfReshape : public OpRewritePattern<tensor::ExtractOp> {
if (!reshapeOp)
return failure();

// Skip if either shape has dynamic dimensions
if (!reshapeOp.getOperand().getType().hasStaticShape())
return failure();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,6 @@ struct AbsorbTensorCastProducer : public RewritePattern {
};
} // namespace


/// Populates patterns that are temporarily reproduced here from upstream
/// commits we have not yet integrated.
static void populateFutureUpstreamPatterns(RewritePatternSet &patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,42 @@ func.func @broadcast_elim_matmul_vector(%arg0: tensor<?x?x128xf32>, %arg1: tenso
// CHECK: return %[[v0]] : tensor<?x?x100xf32>


// -----

func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor<?x?x1x1xi1>, %arg1: tensor<?x1xf16>, %arg2: tensor<?x?x256x256xf16>, %arg3: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
%0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x?x1x1xi1> to tensor<?x?x256x256xi1>
%1 = tensorrt.broadcast %arg1 broadcast_dims<2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x1xf16> to tensor<?x?x256x256xf16>
%2 = tensorrt.select ins(%0, %arg2, %1 : tensor<?x?x256x256xi1>, tensor<?x?x256x256xf16>, tensor<?x?x256x256xf16>)
-> tensor<?x?x256x256xf16>
return %2 : tensor<?x?x256x256xf16>
}

// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x?x1x1xi1>, %[[arg1:.+]]: tensor<?x1xf16>, %[[arg2:.+]]: tensor<?x?x256x256xf16>, %[[arg3:.+]]: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] : tensor<?x1xf16> to tensor<1x1x?x1xf16>
// CHECK: %[[v1:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v0]] : tensor<?x?x1x1xi1>, tensor<?x?x256x256xf16>, tensor<1x1x?x1xf16>) -> tensor<?x?x256x256xf16>
// CHECK: return %[[v1]] : tensor<?x?x256x256xf16>

// -----

func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor<?x?x1x1xi1>, %arg1: tensor<?x1x?xf16>, %arg2: tensor<?x?x256x256xf16>, %arg3: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
%0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x?x1x1xi1> to tensor<?x?x256x256xi1>
%1 = tensorrt.broadcast %arg1 broadcast_dims<3, 2, 1> shape(%arg3 : tensor<4xi32>) : tensor<?x1x?xf16> to tensor<?x?x256x256xf16>
%2 = tensorrt.select ins(%0, %arg2, %1 : tensor<?x?x256x256xi1>, tensor<?x?x256x256xf16>, tensor<?x?x256x256xf16>)
-> tensor<?x?x256x256xf16>
return %2 : tensor<?x?x256x256xf16>
}

// CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
// CHECK: module {
// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x?x1x1xi1>, %[[arg1:.+]]: tensor<?x1x?xf16>, %[[arg2:.+]]: tensor<?x?x256x256xf16>, %[[arg3:.+]]: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<1xi32>
// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]] : tensor<?x1x?xf16> to tensor<?x1x?xf16>
// CHECK: %[[v1:.+]] = tensorrt.shape %[[v0]] : tensor<?x1x?xf16> -> tensor<3xi32>
// CHECK: %[[v2:.+]] = tensorrt.slice %[[v1]][0][1][1] : tensor<3xi32> to tensor<1xi32>
// CHECK: %[[v3:.+]] = tensorrt.slice %[[v1]][2][1][1] : tensor<3xi32> to tensor<1xi32>
// CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32]], %[[v2]], %[[cst_i32]], %[[v3]] : tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK: %[[v5:.+]] = tensorrt.reshape %[[v0]] shape(%[[v4]]: tensor<4xi32>) : tensor<?x1x?xf16> to tensor<1x?x1x?xf16>
// CHECK: %[[v6:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v5]] : tensor<?x?x1x1xi1>, tensor<?x?x256x256xf16>, tensor<1x?x1x?xf16>) -> tensor<?x?x256x256xf16>
// CHECK: return %[[v6]] : tensor<?x?x256x256xf16>
Original file line number Diff line number Diff line change
Expand Up @@ -1088,3 +1088,27 @@ func.func @reduce_window_dynamic_input(%arg0: tensor<?x?x?x?xf32> {tensorrt.shap
// CHECK-DAG: %[[v2:.+]] = arith.maxsi %[[dim]], %[[c0]] : index
// CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v1]](%[[v2]], %[[c3]], %[[c512]], %[[c512]]) :
// CHECK-DAG: return %[[v3]]

// -----

func.func @simplify_extract_of_reshape_negative(%arg0: tensor<1x?x3x4xf32>) -> f32 {
%c0 = arith.constant 0: index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%1 = stablehlo.reshape %arg0 : (tensor<1x?x3x4xf32>) -> tensor<1x6x4xf32>
%2 = tensor.extract %1[%c0, %c1, %c2] : tensor<1x6x4xf32>
return %2 : f32
}

// CHECK-LABEL: simplify_extract_of_reshape_negative
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x?x3x4xf32>)
// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index
// CHECK-NEXT: %[[c3:.+]] = arith.constant 3 : index
// CHECK-NEXT: %[[c2:.+]] = arith.constant 2 : index
// CHECK-NEXT: %[[c1:.+]] = arith.constant 1 : index
// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index
// CHECK-NEXT: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] : tensor<1x?x3x4xf32>
// CHECK-NEXT: %[[v0:.+]] = plan.with_shape %[[arg0]](%[[c1]], %[[dim]], %[[c3]], %[[c4]])
// CHECK-NEXT: %[[v1:.+]] = stablehlo.reshape %[[v0]]
// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v1]][%[[c0]], %[[c1]], %[[c2]]]
// CHECK-NEXT: return %extracted
16 changes: 16 additions & 0 deletions mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,22 @@ func.func @concat_simplify_single_operand_requires_cast(%arg0: tensor<4xi32>) ->

// -----

func.func @concat_slice_concat(%arg0: tensor<1xi32>, %arg1: tensor<3xi32>, %arg2: tensor<1xi32>) -> tensor<5xi32> {
%0 = stablehlo.concatenate %arg0, %arg1, %arg2, dim = 0 : (tensor<1xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<5xi32>
%1 = stablehlo.slice %0 [1:5] : (tensor<5xi32>) -> tensor<4xi32>
%2 = stablehlo.constant dense<1> : tensor<1xi32>
%3 = stablehlo.concatenate %2, %1, dim = 0 : (tensor<1xi32>, tensor<4xi32>) -> tensor<5xi32>
return %3 : tensor<5xi32>
}

// CHECK-LABEL: func.func @concat_slice_concat
// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<3xi32>, %[[arg2:.+]]: tensor<1xi32>) -> tensor<5xi32>
// CHECK: %[[c:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
// CHECK: %[[v0:.+]] = stablehlo.concatenate %[[c]], %[[arg1]], %[[arg2]], dim = 0
// CHECK: return %[[v0]] : tensor<5xi32>

// -----

func.func @bitwise_or_fold_lhs(%arg0: tensor<5xi8>, %arg1: tensor<5xi1>, %arg2: tensor<5xi32>) -> (tensor<5xi8>, tensor<5xi1>, tensor<5xi32>, tensor<5xi32>){
%0 = stablehlo.constant dense<[255, 255, 255, 255, 255]> : tensor<5xi8>
%1 = stablehlo.or %0, %arg0 : tensor<5xi8>
Expand Down
86 changes: 86 additions & 0 deletions mlir-tensorrt/test/Dialect/StableHloExt/refine-shapes.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: mlir-tensorrt-opt %s -split-input-file -stablehlo-ext-refine-shapes | FileCheck %s

func.func @check_type_refinement() -> tensor<?xf32> {
%c = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi32>
%c_0 = stablehlo.constant dense<3> : tensor<i32>
%c_1 = stablehlo.constant dense<1> : tensor<1xi32>
%c_2 = stablehlo.constant dense<3> : tensor<1xi32>
%c_3 = stablehlo.constant dense<1> : tensor<i32>
%c_4 = stablehlo.constant dense<1> : tensor<1xi32>
%c_5 = stablehlo.constant dense<0> : tensor<i32>
%c_6 = stablehlo.constant dense<1> : tensor<i32>
%c_7 = stablehlo.constant dense<0> : tensor<1xi32>
%c_8 = stablehlo.constant dense<1> : tensor<1xi32>
%0 = stablehlo.compare LE, %c_7, %c_8 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
%1 = stablehlo.select %0, %c_7, %c_8 : tensor<1xi1>, tensor<1xi32>
%c_9 = stablehlo.constant dense<1> : tensor<1xi32>
%2 = stablehlo.real_dynamic_slice %c_4, %1, %c_8, %c_9 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
%c_10 = stablehlo.constant dense<> : tensor<0xi32>
%3 = stablehlo.dynamic_reshape %2, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
%c_11 = stablehlo.constant dense<-1> : tensor<i32>
%c_12 = stablehlo.constant dense<> : tensor<0xi32>
%4 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
%5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
%6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
%7 = stablehlo.dynamic_broadcast_in_dim %c_11, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
%8 = stablehlo.add %6, %7 : tensor<i32>
%c_13 = stablehlo.constant dense<0> : tensor<1xi32>
%c_14 = stablehlo.constant dense<1> : tensor<1xi32>
%9 = stablehlo.compare LE, %c_13, %c_14 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
%10 = stablehlo.select %9, %c_13, %c_14 : tensor<1xi1>, tensor<1xi32>
%c_15 = stablehlo.constant dense<1> : tensor<1xi32>
%11 = stablehlo.real_dynamic_slice %c_4, %10, %c_14, %c_15 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
%12 = stablehlo.dynamic_reshape %11, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
%13 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
%14 = stablehlo.select %13, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
%15 = stablehlo.dynamic_broadcast_in_dim %12, %14, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
%16 = stablehlo.dynamic_broadcast_in_dim %c_11, %14, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
%17 = stablehlo.add %15, %16 : tensor<i32>
%18 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
%19 = stablehlo.select %18, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
%20 = stablehlo.dynamic_broadcast_in_dim %17, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
%21 = stablehlo.dynamic_broadcast_in_dim %c_6, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
%22 = stablehlo.add %20, %21 : tensor<i32>
%23 = stablehlo.reshape %8 : (tensor<i32>) -> tensor<1xi32>
%24 = stablehlo.reshape %22 : (tensor<i32>) -> tensor<1xi32>
%25 = stablehlo.compare LE, %23, %24 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
%26 = stablehlo.select %25, %23, %24 : tensor<1xi1>, tensor<1xi32>
%c_16 = stablehlo.constant dense<1> : tensor<1xi32>
%27 = stablehlo.real_dynamic_slice %c_2, %26, %24, %c_16 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
%28 = stablehlo.dynamic_reshape %27, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
%29 = stablehlo.dynamic_broadcast_in_dim %28, %c_1, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%30 = stablehlo.dynamic_broadcast_in_dim %cst, %29, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
return %30 : tensor<?xf32>
}

// CHECK-LABEL: func.func @check_type_refinement
// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<-1> : tensor<i32>
// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<> : tensor<0xi32>
// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
// CHECK-DAG: %[[c_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32>
// CHECK-DAG: %[[c_3:.+]] = stablehlo.constant dense<1> : tensor<i32>
// CHECK-DAG: %[[c_4:.+]] = stablehlo.constant dense<0> : tensor<1xi32>
// CHECK-DAG: %[[v0:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK-DAG: %[[v1:.+]] = stablehlo.dynamic_reshape %[[v0]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v2:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v1]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v3:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[v2]], %[[v3]] : tensor<i32>
// CHECK-DAG: %[[v5:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK-DAG: %[[v6:.+]] = stablehlo.dynamic_reshape %[[v5]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v7:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v6]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v8:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v9:.+]] = stablehlo.add %[[v7]], %[[v8]] : tensor<i32>
// CHECK-DAG: %[[v10:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v9]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v11:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c_3]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v12:.+]] = stablehlo.add %[[v10]], %[[v11]] : tensor<i32>
// CHECK-DAG: %[[v13:.+]] = stablehlo.reshape %[[v4]] : (tensor<i32>) -> tensor<1xi32>
// CHECK-DAG: %[[v14:.+]] = stablehlo.reshape %[[v12]] : (tensor<i32>) -> tensor<1xi32>
// CHECK-DAG: %[[v15:.+]] = stablehlo.compare LE, %[[v13]], %[[v14]] : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK-DAG: %[[v16:.+]] = stablehlo.select %[[v15]], %[[v13]], %[[v14]] : tensor<1xi1>, tensor<1xi32>
// CHECK-DAG: %[[v17:.+]] = stablehlo.real_dynamic_slice %[[c_2]], %[[v16]], %[[v14]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
// CHECK-DAG: %[[v18:.+]] = stablehlo.dynamic_reshape %[[v17]], %[[c_0]] : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG: %[[v19:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v18]], %[[c_1]], dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK-DAG: %[[v20:.+]] = stablehlo.dynamic_broadcast_in_dim %[[cst]], %[[v19]], dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
// CHECK-DAG: return %[[v20]] : tensor<?xf32>
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/bert.stablehlo.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
module @bert attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<32x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<32x8x768xf16> {mhlo.layout_mode = "default"}, tensor<32x768xf16> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense_resource<__elided__> : tensor<30522x768xf32>
%1 = stablehlo.constant dense_resource<__elided__> : tensor<512x768xf32>
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
module @gpt2_bs2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2x20xi32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<0> : tensor<1xi32>
%1 = stablehlo.constant dense<768> : tensor<i32>
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
module @gpt_bs1 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x7xi32> {jax.arg_info = "inputs['attention_mask']", mhlo.sharding = "{replicated}"}, %arg1: tensor<1x7xi32> {jax.arg_info = "inputs['input_ids']", mhlo.sharding = "{replicated}"}) -> (tensor<1x20xi32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense_resource<__elided__> : tensor<50257x768xf16>
%1 = stablehlo.constant dense_resource<__elided__> : tensor<1024x768xf16>
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/llama-68m.stablehlo.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit_generate attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} {
module @llama_68m attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<1x9xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<1x9xi32> {mhlo.sharding = "{replicated}"}) -> tensor<1x20xi32> {
%0 = stablehlo.constant dense<1.000000e+00> : tensor<1x1x3072xf32>
%1 = stablehlo.constant dense<-3.40282347E+38> : tensor<1x1x1x20xf32>
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/llama-v2.stablehlo.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
module @llama_v2 attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x27xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x27x32000xf32> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense_resource<__elided__> : tensor<32000x4096xf16>
%1 = stablehlo.constant dense_resource<__elided__> : tensor<4096xf16>
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/resnet50.stablehlo.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
module @resnet50 attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<16x3x224x224xf16> {mhlo.layout_mode = "default"}) -> (tensor<16x1000xf16> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>
%1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/test/models/swin.stablehlo.elided.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit_run attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} {
module @swin attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1000xf32> {
%cst = stablehlo.constant dense_resource<__elided__> : tensor<1x1000xf32>
%cst_0 = stablehlo.constant dense_resource<__elided__> : tensor<1024x1000xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module @jit_generate_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
module @whisper_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x80x3000xf32> {jax.arg_info = "input_features", mhlo.sharding = "{replicated}"}) -> (tensor<1x448xi32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense_resource<__elided__> : tensor<3x80x384xf32>
%1 = stablehlo.constant dense_resource<__elided__> : tensor<384xf32>
Expand Down
5 changes: 2 additions & 3 deletions tripy/docs/post0_developer_guides/how-to-add-new-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,9 @@ import tripy as tp

def test_multi_dimensional():
output = tp.theta([2, 3], dim=1)
expected = np.broadcast_to(np.arange(0, 3, dtype=np.float32), (2, 3))

assert np.array_equal(cp.from_dlpack(output).get(), expected)
expected = tp.Tensor([[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], dtype=tp.float32)

assert tp.equal(output, expected)
```

## Done!
Expand Down
Loading

0 comments on commit 9c9893c

Please sign in to comment.