diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -286,6 +286,12 @@ return llvm::make_range(attr_value_iterator(begin()), attr_value_iterator(end())); } + template + auto getAsRange() { + return llvm::map_range(getAsRange(), [](AttrTy attr) { + return static_cast(attr.getValue()); + }); + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -36,8 +36,7 @@ #define DEBUG_TYPE "linalg-vectorization" -static bool hasMultiplyAddBody(linalg::GenericOp op) { - auto &r = op.region(); +static bool hasMultiplyAddBody(Region &r) { if (!llvm::hasSingleElement(r)) return false; if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) @@ -59,14 +58,26 @@ } // TODO: Should be Tablegen'd from a single source that generates the op itself. -static bool isRowMajorMatmul(linalg::GenericOp genericOp) { - return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && - isRowMajorMatmul(genericOp.indexing_maps()) && - hasMultiplyAddBody(genericOp); +static LogicalResult isContraction(Operation *op) { + // TODO: interface for named ops. + if (isa(op)) + return success(); + + auto genericOp = dyn_cast(op); + if (!genericOp) + return failure(); + + auto mapRange = + genericOp.indexing_maps().getAsRange(); + + return success( + genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && + llvm::all_of(mapRange, + [](AffineMap m) { return m.isProjectedPermutation(); }) && + hasMultiplyAddBody(genericOp.region())); } -// TODO: This is in fact much more general than just vectorization for matmul -// and fill ops. LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -76,33 +87,16 @@ for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - if (isa(op)) - return success(); - auto genericOp = dyn_cast(op); - if (!genericOp || !::isRowMajorMatmul(genericOp)) - return failure(); + if (isa(op)) + return success(); - // TODO: non-identity layout. - auto isStaticMemRefWithIdentityLayout = [](Value v) { - auto m = v.getType().dyn_cast(); - if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) - return false; - return true; - }; - return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(), - isStaticMemRefWithIdentityLayout)); + return isContraction(op); } void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { assert(succeeded(vectorizeLinalgOpPrecondition(op))); - if (auto convOp = dyn_cast(op)) { - // TODO: add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; (void)dbgPref; edsc::ScopedContext scope(builder, op->getLoc()); @@ -117,33 +111,47 @@ return; } - // Vectorize other ops as vector contraction (currently only matmul). + assert(succeeded(isContraction(op)) && "Expected contraction"); + + // Vectorize other ops as vector contraction. + // TODO: interface. LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg op as vector.contract: " << *op); + // In the case of 0-D memrefs, return null and special case to scalar load or + // store later. auto extractVectorTypeFromScalarView = [](Value v) { MemRefType mt = v.getType().cast(); - return VectorType::get(mt.getShape(), mt.getElementType()); + return mt.getShape().empty() + ? VectorType() + : VectorType::get(mt.getShape(), mt.getElementType()); }; auto linalgOp = cast(op); Value viewA = linalgOp.getInput(0); Value viewB = linalgOp.getInput(1); Value viewC = linalgOp.getOutputBuffer(0); + VectorType vtA = extractVectorTypeFromScalarView(viewA); + VectorType vtB = extractVectorTypeFromScalarView(viewB); + VectorType vtC = extractVectorTypeFromScalarView(viewC); Value zero = std_constant_index(0); - SmallVector indicesA(linalgOp.getInputShapedType(0).getRank(), - zero); - SmallVector indicesB(linalgOp.getInputShapedType(1).getRank(), - zero); - SmallVector indicesC(linalgOp.getOutputShapedType(0).getRank(), - zero); - Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA, - indicesA); - Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB, - indicesB); - Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC, - indicesC); + SmallVector indicesA, indicesB, indicesC; + if (vtA) + indicesA = SmallVector(vtA.getRank(), zero); + if (vtB) + indicesB = SmallVector(vtB.getRank(), zero); + if (vtC) + indicesC = SmallVector(vtC.getRank(), zero); + Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value + : std_load(viewA, indicesA).value; + Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value + : std_load(viewB, indicesB).value; + Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value + : std_load(viewC, indicesC).value; Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), linalgOp.iterator_types()); - vector_transfer_write(res, viewC, indicesC); + if (vtC) + vector_transfer_write(res, viewC, indicesC); + else + std_store(res, viewC, indicesC); } /// Check whether there is any interleaved use of any `values` between `firstOp` diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, @@ -30,3 +31,38 @@ // CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> // // CHECK: linalg.copy + +// VECTOR-CONTRACTION-LABEL: contraction_dot +func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { + // VECTOR-CONTRACTION: vector.contract + // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32 + linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref + return +} + +// VECTOR-CONTRACTION-LABEL: contraction_matvec +func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { + // VECTOR-CONTRACTION: vector.contract + // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32> + linalg.matvec %A, %B, %C : + (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>) + return +} + +// VECTOR-CONTRACTION-LABEL: contraction_matmul +func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { + // VECTOR-CONTRACTION: vector.contract + // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32> + linalg.matmul %A, %B, %C : + (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>) + return +} + +// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul +func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { + // VECTOR-CONTRACTION: vector.contract + // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32> + linalg.batch_matmul %A, %B, %C : + (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -54,6 +54,11 @@ llvm::cl::desc( "Test a fused pass that forwards linalg.copy to vector.transfer"), llvm::cl::init(false)}; + Option testGenericToVectorPattern{ + *this, "test-contraction-to-vector-patterns", + llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " + "in vector.contract form"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -300,6 +305,16 @@ applyPatternsAndFoldGreedily(funcOp, forwardPattern); } +static void applyContractionToVectorPatterns(FuncOp funcOp) { + OwningRewritePatternList patterns; + patterns.insert, + LinalgVectorizationPattern, + LinalgVectorizationPattern, + LinalgVectorizationPattern, + LinalgVectorizationPattern>(funcOp.getContext()); + applyPatternsAndFoldGreedily(funcOp, patterns); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -323,6 +338,8 @@ testMatmulToVectorPatterns2dTiling); if (testVectorTransferForwardingPatterns) return applyVectorTransferForwardingPatterns(getFunction()); + if (testGenericToVectorPattern) + return applyContractionToVectorPatterns(getFunction()); } namespace mlir {