diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -86,10 +86,11 @@ //===----------------------------------------------------------------------===// // Linalg to vector patterns precondition and DRR. //===----------------------------------------------------------------------===// -def PreconditionVectorizeGenericLinalgOp : CPred< - "succeeded(vectorizeGenericLinalgOpPrecondition(op))">; -def VectorizeGenericLinalgOp : NativeCodeCall< - "vectorizeGenericLinalgOp($_builder, op)">; +def PreconditionVectorizeLinalgOp : CPred< + "succeeded(vectorizeLinalgOpPrecondition(op))">; +def VectorizeLinalgOp : NativeCodeCall< + "vectorizeLinalgOp($_builder, op)">; + //===----------------------------------------------------------------------===// // Linalg generic permutation patterns precondition and DRR. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -79,9 +79,9 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); /// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeGenericLinalgOpPrecondition(Operation *op); -SmallVector vectorizeGenericLinalgOp(PatternRewriter &rewriter, - Operation *op); +LogicalResult vectorizeLinalgOpPrecondition(Operation *op); +SmallVector vectorizeLinalgOp(PatternRewriter &rewriter, + Operation *op); /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` /// and `iterator_types` permutated according to `permutation`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -158,10 +158,20 @@ genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -LogicalResult -mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { - // TODO(ntv): This is in fact much more general than just vectorization for - // matmul ops. +// TODO(ntv): This is in fact much more general than just vectorization for +// matmul ops. +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { + auto linalgOp = cast(op); + // All types must be static shape to go to vector. + for (Value operand : linalgOp.getInputsAndOutputBuffers()) + if (!operand.getType().cast().hasStaticShape()) + return failure(); + for (Type outputTensorType : linalgOp.getOutputTensorTypes()) + if (!outputTensorType.cast().hasStaticShape()) + return failure(); + if (isa(op)) + return success(); + auto genericOp = dyn_cast(op); if (!genericOp || !isMatmul(genericOp)) return failure(); @@ -179,30 +189,29 @@ return success(); } -SmallVector -mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, - Operation *op) { +SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, + Operation *op) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Rewrite linalg op as vector.contract: " << *op << ":\n"); - assert(succeeded(vectorizeGenericLinalgOpPrecondition(op)) && + assert(succeeded(vectorizeLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); - auto genericOp = cast(op); - assert(genericOp.hasBufferSemantics() && + auto linalgOp = cast(op); + assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; using vector_contract = edsc::intrinsics::ValueBuilder; using vector_type_cast = edsc::intrinsics::ValueBuilder; - auto vA = std_load(vector_type_cast(genericOp.getInput(0))); - auto vB = std_load(vector_type_cast(genericOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); + auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); + auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); + auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); - auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), - genericOp.iterator_types()); + auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), + linalgOp.iterator_types()); std_store(vRes, vectorMemRefC); return {}; } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -173,7 +173,7 @@ affine_map<(m, n, k) -> (m, n)> ], iterator_types = ["parallel", "parallel", "reduction"], - __internal_linalg_transform__ = "_marked_matmul_" + __internal_linalg_transform__ = "VECTORIZE" } func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { @@ -185,7 +185,6 @@ } : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32> return } - // CHECK-LABEL: func @vectorization_test // CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref> // CHECK: load %{{.*}}[] : memref> @@ -195,6 +194,17 @@ // CHECK: load %{{.*}}[] : memref> // CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> // CHECK: store %{{.*}}, %{{.*}}[] : memref> + +func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, + %C: memref<8x32xf32>) { + linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "VECTORIZE"} : + memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32> + return +} +// CHECK-LABEL: func @vectorization_test_2 +// CHECK: vector.contract {{.*}} : +// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> + func @fma(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 @@ -213,7 +223,6 @@ library_call = "linalg_matmul", iterator_types = ["parallel", "parallel", "reduction"] } - func @permute_generic(%A: memref, %B: memref, %C: memref) { diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -99,11 +99,17 @@ //===----------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// +def : Pattern<(MatmulOp:$op $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp + ]>>)]>; def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), - [(VectorizeGenericLinalgOp)], - [(Constraint, - PreconditionVectorizeGenericLinalgOp + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp ]>>)]>; //===----------------------------------------------------------------------===// @@ -135,4 +141,5 @@ HasOperandsOfType<"SubViewOp">, HasLinalgTransformMarker<"_promote_views_">]>> )]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS