diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -16,10 +16,12 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -156,8 +158,8 @@ genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -// TODO(ntv): This is in fact much more general than just vectorization for -// matmul ops. +// TODO(ntv, ataei): This is in fact much more general than just vectorization +// for matmul and fill ops. LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -167,7 +169,7 @@ for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - if (isa(op)) + if (isa(op) || isa(op)) return success(); auto genericOp = dyn_cast(op); @@ -189,28 +191,41 @@ SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); + using edsc::intrinsics::std_load; + using edsc::intrinsics::std_store; + using vector_contract = edsc::intrinsics::ValueBuilder; + using vector_broadcast = edsc::intrinsics::ValueBuilder; + using vector_type_cast = edsc::intrinsics::ValueBuilder; assert(succeeded(vectorizeLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); - auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); - using edsc::intrinsics::std_load; - using edsc::intrinsics::std_store; - using vector_contract = edsc::intrinsics::ValueBuilder; - using vector_type_cast = edsc::intrinsics::ValueBuilder; - auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); - auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); - auto vC = std_load(vectorMemRefC); - auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - std_store(vRes, vectorMemRefC); + + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg.fill as vector.broadcast: " + << *op << ":\n"); + auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0)); + auto dstVec = std_load(dstMemrefVec); + auto resVec = vector_broadcast(dstVec, fillOp.value()); + std_store(resVec, dstMemrefVec); + } else { + // Vectorize other ops as vector contraction (currently only matmul). + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); + auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); + auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); + auto vC = std_load(vectorMemRefC); + auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), + linalgOp.iterator_types()); + std_store(vRes, vectorMemRefC); + } return {}; } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -205,6 +205,13 @@ // CHECK: vector.contract {{.*}} : // vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { + linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, f32 + return +} +// CHECK-LABEL: func @test_vectorize_fill +// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> + func @fma(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -105,6 +105,12 @@ HasLinalgTransformMarker<"VECTORIZE">, PreconditionVectorizeLinalgOp ]>>)]>; +def : Pattern<(FillOp:$op $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp + ]>>)]>; def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), [(VectorizeLinalgOp)], [(Constraint>)]>; + //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===//