diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir rename from mlir/integration_test/Dialect/Vector/CPU/test-outerproduct.mlir rename to mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir rename from mlir/integration_test/Dialect/Vector/CPU/test-outerproduct.mlir rename to mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir @@ -3,18 +3,18 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -!vector_type_A = type vector<8xf32> -!vector_type_B = type vector<8xf32> -!vector_type_C = type vector<8x8xf32> +!vector_type_A = type vector<8xi64> +!vector_type_B = type vector<8xi64> +!vector_type_C = type vector<8x8xi64> -!vector_type_X = type vector<2xf32> -!vector_type_Y = type vector<3xf32> -!vector_type_Z = type vector<2x3xf32> +!vector_type_X = type vector<2xi64> +!vector_type_Y = type vector<3xi64> +!vector_type_Z = type vector<2x3xi64> -func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C { - %a = splat %fa: !vector_type_A - %b = splat %fb: !vector_type_B - %c = splat %fc: !vector_type_C +func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C { + %a = splat %ia: !vector_type_A + %b = splat %ib: !vector_type_B + %c = splat %ic: !vector_type_C %d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B return %d: !vector_type_C } @@ -33,16 +33,16 @@ } func @entry() { - %f1 = constant 1.0: f32 - %f2 = constant 2.0: f32 - %f3 = constant 3.0: f32 - %f4 = constant 4.0: f32 - %f5 = constant 5.0: f32 - %f10 = constant 10.0: f32 + %i1 = constant 1: i64 + %i2 = constant 2: i64 + %i3 = constant 3: i64 + %i4 = constant 4: i64 + %i5 = constant 5: i64 + %i10 = constant 10: i64 // Simple case, splat scalars into vectors, then take outer product. - %v = call @vector_outerproduct_splat_8x8(%f1, %f2, %f10) - : (f32, f32, f32) -> (!vector_type_C) + %v = call @vector_outerproduct_splat_8x8(%i1, %i2, %i10) + : (i64, i64, i64) -> (!vector_type_C) vector.print %v : !vector_type_C // // outer product 8x8: @@ -50,11 +50,11 @@ // CHECK-COUNT-8: ( 12, 12, 12, 12, 12, 12, 12, 12 ) // Direct outerproduct on vectors with different size. - %0 = vector.broadcast %f1 : f32 to !vector_type_X - %x = vector.insert %f2, %0[1] : f32 into !vector_type_X - %1 = vector.broadcast %f3 : f32 to !vector_type_Y - %2 = vector.insert %f4, %1[1] : f32 into !vector_type_Y - %y = vector.insert %f5, %2[2] : f32 into !vector_type_Y + %0 = vector.broadcast %i1 : i64 to !vector_type_X + %x = vector.insert %i2, %0[1] : i64 into !vector_type_X + %1 = vector.broadcast %i3 : i64 to !vector_type_Y + %2 = vector.insert %i4, %1[1] : i64 into !vector_type_Y + %y = vector.insert %i5, %2[2] : i64 into !vector_type_Y %p = call @vector_outerproduct_vec_2x3(%x, %y) : (!vector_type_X, !vector_type_Y) -> (!vector_type_Z)