diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -29,9 +29,11 @@ // AVX512 op definitions //===----------------------------------------------------------------------===// +// Operation that is part of the input dialect. class AVX512_Op traits = []> : Op {} +// Intrinsic operation used during lowering to LLVM IR. class AVX512_IntrOp traits = []> : LLVM_IntrOpBase overloadedResults=*/[0], /*list overloadedOperands=*/[], traits, /*numResults=*/1>; + //----------------------------------------------------------------------------// // MaskCompressOp //----------------------------------------------------------------------------// @@ -271,9 +274,17 @@ // AVX op definitions //===----------------------------------------------------------------------===// +// Operation that is part of the input dialect. class AVX_Op traits = []> : Op {} +// Operation that may be part of the input dialect, but whose +// form is somewhere between the user view of the operation +// and the actual lower level intrinsic in LLVM IR. +class AVX_LowOp traits = []> : + Op {} + +// Intrinsic operation used during lowering to LLVM IR. class AVX_IntrOp traits = []> : LLVM_IntrOpBase:$a); } +//----------------------------------------------------------------------------// +// AVX Dot +//----------------------------------------------------------------------------// + +def DotOp : AVX_LowOp<"dot", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Dot"; + let description = [{ + Computes the 4-way dot products of the lower and higher parts of the source + vectors and broadcasts the two results to the lower and higher elements of + the destination vector, respectively. Adding one element of the lower part + to one element of the higher part in the destination vector yields the full + dot product of the two source vectors. + + Example: + + ```mlir + %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> + %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> + %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> + %d = addf %1, %2 : f32 + ``` + }]; + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, + VectorOfLengthAndType<[8], [F32]>:$b); + let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); + let assemblyFormat = "$a `,` $b attr-dict `:` type($res)"; +} + +def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [NoSideEffect, + AllTypesMatch<["a", "b", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, + VectorOfLengthAndType<[8], [F32]>:$b, I8:$c); + let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); +} + #endif // X86VECTOR_OPS diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -104,6 +104,25 @@ } }; +struct DotOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(DotOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + DotOp::Adaptor adaptor(operands); + auto opType = adaptor.a().getType(); + Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); + // Dot product of all elements, broadcasted to all elements. + auto attr = rewriter.getI8IntegerAttr(0xff); + Value scale = + rewriter.create(op.getLoc(), llvmIntType, attr); + rewriter.replaceOpWithNewOp(op, opType, adaptor.a(), adaptor.b(), + scale); + return success(); + } +}; + /// An entry associating the "main" AVX512 op with its instantiations for /// vectors of 32-bit and 64-bit elements. template @@ -145,7 +164,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { Registry::registerPatterns(converter, patterns); - patterns.add(converter); + patterns.add( + converter); } void mlir::configureX86VectorLegalizeForExportTarget( @@ -155,4 +175,6 @@ target.addIllegalOp(); target.addLegalOp(); target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); } diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -50,3 +50,11 @@ %0 = x86vector.avx.rsqrt %a : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot +func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) +{ + // CHECK: x86vector.avx.intr.dp.ps.256 + %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -54,3 +54,11 @@ %0 = x86vector.avx.rsqrt %a : vector<8xf32> return %0 : vector<8xf32> } + +// CHECK-LABEL: func @avx_dot +func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) +{ + // CHECK: x86vector.avx.intr.dot {{.*}} : vector<8xf32> + %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-dot.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() -> i32 { + %i0 = constant 0 : i32 + %i4 = constant 4 : i32 + + %a = std.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32> + %b = std.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> + %r = x86vector.avx.intr.dot %a, %b : vector<8xf32> + + %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32> + %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32> + %d = addf %1, %2 : f32 + + // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 ) + // CHECK: 492 + vector.print %r : vector<8xf32> + vector.print %d : f32 + + return %i0 : i32 +}