diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir @@ -0,0 +1,141 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +// Multiply full size tiles into zero destination. +func @kernel(%arg0: memref<16x32xbf16>, + %arg1: memref<16x32xbf16>, + %arg2: memref<16x16xf32>) { + %0 = arith.constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> + %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> + %3 = amx.tile_zero : vector<16x16xf32> + %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32> + return +} + +func @entry() -> i32 { + %fu = arith.constant -1.0: f32 + %c0 = arith.constant 0: index + %c1 = arith.constant 1: index + %c16 = arith.constant 16: index + %c32 = arith.constant 32: index + + // Setup simple test data. Note that bf16 does not seem to work well with the + // tensor type yet, which is why we use vectors and transfer the data into memref. + // TODO: use tensors and bufferization.to_memref instead + %0 = arith.constant dense<[ + [ 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ] + ]> : vector<16x32xbf16> + %1 = arith.constant dense<[ + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ], + [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ] + ]> : vector<16x32xbf16> + + // Set up memory. + %a = memref.alloc() : memref<16x32xbf16> + %b = memref.alloc() : memref<16x32xbf16> + %c = memref.alloc() : memref<16x16xf32> + vector.transfer_write %0, %a[%c0, %c0] : vector<16x32xbf16>, memref<16x32xbf16> + vector.transfer_write %1, %b[%c0, %c0] : vector<16x32xbf16>, memref<16x32xbf16> + + // Call kernel. + call @kernel(%a, %b, %c) : (memref<16x32xbf16>, memref<16x32xbf16>, memref<16x16xf32>) -> () + + // + // Print and verify the 16x16 result. + // + // CHECK: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 ) + // CHECK-NEXT: ( 32.2031, 34.5, 34.6094, 33.7969, 33.0284, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 ) + // CHECK-NEXT: ( 32.2969, 34.5938, 34.7031, 33.8906, 33.1619, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 ) + // CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 ) + // CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 ) + // CHECK-NEXT: ( 32.6016, 34.8984, 35.0078, 34.3739, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 ) + // CHECK-NEXT: ( 32.7031, 35, 35.1094, 34.577, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 ) + // CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 ) + // CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 ) + // CHECK-NEXT: ( 32.7031, 35, 35.4609, 34.2969, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 ) + // CHECK-NEXT: ( 32.6016, 34.8984, 35.3697, 34.1953, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 ) + // CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 ) + // CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 ) + // CHECK-NEXT: ( 32.2969, 34.8025, 34.7031, 33.8906, 33.1016, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 ) + // CHECK-NEXT: ( 32.2031, 34.6619, 34.6094, 33.7969, 33.0078, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 ) + // CHECK-NEXT: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 ) + // + scf.for %i = %c0 to %c16 step %c1 { + %v = vector.transfer_read %c[%i, %c0], %fu: memref<16x16xf32>, vector<16xf32> + vector.print %v : vector<16xf32> + } + + // Release resources. + memref.dealloc %a : memref<16x32xbf16> + memref.dealloc %b : memref<16x32xbf16> + memref.dealloc %c : memref<16x16xf32> + %i0 = arith.constant 0 : i32 + return %i0 : i32 +}