diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-rewrite-sparse-dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-rewrite-sparse-dot.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-rewrite-sparse-dot.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -test-sparse-linalg-dot-to-avx512 -allow-unregistered-dialect -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +global_memref "private" @gm_A : memref<16xi64> = + dense<[0, 1, 10, 12, 13, 17, 18, 21, + 51, 52, 57, 61, 62, 82, 98, 99]> +global_memref "private" @gm_B_32 : memref<16xf32> = + dense<[1., 5., 8., 3., 2., 1., 0., 9., + 6., 7., 7., 3., 5., 2., 9., 1.]> +global_memref "private" @gm_B_64 : memref<16xf64> = + dense<[1., 5., 8., 3., 2., 1., 0., 9., + 6., 7., 7., 3., 5., 2., 9., 1.]> +global_memref "private" @gm_C : memref<24xi64> = + dense<[1, 2, 5, 10, 11, 12, 47, 48, + 67, 68, 69, 70, 71, 72, 77, 78, + 79, 82, 83, 84, 85, 90, 91, 98]> +global_memref "private" @gm_D_32 : memref<24xf32> = + dense<[1., 5., 8., 3., 2., 1., 2., 9., + 6., 7., 7., 3., 5., 2., 9., 1., + 2., 9., 8., 7., 2., 0., 0., 4.]> +global_memref "private" @gm_D_64 : memref<24xf64> = + dense<[1., 5., 8., 3., 2., 1., 2., 9., + 6., 7., 7., 3., 5., 2., 9., 1., + 2., 9., 8., 7., 2., 0., 0., 4.]> + +func @entry() -> i32 { + // Create input test data. + %m_A_shaped = get_global_memref @gm_A : memref<16xi64> + %m_B_32_shaped = get_global_memref @gm_B_32 : memref<16xf32> + %m_B_64_shaped = get_global_memref @gm_B_64 : memref<16xf64> + %M = constant 16 : index + %m_C_shaped = get_global_memref @gm_C : memref<24xi64> + %m_D_32_shaped = get_global_memref @gm_D_32 : memref<24xf32> + %m_D_64_shaped = get_global_memref @gm_D_64 : memref<24xf64> + %N = constant 24 : index + + // Cast to memrefs of size ?. + %m_A = memref_cast %m_A_shaped : memref<16xi64> to memref + %m_B_32 = memref_cast %m_B_32_shaped : memref<16xf32> to memref + %m_B_64 = memref_cast %m_B_64_shaped : memref<16xf64> to memref + %m_C = memref_cast %m_C_shaped : memref<24xi64> to memref + %m_D_32 = memref_cast %m_D_32_shaped : memref<24xf32> to memref + %m_D_64 = memref_cast %m_D_64_shaped : memref<24xf64> to memref + + // Test case 1: f64 data, i64 indices. + + // Wrap indices and data memrefs in fake OP, so that they can be retrieved + // by the rewrite pass. + %m0_64 = "fake_op_sparse_wrapper"(%m_A, %m_B_64, %M) + : (memref, memref, index) -> memref + %m1_64 = "fake_op_sparse_wrapper"(%m_C, %m_D_64, %N) + : (memref, memref, index) -> memref + %m_r_64 = alloc() : memref + + linalg.dot { __test_sparse__ } ins(%m0_64, %m1_64 : memref, memref) + outs(%m_r_64 : memref) + + %result_64 = load %m_r_64[] : memref + vector.print %result_64 : f64 + // CHECK: 86 + + // Test case 2: f32 data, i64 indices. + + %m0_32 = "fake_op_sparse_wrapper"(%m_A, %m_B_32, %M) + : (memref, memref, index) -> memref + %m1_32 = "fake_op_sparse_wrapper"(%m_C, %m_D_32, %N) + : (memref, memref, index) -> memref + %m_r_32 = alloc() : memref + + linalg.dot { __test_sparse__ } ins(%m0_32, %m1_32 : memref, memref) + outs(%m_r_32 : memref) + + %result_32 = load %m_r_32[] : memref + vector.print %result_32 : f32 + // CHECK: 86 + + // TODO: Support indices other than i64. + + %r = constant 0 : i32 + return %r : i32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -32,6 +32,7 @@ TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp TestSCFUtils.cpp + TestSparseLinalgDotToAVX512.cpp TestSparsification.cpp TestVectorTransforms.cpp diff --git a/mlir/test/lib/Transforms/TestSparseLinalgDotToAVX512.cpp b/mlir/test/lib/Transforms/TestSparseLinalgDotToAVX512.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestSparseLinalgDotToAVX512.cpp @@ -0,0 +1,258 @@ +//===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +static const char *kFakeOpName = "fake_op_sparse_wrapper"; +static const char *kSparseAttr = "__test_sparse__"; + +struct TestSparseLinalgDotToAVX512 + : public PassWrapper { + void runOnFunction() override; + + void replaceOperation(linalg::DotOp op); + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // namespace + +void TestSparseLinalgDotToAVX512::runOnFunction() { + // TODO: Decompose into smaller composable pieces, instead of going from + // linalg all the way to LLVMIR/AVX512 in one step. + + // Find linalg.dot and replace by nested for loops. + getFunction().walk([&](linalg::DotOp dotOp) { + if (dotOp->hasAttr(kSparseAttr)) { + replaceOperation(dotOp); + } + }); + + // Remove fake op wrappers. + getFunction().walk([&](Operation *op) { + if (op->getName().getStringRef().equals(kFakeOpName)) { + op->erase(); + } + }); +} + +void TestSparseLinalgDotToAVX512::replaceOperation(linalg::DotOp op) { + Value lhs = op.inputs()[0]; + Value rhs = op.inputs()[1]; + Value dotOut = op.outputs()[0]; + + auto *lhsFakeOp = lhs.getDefiningOp(); + auto *rhsFakeOp = rhs.getDefiningOp(); + + if (!lhsFakeOp->getName().getStringRef().equals(kFakeOpName) || + !rhsFakeOp->getName().getStringRef().equals(kFakeOpName)) { + llvm::errs() << "Input to " << kSparseAttr << " dot op must be a " + << kFakeOpName << "op.\n"; + return; + } + + if (lhsFakeOp->getNumOperands() != 3) { + llvm::errs() << "Expected 3 operands for " << kFakeOpName << "op.\n"; + return; + } + + Value idxA = lhsFakeOp->getOperand(0); + Value idxB = rhsFakeOp->getOperand(0); + Value dataA = lhsFakeOp->getOperand(1); + Value dataB = rhsFakeOp->getOperand(1); + Value sizeA = lhsFakeOp->getOperand(2); + Value sizeB = rhsFakeOp->getOperand(2); + + auto idxAType = idxA.getType().dyn_cast(); + auto idxBType = idxB.getType().dyn_cast(); + auto dataAType = dataA.getType().dyn_cast(); + auto dataBType = dataB.getType().dyn_cast(); + + if (!idxAType || !dataAType || !idxBType || !dataBType) { + llvm::errs() << "First and second operands of " << kFakeOpName + << " must be memrefs.\n"; + return; + } + + if (idxAType.getElementType() != idxBType.getElementType()) { + llvm::errs() << "Indices memrefs must have equal element type.\n"; + return; + } + + if (dataAType.getElementType() != dataBType.getElementType()) { + llvm::errs() << "Data memrefs must have equal element type.\n"; + return; + } + + if (!sizeA.getType().isIndex() || !sizeB.getType().isIndex()) { + llvm::errs() << "Third operand of " << kFakeOpName << " must be index.\n"; + return; + } + + // TODO: Support other integer types for indices. + if (!idxAType.getElementType().isInteger(64)) { + llvm::errs() << "Indices must be i64.\n"; + return; + } + + if (!dataAType.getElementType().isF32() && + !dataAType.getElementType().isF64()) { + llvm::errs() << "Data must be f32 or f64.\n"; + return; + } + + OpBuilder builder(op); + auto loc = op.getLoc(); + + // Scalar types. + auto tyI1 = builder.getI1Type(); + auto tyI64 = builder.getI64Type(); + auto tyIdx = builder.getIndexType(); + + // Input data type. + auto tyDataEl = dataAType.getElementType().dyn_cast(); + int dataVecSize = 512 / tyDataEl.getWidth(); + + // Vector types. + auto tyVec8I1 = VectorType::get({8}, tyI1); + auto tyVec16I1 = VectorType::get({16}, tyI1); + auto tyVec8I64 = VectorType::get({8}, tyI64); + auto tyVecF = VectorType::get({dataVecSize}, tyDataEl); + + auto indexPadding = builder.create( + loc, tyI64, IntegerAttr::get(tyI64, 0x7FFFFFFFFFFFFFFFULL)); + auto dataZero = + builder.create(loc, tyDataEl, FloatAttr::get(tyDataEl, 0.0)); + auto attrVecFZero = + SplatElementsAttr::get(tyVecF, builder.getZeroAttr(tyDataEl)); + + auto constIdx0 = + builder.create(loc, tyIdx, IntegerAttr::get(tyIdx, 0)); + auto constIdx8 = + builder.create(loc, tyIdx, IntegerAttr::get(tyIdx, 8)); + auto constLoopStep = builder.create( + loc, tyIdx, IntegerAttr::get(tyIdx, dataVecSize)); + + // Outer for loop: Iterate over (idxA, dataA). + auto result = builder.create( + op.getLoc(), constIdx0, sizeA, constLoopStep, ValueRange({dataZero}), + [&](OpBuilder &builderA, Location locA, Value ivA, ValueRange argsA) { + // Read indices from from idxA. + auto vIdxA0 = builderA.create( + locA, tyVec8I64, idxA, ivA, indexPadding); + // If indices are i64 and data is f32: Read two indices vectors. + Value vIdxA1; + if (dataVecSize == 16) { + auto ivA1 = builderA.create(locA, ivA, constIdx8); + vIdxA1 = builder.create( + locA, tyVec8I64, idxA, ValueRange({ivA1}), indexPadding); + } + + // Read data from dataA. + auto vDataA = builderA.create( + locA, tyVecF, dataA, ivA, dataZero); + + // Inner for loop: Iterate over (idxB, dataB). + auto sum = builderA.create( + op.getLoc(), constIdx0, sizeB, constLoopStep, argsA, + [&](OpBuilder &builderB, Location locB, Value ivB, + ValueRange argsB) { + // Read indices from idxB. + auto vIdxB0 = builderB.create( + locB, tyVec8I64, idxB, ivB, indexPadding); + // If indices are i64 and data is f32: Read two indices vectors. + Value vIdxB1; + if (dataVecSize == 16) { + auto ivB1 = builderB.create(locB, ivB, constIdx8); + vIdxB1 = builder.create( + locA, tyVec8I64, idxB, ValueRange({ivB1}), indexPadding); + } + + // Read data from dataB. + auto vDataB = builderB.create( + locB, tyVecF, dataB, ivB, dataZero); + + // Intersect indices vectors. + auto is00 = builderB.create( + locB, tyVec8I1, tyVec8I1, vIdxA0, vIdxB0); + Value k0, k1; + if (dataVecSize == 16) { + // There are two indices vectors for each vDataA and vDataB. + // Intersect all pairs (4 intersections in total). + auto is01 = builderB.create( + locB, tyVec8I1, tyVec8I1, vIdxA0, vIdxB1); + auto is10 = builderB.create( + locB, tyVec8I1, tyVec8I1, vIdxA1, vIdxB0); + auto is11 = builderB.create( + locB, tyVec8I1, tyVec8I1, vIdxA1, vIdxB1); + + auto k00 = builderB.create(locB, is00.getResult(0), + is01.getResult(0)); + auto k01 = builderB.create(locB, is10.getResult(0), + is11.getResult(0)); + auto k10 = builderB.create(locB, is00.getResult(1), + is10.getResult(1)); + auto k11 = builderB.create(locB, is01.getResult(1), + is11.getResult(1)); + + // Concatenate results to get vectors of 16 bits. + const int64_t shuffleMask[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + auto maskAttr = builderB.getI64ArrayAttr(shuffleMask); + k0 = builderB.create(locB, tyVec16I1, k00, + k01, maskAttr); + k1 = builderB.create(locB, tyVec16I1, k10, + k11, maskAttr); + } else { + k0 = is00.getResult(0); + k1 = is00.getResult(1); + } + + // Filter and compress data vectors. + auto compA = builderB.create( + loc, tyVecF, k0, vDataA, Value(), attrVecFZero); + auto compB = builderB.create( + loc, tyVecF, k1, vDataB, Value(), attrVecFZero); + + // Multiply and reduce data vectors. + auto mulAB = builderB.create(loc, compA, compB); + auto dotResult = builderB.create( + loc, tyDataEl, builder.getStringAttr("add"), mulAB, + dataZero.getResult()); + auto sum = builderB.create(loc, argsB[0], dotResult); + builderB.create(locB, sum.getResult()); + }); + builderA.create(locA, sum.getResult(0)); + }); + + builder.create(loc, result.getResult(0), dotOut); + + op->erase(); +} + +namespace mlir { +namespace test { +void registerTestSparseLinalgDotToAVX512() { + PassRegistration pass( + "test-sparse-linalg-dot-to-avx512", + "Test sparse linalg.dot to AVX512 conversion"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -97,6 +97,7 @@ void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); void registerTestSparsification(); +void registerTestSparseLinalgDotToAVX512(); void registerTestVectorConversions(); } // namespace test } // namespace mlir @@ -172,6 +173,7 @@ test::registerTestRecursiveTypesPass(); test::registerTestSCFUtilsPass(); test::registerTestSparsification(); + test::registerTestSparseLinalgDotToAVX512(); test::registerTestVectorConversions(); } #endif