diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -53,8 +53,6 @@ Matmul = 1, /// Lower to `vector.outerproduct`. OuterProduct = 2, - /// Lower to series of AXPY chained through FMA. - AXPY = 3, }; /// Enum to control the lowering of `vector.transpose` operations. enum class VectorTransposeLowering { diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -177,33 +177,6 @@ vector::VectorTransformsOptions vectorTransformsOptions; }; -/// Progressive lowering of a `vector.contract %a, %b, %c` with -/// matvec semantics to series of AXPY operations that are chained -/// through FMA operations. -/// -/// This only kicks in when VectorTransformsOptions is set to AXPY. -// -// TODO: this is very similar, but not quite the same as the outerproduct -// lowering above; merge the two? -class ContractionOpToAXPYLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - ContractionOpToAXPYLowering( - vector::VectorTransformsOptions vectorTransformsOptions, - MLIRContext *context) - : OpRewritePattern(context), - vectorTransformsOptions(vectorTransformsOptions) {} - - LogicalResult match(vector::ContractionOp op) const override; - void rewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformsOptions; -}; - /// Progressive lowering of ContractionOp. /// /// One: diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1638,7 +1638,7 @@ /// This only kicks in when VectorTransformsOptions is set to OuterProduct but /// otherwise supports any layout permutation of the matrix-multiply. LogicalResult -ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const { +ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const { // TODO: implement masks if (llvm::size(op.masks()) != 0) return failure(); @@ -1647,30 +1647,46 @@ vector::VectorContractLowering::OuterProduct) return failure(); - // Transpose arguments to make them ready for lowering to OuterProduct. The - // constraint to match is that we must load full rows at a time with - // vector::ExtractOp. + // Determine if the parallel/reduction structure matches something + // that can be expressed a reduction_size unrolled sequence. using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(op.getContext(), m, n, k); auto iteratorTypes = op.iterator_types().getValue(); - if (!isParallelIterator(iteratorTypes[0]) || - !isParallelIterator(iteratorTypes[1]) || - !isReductionIterator(iteratorTypes[2])) - return failure(); SmallVector maps = op.getIndexingMaps(); - // When lowering to outerproduct we can support all permutations. - if (maps != infer({{m, k}, {k, n}, {m, n}}) && - maps != infer({{m, k}, {n, k}, {m, n}}) && - maps != infer({{k, m}, {k, n}, {m, n}}) && - maps != infer({{k, m}, {n, k}, {m, n}}) && - maps != infer({{m, k}, {k, n}, {n, m}}) && - maps != infer({{m, k}, {n, k}, {n, m}}) && - maps != infer({{k, m}, {k, n}, {n, m}}) && - maps != infer({{k, m}, {n, k}, {n, m}})) - return failure(); - return success(); + if (isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2])) { + // + // Two outer parallel, one inner reduction (matmat flavor). + // When lowering to outerproduct we can support all permutations. + // + if (maps != infer({{m, k}, {k, n}, {m, n}}) && + maps != infer({{m, k}, {n, k}, {m, n}}) && + maps != infer({{k, m}, {k, n}, {m, n}}) && + maps != infer({{k, m}, {n, k}, {m, n}}) && + maps != infer({{m, k}, {k, n}, {n, m}}) && + maps != infer({{m, k}, {n, k}, {n, m}}) && + maps != infer({{k, m}, {k, n}, {n, m}}) && + maps != infer({{k, m}, {n, k}, {n, m}})) + return failure(); + return success(); + } else if (isParallelIterator(iteratorTypes[0]) && + isReductionIterator(iteratorTypes[1])) { + // + // One outer parallel, one inner reduction (matvec flavor) + // See if a series of AXPY operations chained through FMA operations + // could replace the default DOT implementation. + // + if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec + maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec + maps != infer({{n}, {m, n}, {m}}) && // vec-mat + maps != infer({{n}, {n, m}, {m}})) // vec-mat-trans + return failure(); + return success(); + } + return failure(); } void ContractionOpToOuterProductOpLowering::rewrite( @@ -1680,61 +1696,87 @@ VectorType lhsType = op.getLhsType(); Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); - // Transpose arguments to make them ready for lowering to OuterProduct. The - // constraint to match is that we must load full rows at a time with - // vector::ExtractOp. + // Set up the parallel/reduction structure in right form. using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); SmallVector perm{1, 0}; + auto iteratorTypes = op.iterator_types().getValue(); SmallVector maps = op.getIndexingMaps(); - // First batch of cases, no need to output permute. - if (maps == infer({{m, k}, {k, n}, {m, n}})) { - // This is the classical row-major matmul. Just permute the lhs. - reductionSize = lhsType.getDimSize(1); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { - // TODO: may be better to fail and use some vector -> scalar reduction. - reductionSize = lhsType.getDimSize(1); - lhs = rewriter.create(loc, lhs, perm); - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - // No need to permute anything. - reductionSize = lhsType.getDimSize(0); - } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - // Just permute the rhs. - reductionSize = lhsType.getDimSize(0); - rhs = rewriter.create(loc, rhs, perm); - } - // Second batch of cases, reshuffle to avoid output permute. - else if (maps == infer({{m, k}, {k, n}, {n, m}})) { - // This is the classical row-major matmul. Just permute the lhs. - reductionSize = lhsType.getDimSize(1); - Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); - lhs = tmp; - } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { - // TODO: may be better to fail and use some vector -> scalar reduction. - reductionSize = lhsType.getDimSize(1); - Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); - lhs = rewriter.create(loc, tmp, perm); - } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { - // No need to permute anything, but still swap lhs and rhs. - reductionSize = lhsType.getDimSize(0); - std::swap(lhs, rhs); - } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { - // Just permute the rhs. - reductionSize = lhsType.getDimSize(0); - Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); - rhs = tmp; + if (isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2])) { + // + // Two outer parallel, one inner reduction (matmat flavor). + // + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + // This is the classical row-major matmul. Just permute the lhs. + reductionSize = lhsType.getDimSize(1); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { + // TODO: may be better to fail and use some vector -> scalar reduction. + reductionSize = lhsType.getDimSize(1); + lhs = rewriter.create(loc, lhs, perm); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + // No need to permute anything. + reductionSize = lhsType.getDimSize(0); + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + // Just permute the rhs. + reductionSize = lhsType.getDimSize(0); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + // This is the classical row-major matmul. Just permute the lhs. + reductionSize = lhsType.getDimSize(1); + Value tmp = rhs; + rhs = rewriter.create(loc, lhs, perm); + lhs = tmp; + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + // TODO: may be better to fail and use some vector -> scalar reduction. + reductionSize = lhsType.getDimSize(1); + Value tmp = rhs; + rhs = rewriter.create(loc, lhs, perm); + lhs = rewriter.create(loc, tmp, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + // No need to permute anything, but still swap lhs and rhs. + reductionSize = lhsType.getDimSize(0); + std::swap(lhs, rhs); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + // Just permute the rhs. + reductionSize = lhsType.getDimSize(0); + Value tmp = lhs; + lhs = rewriter.create(loc, rhs, perm); + rhs = tmp; + } + } else { + // + // One outer parallel, one inner reduction (matvec flavor) + // + assert(isParallelIterator(iteratorTypes[0]) && + isReductionIterator(iteratorTypes[1])); + if (maps == infer({{m, n}, {n}, {m}})) { + // Case mat-vec: transpose. + reductionSize = lhsType.getDimSize(1); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{n, m}, {n}, {m}})) { + // Case mat-trans-vec: ready to go. + reductionSize = lhsType.getDimSize(0); + } else if (maps == infer({{n}, {m, n}, {m}})) { + // Case vec-mat: swap and transpose. + reductionSize = lhsType.getDimSize(0); + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{n}, {n, m}, {m}})) { + // Case vec-mat-trans: swap and ready to go. + reductionSize = lhsType.getDimSize(0); + std::swap(lhs, rhs); + } } assert(reductionSize > 0); - // ExtractOp does not allow dynamic indexing, we must unroll explicitly. - for (unsigned k = 0; k < reductionSize; ++k) { + // Unroll outer-products along reduction. + for (int64_t k = 0; k < reductionSize; ++k) { Value a = rewriter.create(op.getLoc(), lhs, k); Value b = rewriter.create(op.getLoc(), rhs, k); res = rewriter.create(op.getLoc(), a, b, res); @@ -1742,88 +1784,6 @@ rewriter.replaceOp(op, res); } -/// Progressive lowering of a `vector.contract %a, %b, %c` with -/// matvec semantics to series of AXPY operations that are chained -/// through FMA operations. -/// -/// This only kicks in when VectorTransformsOptions is set to AXPY. -// -// TODO: this is very similar, but not quite the same as the outerproduct -// lowering above; merge the two? -LogicalResult -ContractionOpToAXPYLowering::match(vector::ContractionOp op) const { - // TODO: implement masks - if (llvm::size(op.masks()) != 0) - return failure(); - - if (vectorTransformsOptions.vectorContractLowering != - vector::VectorContractLowering::AXPY) - return failure(); - - auto iteratorTypes = op.iterator_types().getValue(); - if (!isParallelIterator(iteratorTypes[0]) || - !isReductionIterator(iteratorTypes[1])) - return failure(); - - // See if a series of AXPY operations chained through FMA operations - // could replace the default DOT implementation. - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr m, n; - bindDims(op.getContext(), m, n); - SmallVector maps = op.getIndexingMaps(); - if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec - maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec - maps != infer({{n}, {m, n}, {m}}) && // vec-mat - maps != infer({{n}, {n, m}, {m}})) // vec-mat-trans - return failure(); - return success(); -} - -void ContractionOpToAXPYLowering::rewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - VectorType lhsType = op.getLhsType(); - Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); - - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr m, n; - bindDims(op.getContext(), m, n); - SmallVector perm{1, 0}; - // - SmallVector maps = op.getIndexingMaps(); - int64_t reductionSize = 0; - if (maps == infer({{m, n}, {n}, {m}})) { - // Case mat-vec: transpose. - reductionSize = lhsType.getDimSize(1); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{n, m}, {n}, {m}})) { - // Case mat-trans-vec: ready to go. - reductionSize = lhsType.getDimSize(0); - } else if (maps == infer({{n}, {m, n}, {m}})) { - // Case vec-mat: swap and transpose. - reductionSize = lhsType.getDimSize(0); - std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{n}, {n, m}, {m}})) { - // Case vec-mat-trans: swap and ready to go. - reductionSize = lhsType.getDimSize(0); - std::swap(lhs, rhs); - } - assert(reductionSize > 0); - - // A direct series of AXPY operations, chained through FMA. - Type resType = op.getResultType(); - for (int64_t k = 0; k < reductionSize; ++k) { - Value a = rewriter.create(loc, lhs, k); - Value s = rewriter.create(loc, rhs, k); - Value b = rewriter.create(loc, resType, s); - res = rewriter.create(loc, a, b, res); - } - rewriter.replaceOp(op, res); -} - /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension @@ -1862,9 +1822,6 @@ ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx); if (succeeded(pat2.match(op))) return failure(); - ContractionOpToAXPYLowering pat3(vectorTransformsOptions, ctx); - if (succeeded(pat3.match(op))) - return failure(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); @@ -2050,7 +2007,6 @@ patterns.insert(parameters, context); + ContractionOpToOuterProductOpLowering>(parameters, context); // clang-format on } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-axpy=1 | FileCheck %s +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, @@ -44,27 +44,18 @@ // CHECK-SAME: %[[A:.*0]]: memref> // CHECK-SAME: %[[B:.*1]]: memref> // CHECK-SAME: %[[C:.*2]]: memref> -// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> // CHECK: %[[T0:.*]] = load %[[A]][] : memref> // CHECK: %[[T1:.*]] = load %[[B]][] : memref> // CHECK: %[[T2:.*]] = load %[[C]][] : memref> -// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32> -// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32> -// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32> -// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32> -// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32> -// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32> -// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32> -// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32> -// CHECK: store %[[T18]], %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: store %[[T9]], %[[C]][] : memref> +// CHECK: return func @matvec2x2(%arg0: memref>, %arg1: memref>, %arg2: memref>) { %A = load %arg0[] : memref> @@ -84,13 +75,12 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32> -// CHECK: store %[[T10]], %[[C]][] : memref> +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: store %[[T8]], %[[C]][] : memref> +// CHECK: return func @mattransvec2x2(%arg0: memref>, %arg1: memref>, %arg2: memref>) { %A = load %arg0[] : memref> @@ -105,27 +95,18 @@ // CHECK-SAME: %[[A:.*0]]: memref> // CHECK-SAME: %[[B:.*1]]: memref> // CHECK-SAME: %[[C:.*2]]: memref> -// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> // CHECK: %[[T0:.*]] = load %[[A]][] : memref> // CHECK: %[[T1:.*]] = load %[[B]][] : memref> // CHECK: %[[T2:.*]] = load %[[C]][] : memref> -// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32> -// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32> -// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32> -// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32> -// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32> -// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32> -// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32> -// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32> -// CHECK: store %[[T18]], %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32 +// CHECK: store %[[T9]], %[[C]][] : memref> +// CHECK: return func @vecmat2x2(%arg0: memref>, %arg1: memref>, %arg2: memref>) { %A = load %arg0[] : memref> @@ -145,13 +126,12 @@ // CHECK: %[[T2:.*]] = load %[[C]][] : memref> // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32> -// CHECK: store %[[T10]], %[[C]][] : memref> +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32 +// CHECK: store %[[T8]], %[[C]][] : memref> +// CHECK: return func @vecmattrans2x2(%arg0: memref>, %arg1: memref>, %arg2: memref>) { %A = load %arg0[] : memref> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -59,9 +59,6 @@ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), llvm::cl::init(false)}; - Option lowerToAXPY{*this, "vector-axpy", - llvm::cl::desc("Lower vector.contract to AXPY"), - llvm::cl::init(false)}; void runOnFunction() override { OwningRewritePatternList patterns; @@ -80,8 +77,6 @@ VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix) contractLowering = VectorContractLowering::Matmul; - else if (lowerToAXPY) - contractLowering = VectorContractLowering::AXPY; VectorTransposeLowering transposeLowering = VectorTransposeLowering::EltWise; if (lowerToFlatTranspose)