diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -224,9 +224,11 @@ } static bool verifyOutputShape( - VectorType lhsType, VectorType rhsType, Type accType, Type resType, + ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, + Type resType, const std::vector> &contractingDimMap, - const std::vector> &batchDimMap) { + const std::vector> &batchDimMap, + VectorType &expected) { DenseSet lhsContractingDimSet; DenseSet rhsContractingDimSet; for (auto &dimPair : contractingDimMap) { @@ -257,7 +259,6 @@ // No batch or free dimension implies a scalar result. if (resType.isa() || accType.isa()) return false; - } else { // At least one batch or free dimension implies a vector result. auto resVectorType = resType.dyn_cast(); @@ -265,15 +266,42 @@ if (!resVectorType || !accVectorType) return false; - // Verify dimension from 'resType' against 'expectedResultDims'. - if (resVectorType.getShape().size() != expectedResultDims.size() || - accVectorType.getShape().size() != expectedResultDims.size()) - return false; - for (int64_t i = 0, e = resVectorType.getRank(); i < e; ++i) { - if (resVectorType.getDimSize(i) != expectedResultDims[i] || - accVectorType.getDimSize(i) != expectedResultDims[i]) - return false; + // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector + // types fully define the result vector type. This assumes the affine maps + // are well-formed, which must have been verified already. + MLIRContext *ctx = op.getContext(); + AffineMap lhsMap = op.getIndexingMaps()[0]; + AffineMap rhsMap = op.getIndexingMaps()[1]; + SmallVector extents(lhsMap.getNumInputs()); + extents.reserve(lhsMap.getNumInputs()); + for (auto pair : + {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { + VectorType v = pair.first; + auto map = pair.second; + for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { + unsigned pos = map.getResult(idx).cast().getPosition(); + if (!extents[pos]) + extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); + } } + assert(llvm::all_of(extents, [](AffineExpr e) { return e; })); + + AffineMap resMap = op.getIndexingMaps()[2]; + auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), + /*symCount=*/0, extents, ctx); + // Compose the resMap with the extentsMap, which is a constant map. + AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); + assert(llvm::all_of(expectedMap.getResults(), [](AffineExpr e) { + return e.isa(); + })); + // Extract the expected shape and build the type. + auto expectedShape = llvm::to_vector<4>( + llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { + return e.cast().getValue(); + })); + expected = VectorType::get(expectedShape, resVectorType.getElementType()); + if (resVectorType != expected || accVectorType != expected) + return false; } return true; } @@ -329,9 +357,14 @@ return op.emitOpError("invalid batch dimension map"); // Verify 'accType' and 'resType' shape. - if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap, - batchDimMap)) - return op.emitOpError("invalid accumulator/result vector shape"); + VectorType expected; + if (!verifyOutputShape(op, lhsType, rhsType, accType, resType, + contractingDimMap, batchDimMap, expected)) { + if (!expected) + return op.emitOpError("invalid accumulator/result vector shape"); + return op.emitOpError("invalid accumulator/result vector shape, expected: ") + << expected; + } // Verify that either two vector masks are set or none are set. auto lhsMaskType = op.getLHSVectorMaskType(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1503,34 +1503,8 @@ /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 /// ``` /// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and -/// the vector.contract op is a row-major matrix multiply. -void ContractionOpToOuterProductOpLowering::rewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { - VectorType lhsType = op.getLhsType(); - // TODO(ntv) other modes. - // We know we are in row-major. - bool transposeLhs = false; - unsigned reductionSize = - transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1]; - - // If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to - // transpose it to extract the proper vector. Otherwise, just take - // the lhs. - Value lhs = transposeLhs - ? op.lhs() - : rewriter.create( - op.getLoc(), op.lhs(), ArrayRef{1, 0}); - Value res = op.acc(); - // ExtractOp does not allow dynamic indexing, we must unroll explicitly. - for (unsigned k = 0; k < reductionSize; ++k) { - Value a = rewriter.create(op.getLoc(), lhs, k); - Value b = rewriter.create(op.getLoc(), op.rhs(), k); - res = rewriter.create(op.getLoc(), a, b, res); - } - rewriter.replaceOp(op, res); -} - +/// This only kicks in when VectorTransformsOptions is set to OuterProduct but +/// otherwise supports any layout permutation of the matrix-multiply. LogicalResult ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const { // TODO(ajcbik): implement masks @@ -1538,12 +1512,101 @@ return failure(); if (vectorTransformsOptions.vectorContractLowering != - vector::VectorContractLowering::OuterProduct || - !isRowMajorMatmul(op.indexing_maps())) + vector::VectorContractLowering::OuterProduct) + return failure(); + + // Transpose arguments to make them ready for lowering to OuterProduct. The + // constraint to match is that we must load full rows at a time with + // vector::ExtractOp. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(op.getContext(), m, n, k); + SmallVector perm{1, 0}; + SmallVector maps = op.getIndexingMaps(); + // When lowering to outerproduct we can support all permutations. + if (maps != infer({{m, k}, {k, n}, {m, n}}) && + maps != infer({{m, k}, {n, k}, {m, n}}) && + maps != infer({{k, m}, {k, n}, {m, n}}) && + maps != infer({{k, m}, {n, k}, {m, n}}) && + maps != infer({{m, k}, {k, n}, {n, m}}) && + maps != infer({{m, k}, {n, k}, {n, m}}) && + maps != infer({{k, m}, {k, n}, {n, m}}) && + maps != infer({{k, m}, {n, k}, {n, m}})) return failure(); return success(); } +void ContractionOpToOuterProductOpLowering::rewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + unsigned reductionSize = 0; + VectorType lhsType = op.getLhsType(); + Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); + + // Transpose arguments to make them ready for lowering to OuterProduct. The + // constraint to match is that we must load full rows at a time with + // vector::ExtractOp. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + SmallVector perm{1, 0}; + SmallVector maps = op.getIndexingMaps(); + // First batch of cases, no need to output permute. + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + // This is the classical row-major matmul. Just permute the rhs. + reductionSize = lhsType.getShape()[1]; + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { + // TODO: may be better to fail and use some vector -> scalar reduction. + reductionSize = lhsType.getShape()[1]; + lhs = rewriter.create(loc, lhs, perm); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + // No need to permute anything. + reductionSize = lhsType.getShape()[0]; + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + // Just permute the rhs. + reductionSize = lhsType.getShape()[0]; + rhs = rewriter.create(loc, rhs, perm); + } + // Second batch of cases, reshuffle to avoid output permute. + else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + // This is the classical row-major matmul. Just permute the rhs. + reductionSize = lhsType.getShape()[1]; + Value tmp = rhs; + rhs = rewriter.create(loc, lhs, perm); + lhs = tmp; + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + // TODO: may be better to fail and use some vector -> scalar reduction. + reductionSize = lhsType.getShape()[1]; + Value tmp = rhs; + rhs = rewriter.create(loc, lhs, perm); + lhs = rewriter.create(loc, tmp, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + // No need to permute anything, but still swap lhs and rhs. + reductionSize = lhsType.getShape()[0]; + std::swap(lhs, rhs); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + // Just permute the rhs. + reductionSize = lhsType.getShape()[0]; + Value tmp = lhs; + lhs = rewriter.create(loc, rhs, perm); + rhs = tmp; + } + assert(reductionSize > 0); + + // ExtractOp does not allow dynamic indexing, we must unroll explicitly. + for (unsigned k = 0; k < reductionSize; ++k) { + Value a = rewriter.create(op.getLoc(), lhs, k); + Value b = rewriter.create(op.getLoc(), rhs, k); + res = rewriter.create(op.getLoc(), a, b, res); + } + op.getParentOfType().dump(); + rewriter.replaceOp(op, res); +} + /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -767,6 +767,26 @@ // ----- +#contraction_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} +func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<3x2xf32> +{ +// expected-error@+1 {{invalid accumulator/result vector shape, expected: 'vector<3x2xf32>'}} + %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// ----- + func @create_mask() { %c2 = constant 2 : index %c3 = constant 3 : index diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -160,9 +160,11 @@ indexing_maps = #contraction_accesses0, iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] } -#contraction_accesses1 = [ +#contraction_accesses1 = [ // 7, 8, 16, 15 affine_map<(f0, f1, f2, f3, c0, c1) -> (c0, f0, c1, f2)>, + // 8, 16, 7, 5 affine_map<(f0, f1, f2, f3, c0, c1) -> (f1, c1, c0, f3)>, + // 8, 8, 15, 5 affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)> ] #contraction_trait1 = { @@ -172,7 +174,7 @@ } // CHECK-LABEL: contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, - %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, + %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, %arg4 : index) { // Test contraction with batch and contracting dims. // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> @@ -181,16 +183,16 @@ // Test contraction with only contracting dims. In this case the lhs/rhs // dimension of size 8 will be considered a parallel dim for lhs/rhs and will // appear twice in the output. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 - : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with optional vector mask arguments. %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask - : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s --dump-input-on-failure // RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX --dump-input-on-failure -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT --dump-input-on-failure +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 -mlir-disable-threading +//| FileCheck %s --check-prefix=OUTERPRODUCT --dump-input-on-failure #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -681,3 +682,219 @@ %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> return %0 : vector<2x3xi1> } + +#matmat_accesses_0 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_0 = { + indexing_maps = #matmat_accesses_0, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_0 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> +func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_1 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_1 = { + indexing_maps = #matmat_accesses_1, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_1 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> +func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_2 = [ + affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_2 = { + indexing_maps = #matmat_accesses_2, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_2 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> +func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_3 = [ + affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_3 = { + indexing_maps = #matmat_accesses_3, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_3 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> +func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_4 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_4 = { + indexing_maps = #matmat_accesses_4, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_4 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> +func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_5 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_5 = { + indexing_maps = #matmat_accesses_5, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_5 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> +func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_6 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_6 = { + indexing_maps = #matmat_accesses_6, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_6 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> +func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_7 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_7 = { + indexing_maps = #matmat_accesses_7, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// OUTERPRODUCT-LABEL: func @matmul_7 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> +func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +}