Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -84,6 +84,227 @@ hasMultiplyAddBody(genericOp.region())); } +// Allow list of operations supported by vectorization. +static bool hasOnlySupportedElementwiseOp(Region &r) { + if (!llvm::hasSingleElement(r)) + return false; + for (Operation &op : r.front()) { + if (!isa(op)) + return false; + } + return true; +} + +// Return true if the op is an element-wise linalg op. +static bool isElementwise(Operation *op) { + auto genericOp = dyn_cast(op); + if (!genericOp) + return false; + if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) + return false; + for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { + if (!genericOp.getOutputIndexingMap(i).isIdentity()) + return false; + } + for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { + if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) + return false; + } + return hasOnlySupportedElementwiseOp(genericOp.getRegion()); +} + +static VectorType extractVectorTypeFromScalarView(Value v) { + MemRefType mt = v.getType().cast(); + return mt.getShape().empty() + ? VectorType() + : VectorType::get(mt.getShape(), mt.getElementType()); +} + +static Value loadVector(OpBuilder &builder, Value memref) { + edsc::ScopedContext scope(builder, builder.getInsertionPoint()->getLoc()); + auto memrefType = memref.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { + SmallVector indices(memrefType.getRank(), std_constant_index(0)); + return vector_transfer_read(vectorType, memref, indices); + } + return std_load(memref); +} + +static void storeVector(OpBuilder &builder, Value value, Value memref) { + edsc::ScopedContext scope(builder, builder.getInsertionPoint()->getLoc()); + auto memrefType = memref.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { + if (vectorType != value.getType()) + value = vector_broadcast(vectorType, value); + SmallVector indices(memrefType.getRank(), std_constant_index(0)); + vector_transfer_write(value, memref, indices); + } else { + std_store(value, memref); + } +} + +namespace { +// Transforms scalar operations into their vectorized counterparts, +// while using the provided generic op to map: +// * Its arguments to reads from the views of the generic op. +// * linalg.yield ops to writes to the views of the generic op. +class GenericVectorizer { +public: + GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic) + : builder(builder), generic(generic) {} + + // Takes a scalar operation and builds its vectorized counterpart or + // counterparts using the underlying builder. + // If operands of the scalar operation are referring to previously vectorized + // operations, then in their vectorized form these operands will be referring + // to previous vectorization results. + void vectorize(Operation &scalarOp) { + auto yieldOp = dyn_cast(scalarOp); + if (yieldOp) { + for (auto outputAndMemref : + llvm::zip(yieldOp.values(), generic.getOutputBuffers())) { + Value vectorValue = vectorize(std::get<0>(outputAndMemref)); + storeVector(builder, vectorValue, std::get<1>(outputAndMemref)); + } + } else { + Operation *vectorOp = uncachedVectorize(scalarOp); + assert(scalarOp.getNumResults() == vectorOp->getNumResults()); + for (auto result : + llvm::zip(scalarOp.getResults(), vectorOp->getResults())) { + valueCache[std::get<0>(result)] = std::get<1>(result); + } + } + } + +private: + // Transforms a scalar value into its vectorized counterpart, recursively + // vectorizing operations as necessary using the underlying builder. + // Keeps track of previously vectorized values and reuses vectorization + // results if these values come up again. + Value vectorize(Value scalarValue) { + // Don't vectorize values coming from outside the region. + if (scalarValue.getParentRegion() != &generic.region()) + return scalarValue; + auto vectorValueIt = valueCache.find(scalarValue); + if (vectorValueIt != valueCache.end()) + return vectorValueIt->second; + + // If the value is from the region but not in the cache it means it is a + // block argument. + auto scalarArg = scalarValue.cast(); + assert(scalarArg.getOwner() == &generic.region().front()); + Value vector_arg = + generic.getInputsAndOutputBuffers()[scalarArg.getArgNumber()]; + Value vectorResult = loadVector(builder, vector_arg); + valueCache[scalarArg] = vectorResult; + return vectorResult; + } + + // If the operands have different shape broadcast the smallest shape to make + // them match. + void broacastIfNeeded(Value &a, Value &b) { + if (a.getType() == b.getType()) + return; + auto aVecType = a.getType().dyn_cast(); + auto bVecType = b.getType().dyn_cast(); + assert(aVecType || bVecType); + if (aVecType && bVecType && aVecType.getShape() == bVecType.getShape()) + return; + if (aVecType == nullptr || + (bVecType != nullptr && + aVecType.getNumElements() < bVecType.getNumElements())) { + VectorType newType = + VectorType::get(bVecType.getShape(), + aVecType ? aVecType.getElementType() : a.getType()); + a = builder.create( + builder.getInsertionPoint()->getLoc(), newType, a); + + } else { + VectorType newType = + VectorType::get(aVecType.getShape(), + bVecType ? bVecType.getElementType() : b.getType()); + b = builder.create( + builder.getInsertionPoint()->getLoc(), newType, b); + } + } + + // Takes a scalar operation and builds its vectorized counterpart or + // counterparts using underlying builder without involving any caches. + Operation *uncachedVectorize(Operation &base_scalarOp) { + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeBinaryArithmeticOp(scalarOp); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + Value vectorLhs = vectorize(scalarOp.lhs()); + Value vectorRhs = vectorize(scalarOp.rhs()); + broacastIfNeeded(vectorLhs, vectorRhs); + return builder.create(scalarOp.getLoc(), scalarOp.predicate(), + vectorLhs, vectorRhs); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return builder.create(scalarOp.getLoc(), scalarOp.getValue()); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeBinaryArithmeticOp(scalarOp); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeUnaryArithmeticOp(scalarOp); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeBinaryArithmeticOp(scalarOp); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeUnaryArithmeticOp(scalarOp); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + Value vector_condition = vectorize(scalarOp.condition()); + Value vector_true_value = vectorize(scalarOp.true_value()); + Value vector_false_value = vectorize(scalarOp.false_value()); + broacastIfNeeded(vector_true_value, vector_false_value); + broacastIfNeeded(vector_condition, vector_false_value); + broacastIfNeeded(vector_true_value, vector_condition); + return builder.create(scalarOp.getLoc(), vector_condition, + vector_true_value, vector_false_value); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeBinaryArithmeticOp(scalarOp); + } + if (auto scalarOp = dyn_cast(base_scalarOp)) { + return uncachedVectorizeUnaryArithmeticOp(scalarOp); + } + llvm_unreachable("Unsupported op"); + } + + template + Operation *uncachedVectorizeBinaryArithmeticOp(T scalarOp) { + Value vectorLhs = vectorize(scalarOp.lhs()); + Value vectorRhs = vectorize(scalarOp.rhs()); + broacastIfNeeded(vectorLhs, vectorRhs); + return builder.create(scalarOp.getLoc(), vectorLhs, vectorRhs); + } + + template + Operation *uncachedVectorizeUnaryArithmeticOp(T scalarOp) { + Value vectorOperand = vectorize(scalarOp.operand()); + return builder.create(scalarOp.getLoc(), vectorOperand); + } + + OpBuilder &builder; + linalg::GenericOp generic; + llvm::DenseMap valueCache; +}; +} // namespace + +// Replaces elementwise linalg.generic ops with their bodies with scalar +// operations from these bodies promoted to vector operations. +static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) { + GenericVectorizer vectorizer(builder, op); + for (Operation &scalarOp : op.region().front()) { + vectorizer.vectorize(scalarOp); + } +} + LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -96,7 +317,8 @@ if (isa(op)) return success(); - + if (isElementwise(op)) + return success(); return isContraction(op); } @@ -108,28 +330,11 @@ edsc::ScopedContext scope(builder, op->getLoc()); // In the case of 0-D memrefs, return null and special case to scalar load or // store later. - auto extractVectorTypeFromScalarView = [](Value v) { - MemRefType mt = v.getType().cast(); - return mt.getShape().empty() - ? VectorType() - : VectorType::get(mt.getShape(), mt.getElementType()); - }; if (auto fillOp = dyn_cast(op)) { // Vectorize fill as a vector.broadcast. LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg.fill as vector.broadcast: " << *op); - Value viewOutput = fillOp.output(); - if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { - auto vecType = - VectorType::get(fillOp.getOutputBufferType(0).getShape(), - fillOp.getOutputBufferType(0).getElementType()); - Value vector = vector_broadcast(vecType, fillOp.value()); - Value zero = std_constant_index(0); - SmallVector indicesOutput(outputType.getRank(), zero); - vector_transfer_write(vector, viewOutput, indicesOutput); - } else { - std_store(fillOp.value(), viewOutput); - } + storeVector(builder, fillOp.value(), fillOp.output()); return; } if (auto copyOp = dyn_cast(op)) { @@ -138,36 +343,19 @@ << "Rewrite linalg.copy as vector.transfer_read + " "vector.transfer_write: " << *op); - Value zero = std_constant_index(0); - Value viewInput = copyOp.input(); - Value viewOutput = copyOp.output(); - Value vector; - if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) { - SmallVector indicesInput(inputType.getRank(), zero); - if (copyOp.inputPermutation()) - vector = vector_transfer_read( - extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput, - copyOp.inputPermutation().getValue()); - else - vector = - vector_transfer_read(extractVectorTypeFromScalarView(viewInput), - viewInput, indicesInput); - } else { - vector = std_load(viewInput).value; - } - if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { - SmallVector indicesOutput(outputType.getRank(), zero); - if (copyOp.outputPermutation()) - vector_transfer_write(vector, viewOutput, indicesOutput, - copyOp.outputPermutation().getValue()); - else - vector_transfer_write(vector, viewOutput, indicesOutput); - } else { - std_store(vector, viewOutput); - } + Value vector = loadVector(builder, copyOp.input()); + storeVector(builder, vector, copyOp.output()); return; } + if (isElementwise(op)) { + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg op as vector.transfer_read + " + "vector_op + vector.transfer_write: " + << *op); + return vectorizeElementwise(cast(op), builder); + } + assert(succeeded(isContraction(op)) && "Expected contraction"); // Vectorize other ops as vector contraction. Index: mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, @@ -26,40 +25,3 @@ // CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> // // CHECK: linalg.copy - -// VECTOR-CONTRACTION-LABEL: contraction_dot -func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { - // VECTOR-CONTRACTION: vector.contract - // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32 - linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) - outs(%C: memref) - return -} - -// VECTOR-CONTRACTION-LABEL: contraction_matvec -func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { - // VECTOR-CONTRACTION: vector.contract - // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32> - linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) - outs(%C: memref<1584xf32>) - return -} - -// VECTOR-CONTRACTION-LABEL: contraction_matmul -func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { - // VECTOR-CONTRACTION: vector.contract - // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32> - linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) - outs(%C: memref<1584x1584xf32>) - return -} - -// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul -func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { - // VECTOR-CONTRACTION: vector.contract - // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32> - linalg.batch_matmul - ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) - outs(%C: memref<1584x1584x1584xf32>) - return -} Index: mlir/test/Dialect/Linalg/transform-patterns.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-patterns.mlir +++ mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -5,9 +5,7 @@ // CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // Map corresponding to a 2D memory access where the stride along all dims are unknown. // CHECK-DAG: #[[$STRIDED_2D:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)> // CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> @@ -92,99 +90,6 @@ // CHECK: ins({{.*}}, {{.*}}: memref, memref) // CHECK: outs({{.*}}: memref) -#matmul_trait = { - args_in = 2, - args_out = 1, - indexing_maps = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> - ], - iterator_types = ["parallel", "parallel", "reduction"], - __internal_linalg_transform__ = "VECTORIZE" -} -func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, - %C: memref<8x32xf32>) { - linalg.generic #matmul_trait - ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) - outs(%C : memref<8x32xf32>) { - ^bb(%a: f32, %b: f32, %c: f32) : - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 - linalg.yield %e : f32 - } - return -} -// CHECK-LABEL: func @vectorization_test -// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> -// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> -// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> -// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> - -func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, - %C: memref<8x32xi32>) { - linalg.generic #matmul_trait - ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) - outs(%C : memref<8x32xi32>) { - ^bb(%a: i32, %b: i32, %c: i32) : - %d = muli %a, %b: i32 - %e = addi %c, %d: i32 - linalg.yield %e : i32 - } - return -} -// CHECK-LABEL: func @vectorization_test_integer -// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> -// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> -// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> -// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> - -func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, - %C: memref<8x32xf32>) { - linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"} - ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) - outs(%C: memref<8x32xf32>) - return -} -// CHECK-LABEL: func @vectorization_test_2 -// CHECK: vector.contract {{.*}} : -// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> - -func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { - linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, f32 - return -} -// CHECK-LABEL: func @test_vectorize_fill -// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> -// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - -func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { - linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref, f32 - return -} -// CHECK-LABEL: func @test_vectorize_fill -// CHECK-SAME: (%[[M:.*]]: memref, %[[V:.*]]: f32) -// CHECK: store %[[V]], %[[M]][] : memref - -func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { - linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32> - return -} -// CHECK-LABEL: func @test_vectorize_copy -// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> -// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - -func @test_vectorize_copy_scalar(%A : memref, %B : memref) { - linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref, memref - return -} -// CHECK-LABEL: func @test_vectorize_copy_scalar -// CHECK: %[[V:.*]] = load {{.*}} : memref -// CHECK: store %[[V]], {{.*}} : memref - - #matmul_accesses = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, Index: mlir/test/Dialect/Linalg/vectorization.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Linalg/vectorization.mlir @@ -0,0 +1,210 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns | FileCheck %s + +// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contraction_dot +func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { + // CHECK: vector.contract + // CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32 + linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) + outs(%C: memref) + return +} + +// CHECK-LABEL: contraction_matvec +func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { + // CHECK: vector.contract + // CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32> + linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) + outs(%C: memref<1584xf32>) + return +} + +// CHECK-LABEL: contraction_matmul +func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { + // CHECK: vector.contract + // CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32> + linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) + outs(%C: memref<1584x1584xf32>) + return +} + +// CHECK-LABEL: contraction_batch_matmul +func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { + // CHECK: vector.contract + // CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32> + linalg.batch_matmul + ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) + outs(%C: memref<1584x1584x1584xf32>) + return +} + +#matmul_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] +} +func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, + %C: memref<8x32xf32>) { + linalg.generic #matmul_trait + ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) + outs(%C : memref<8x32xf32>) { + ^bb(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + linalg.yield %e : f32 + } + return +} +// CHECK-LABEL: func @vectorization_test +// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> +// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> +// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> + +func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, + %C: memref<8x32xi32>) { + linalg.generic #matmul_trait + ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) + outs(%C : memref<8x32xi32>) { + ^bb(%a: i32, %b: i32, %c: i32) : + %d = muli %a, %b: i32 + %e = addi %c, %d: i32 + linalg.yield %e : i32 + } + return +} +// CHECK-LABEL: func @vectorization_test_integer +// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> +// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> +// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> + +func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, + %C: memref<8x32xf32>) { + linalg.matmul + ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) + outs(%C: memref<8x32xf32>) + return +} +// CHECK-LABEL: func @vectorization_test_2 +// CHECK: vector.contract {{.*}} : +// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> + +func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { + linalg.fill(%A, %arg0) : memref<8x16xf32>, f32 + return +} +// CHECK-LABEL: func @test_vectorize_fill +// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> +// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + +func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { + linalg.fill(%A, %arg0) : memref, f32 + return +} +// CHECK-LABEL: func @test_vectorize_fill +// CHECK-SAME: (%[[M:.*]]: memref, %[[V:.*]]: f32) +// CHECK: store %[[V]], %[[M]][] : memref + +func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + linalg.copy(%A, %B) : memref<8x16xf32>, memref<8x16xf32> + return +} +// CHECK-LABEL: func @test_vectorize_copy +// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> +// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + +func @test_vectorize_copy_scalar(%A : memref, %B : memref) { + linalg.copy(%A, %B) : memref, memref + return +} +// CHECK-LABEL: func @test_vectorize_copy_scalar +// CHECK: %[[V:.*]] = load {{.*}} : memref +// CHECK: store %[[V]], {{.*}} : memref + +func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>, + %arg2: memref<256xf32>, %i: f32) { + %c1_f32 = constant 1.0 : f32 + linalg.generic { + args_in = 0 : i64, + args_out = 10 : i64, + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1, %arg2: memref<4x256xf32>, memref<256xf32>) + outs( + %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 : + memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, + memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, + memref<4x256xf32>, memref<4x256xf32>) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, + %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32, + %arg14 : f32): + %6 = addf %arg4, %arg6 : f32 + %7 = cmpf "ogt", %arg3, %arg6 : f32 + %8 = constant 2.0 : f32 + %9 = divf %arg5, %i : f32 + %10 = exp2 %arg5 : f32 + %11 = mulf %arg5, %8 : f32 + %12 = rsqrt %arg5 : f32 + %13 = select %7, %arg5, %arg6 : f32 + %14 = subf %arg5, %arg6 : f32 + %15 = tanh %arg5 : f32 + linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32 + } + return +} + +// CHECK-LABEL: func @generic_vectorize +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32) +// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> +// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]] +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> +// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> +// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]] +// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> +// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> +// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> +// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> +// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> +// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> +// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V1]] : vector<4x256xf32> +// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> +// CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] Index: mlir/test/lib/Transforms/TestLinalgTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -71,7 +71,7 @@ "Test a fused pass that forwards linalg.copy to vector.transfer"), llvm::cl::init(false)}; Option testGenericToVectorPattern{ - *this, "test-contraction-to-vector-patterns", + *this, "test-linalg-to-vector-patterns", llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; @@ -464,14 +464,15 @@ applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); } -static void applyContractionToVectorPatterns(FuncOp funcOp) { +static void applyLinalgToVectorPatterns(FuncOp funcOp) { OwningRewritePatternList patterns; - patterns.insert, - LinalgVectorizationPattern, - LinalgVectorizationPattern, - LinalgVectorizationPattern, - LinalgVectorizationPattern, - LinalgVectorizationPattern>(funcOp.getContext()); + patterns.insert< + LinalgVectorizationPattern, + LinalgVectorizationPattern, + LinalgVectorizationPattern, + LinalgVectorizationPattern, LinalgVectorizationPattern, + LinalgVectorizationPattern, LinalgVectorizationPattern, + LinalgVectorizationPattern>(funcOp.getContext()); applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -516,7 +517,7 @@ if (testVectorTransferForwardingPatterns) return applyVectorTransferForwardingPatterns(getFunction()); if (testGenericToVectorPattern) - return applyContractionToVectorPatterns(getFunction()); + return applyLinalgToVectorPatterns(getFunction()); if (testAffineMinSCFCanonicalizationPatterns) return applyAffineMinSCFCanonicalizationPatterns(getFunction()); }