diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -87,57 +87,6 @@ std::unique_ptr builder; }; -enum class IterType { Parallel, Reduction }; - -inline StringRef toString(IterType t) { - switch (t) { - case IterType::Parallel: - return getParallelIteratorTypeName(); - case IterType::Reduction: - return getReductionIteratorTypeName(); - } - llvm_unreachable("Unsupported IterType"); -} - -/// A StructuredIndexed represents an indexable quantity that is either: -/// 1. a captured value, which is suitable for buffer and tensor operands, or; -/// 2. a captured type, which is suitable for tensor return values. -/// -/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`. -/// It enable an idiomatic syntax for index expressions such as: -/// -/// ``` -/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value), -/// C(buffer_value_or_tensor_type); -/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); -/// ``` -struct StructuredIndexed : public ValueHandle { - StructuredIndexed(Type type) : ValueHandle(type) {} - StructuredIndexed(Value value) : ValueHandle(value) {} - StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {} - StructuredIndexed operator()(ArrayRef indexings) { - return StructuredIndexed(*this, indexings); - } - - ArrayRef getExprs() { return exprs; } - -private: - StructuredIndexed(Type t, ArrayRef indexings) - : ValueHandle(t), exprs(indexings.begin(), indexings.end()) { - assert(t.isa() && "RankedTensor expected"); - } - StructuredIndexed(Value v, ArrayRef indexings) - : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { - assert((v.getType().isa() || - v.getType().isa()) && - "MemRef or RankedTensor expected"); - } - StructuredIndexed(ValueHandle vh, ArrayRef indexings) - : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} - - SmallVector exprs; -}; - inline void defaultRegionBuilder(ArrayRef args) {} /// Build a `linalg.generic` op with the specified `inputs`, `outputs` and @@ -157,7 +106,7 @@ /// restriction output tensor results would need to be reordered, which would /// result in surprising behavior when combined with region definition. Operation *makeGenericLinalgOp( - ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, function_ref)> regionBuilder = defaultRegionBuilder, diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -100,6 +100,19 @@ return res; } +/// Typed representation for loop type strings. +enum class IteratorType { Parallel, Reduction }; + +inline StringRef toString(IteratorType t) { + switch (t) { + case IteratorType::Parallel: + return getParallelIteratorTypeName(); + case IteratorType::Reduction: + return getReductionIteratorTypeName(); + } + llvm_unreachable("Unsupported IteratorType"); +} + } // end namespace mlir #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h @@ -0,0 +1,52 @@ +//===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable interfaces for building structured MLIR +// snippets in a declarative fashion. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ +#define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" + +namespace mlir { +namespace edsc { +namespace ops { + +/// Build a generic vector contraction, that is a `vector.contract` op with +/// specified `iteratorTypes`. The client is responsible for specifying proper +/// indexings when creating the StructuredIndexed. +/// The computation represents a notional (A * B + C) where indexings specify +/// which dimensions are reduced and reordered. +/// Return the result of the `vector.contract` op +/// +/// Prerequisites: +/// A, B and C capture values of proper vector types, and indexing expressions +/// that match semantics of the `vector.contract` op. +Value vector_contraction(StructuredIndexed A, StructuredIndexed B, + StructuredIndexed C, + ArrayRef iteratorTypes); + +/// Build a generic vector contraction that computes a matmul on vectors. +/// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors. +/// +/// Prerequisites: +/// A, B and C capture values of proper vector types. For instance +/// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`. +Value vector_matmul(Value A, Value B, Value C); + +} // namespace ops +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h @@ -0,0 +1,25 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for VectorOps --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using vector_contract = ValueBuilder; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -141,7 +141,11 @@ }]; let builders = [OpBuilder< "Builder *builder, OperationState &result, Value lhs, Value rhs, " - "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; + "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">, + OpBuilder< + "Builder *builder, OperationState &result, Value lhs, Value rhs, " + "Value acc, ArrayRef> indexingExprs, " + "ArrayRef iteratorTypes">]; let extraClassDeclaration = [{ VectorType getLhsType() { return lhs().getType().cast(); diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -17,6 +17,7 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/FoldUtils.h" @@ -493,6 +494,47 @@ mlir::Block *block; }; +/// A StructuredIndexed represents an indexable quantity that is either: +/// 1. a captured value, which is suitable for buffer and tensor operands, or; +/// 2. a captured type, which is suitable for tensor return values. +/// +/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`. +/// It enable an idiomatic syntax for index expressions such as: +/// +/// ``` +/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value), +/// C(buffer_value_or_tensor_type); +/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); +/// ``` +struct StructuredIndexed : public ValueHandle { + StructuredIndexed(Type type) : ValueHandle(type) {} + StructuredIndexed(Value value) : ValueHandle(value) {} + StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {} + StructuredIndexed operator()(ArrayRef indexings) { + return this->hasValue() ? StructuredIndexed(this->getValue(), indexings) + : StructuredIndexed(this->getType(), indexings); + } + + StructuredIndexed(Type t, ArrayRef indexings) + : ValueHandle(t), exprs(indexings.begin(), indexings.end()) { + assert(t.isa() && "RankedTensor expected"); + } + StructuredIndexed(Value v, ArrayRef indexings) + : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { + assert((v.getType().isa() || + v.getType().isa() || + v.getType().isa()) && + "MemRef, RankedTensor or Vector expected"); + } + StructuredIndexed(ValueHandle vh, ArrayRef indexings) + : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} + + ArrayRef getExprs() { return exprs; } + +private: + SmallVector exprs; +}; + template OperationHandle OperationHandle::create(Args... args) { return OperationHandle(ScopedContext::getBuilder() diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -63,6 +63,14 @@ static AffineMap getPermutationMap(ArrayRef permutation, MLIRContext *context); + /// Returns the AffineMap with as many results as `exprs.size()`, as many dims + /// as the largest dim in `exprs` and as many symbols as the largest symbol in + /// `exprs`. + static SmallVector + inferFromExprList(ArrayRef> exprsList); + static SmallVector + inferFromExprList(ArrayRef> exprsList); + MLIRContext *getContext() const; explicit operator bool() { return map != nullptr; } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" @@ -131,20 +132,8 @@ } // namespace edsc } // namespace mlir -static void getMaxDimIndex(ArrayRef structuredIndices, - unsigned &pos) { - for (auto sidx : structuredIndices) { - for (auto expr : sidx.getExprs()) { - expr.walk([&pos](AffineExpr e) { - if (auto d = e.dyn_cast()) - pos = std::max(pos, d.getPosition()); - }); - } - } -} - Operation *mlir::edsc::makeGenericLinalgOp( - ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, function_ref)> regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { @@ -156,20 +145,16 @@ auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); unsigned nOutputs = outputs.size(); - unsigned maxPos = 0; - getMaxDimIndex(inputs, maxPos); - getMaxDimIndex(outputs, maxPos); - // maxPos is 0 indexed, need to turn this into a count (i.e. +1) - unsigned nDims = maxPos + 1; - - SmallVector maps; - maps.reserve(nInputs + nOutputs); - for (auto in : inputs) - maps.push_back( - AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs())); - for (auto out : outputs) - maps.push_back( - AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); + + SmallVector, 4> exprsList; + exprsList.reserve(nInputs + nOutputs); + for (auto structuredIndexed : inputs) + exprsList.emplace_back(structuredIndexed.getExprs().begin(), + structuredIndexed.getExprs().end()); + for (auto structuredIndexed : outputs) + exprsList.emplace_back(structuredIndexed.getExprs().begin(), + structuredIndexed.getExprs().end()); + auto maps = AffineMap::inferFromExprList(exprsList); unsigned nViews = nInputs + nOutputs; SmallVector values; @@ -240,8 +225,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) { - SmallVector iterTypes(O.getExprs().size(), - edsc::IterType::Parallel); + SmallVector iterTypes(O.getExprs().size(), + IteratorType::Parallel); if (O.getType().isa()) { auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 1 && "expected 1 block arguments"); @@ -270,8 +255,8 @@ StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O) { - SmallVector iterTypes(O.getExprs().size(), - edsc::IterType::Parallel); + SmallVector iterTypes(O.getExprs().size(), + IteratorType::Parallel); if (O.getType().isa()) { auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); @@ -315,7 +300,7 @@ bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, macRegionBuilder); @@ -329,7 +314,7 @@ bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(tC); return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, mulRegionBuilder); @@ -343,7 +328,7 @@ bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC), D(tD); return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n}), C({m, n})}, {D({m, n})}, macRegionBuilder); @@ -360,8 +345,8 @@ assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); // Some short names. - auto par = IterType::Parallel; - auto red = IterType::Reduction; + auto par = IteratorType::Parallel; + auto red = IteratorType::Reduction; auto s = strides; auto d = dilations; @@ -393,8 +378,8 @@ assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); // Some short names. - auto par = IterType::Parallel; - auto red = IterType::Reduction; + auto par = IteratorType::Parallel; + auto red = IteratorType::Reduction; auto s = strides; auto d = dilations; diff --git a/mlir/lib/Dialect/VectorOps/CMakeLists.txt b/mlir/lib/Dialect/VectorOps/CMakeLists.txt --- a/mlir/lib/Dialect/VectorOps/CMakeLists.txt +++ b/mlir/lib/Dialect/VectorOps/CMakeLists.txt @@ -3,6 +3,7 @@ VectorOps.cpp VectorTransforms.cpp VectorUtils.cpp + EDSC/Builders.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps diff --git a/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp b/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp @@ -0,0 +1,41 @@ +//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/VectorOps/EDSC/Builders.h" +#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/Functional.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::edsc::ops; + +Value mlir::edsc::ops::vector_contraction( + StructuredIndexed A, StructuredIndexed B, StructuredIndexed C, + ArrayRef iteratorTypes) { + using IndexingExprs = ArrayRef>; + return vector_contract( + A.getValue(), B.getValue(), C.getValue(), + IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()}, + ArrayRef{functional::map(toString, iteratorTypes)}); +} + +Value mlir::edsc::ops::vector_matmul(Value A, Value B, Value C) { + AffineExpr m, n, k; + bindDims(ScopedContext::getContext(), m, n, k); + return vector_contraction(StructuredIndexed(A, {m, k}), + StructuredIndexed(B, {k, n}), + StructuredIndexed(C, {m, n}), + {IteratorType::Parallel, IteratorType::Parallel, + IteratorType::Reduction}); +} diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -64,6 +64,19 @@ // ContractionOp //===----------------------------------------------------------------------===// +void vector::ContractionOp::build(Builder *builder, OperationState &result, + Value lhs, Value rhs, Value acc, + ArrayRef> indexingExprs, + ArrayRef iteratorTypes) { + result.addOperands({lhs, rhs, acc}); + result.addTypes(acc.getType()); + result.addAttribute(getIndexingMapsAttrName(), + builder->getAffineMapArrayAttr( + AffineMap::inferFromExprList(indexingExprs))); + result.addAttribute(getIteratorTypesAttrName(), + builder->getStrArrayAttr(iteratorTypes)); +} + void vector::ContractionOp::build(Builder *builder, OperationState &result, Value lhs, Value rhs, Value acc, ArrayAttr indexingMaps, diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -111,6 +111,41 @@ return permutationMap; } +static void getMaxDimAndSymbol(ArrayRef> exprsList, + int64_t &maxDim, int64_t &maxSym) { + for (auto exprs : exprsList) { + for (auto expr : exprs) { + expr.walk([&maxDim, &maxSym](AffineExpr e) { + if (auto d = e.dyn_cast()) + maxDim = std::max(maxDim, static_cast(d.getPosition())); + if (auto s = e.dyn_cast()) + maxSym = std::max(maxSym, static_cast(s.getPosition())); + }); + } + } +} + +SmallVector +AffineMap::inferFromExprList(ArrayRef> exprsList) { + int64_t maxDim = -1, maxSym = -1; + getMaxDimAndSymbol(exprsList, maxDim, maxSym); + SmallVector maps; + maps.reserve(exprsList.size()); + for (auto exprs : exprsList) + maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, + /*symbolCount=*/maxSym + 1, exprs)); + return maps; +} + +SmallVector +AffineMap::inferFromExprList(ArrayRef> exprsList) { + SmallVector, 4> exprsVector; + exprsVector.reserve(exprsList.size()); + for (const auto &el : exprsList) + exprsVector.emplace_back(el.begin(), el.end()); + return inferFromExprList(exprsVector); +} + AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, MLIRContext *context) { SmallVector dimExprs; diff --git a/mlir/test/EDSC/CMakeLists.txt b/mlir/test/EDSC/CMakeLists.txt --- a/mlir/test/EDSC/CMakeLists.txt +++ b/mlir/test/EDSC/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRLoopOps MLIRStandardOps MLIRTransforms + MLIRVectorOps LLVMCore LLVMSupport ) @@ -25,5 +26,6 @@ MLIRLinalgOps MLIRLoopOps MLIRStandardOps + MLIRVectorOps MLIRTransforms ) diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -11,8 +11,9 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/EDSC/Builders.h" +#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" @@ -470,7 +471,8 @@ auto i1Type = IntegerType::get(1, &globalContext()); auto i8Type = IntegerType::get(8, &globalContext()); auto memrefType = MemRefType::get({}, i1Type, {}, 0); - auto f = makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); + auto f = + makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -1009,6 +1011,36 @@ f.erase(); } +// CHECK-LABEL: func @vector_matmul_test( +// CHECK-SAME: %[[A:.*]]: vector<4x16xf32>, +// CHECK-SAME: %[[B:.*]]: vector<16x8xf32>, +// CHECK-SAME: %[[C:.*]]: vector<4x8xf32>) +// CHECK: vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"] +// CHECK-SAME: %[[A]], %[[B]], %[[C]] +// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32> +TEST_FUNC(vector_matmul_test) { + using namespace edsc; + using namespace edsc::ops; + + int64_t M = 4, N = 8, K = 16; + auto f32Type = FloatType::getF32(&globalContext()); + auto mkVectorType = VectorType::get({M, K}, f32Type); + auto knVectorType = VectorType::get({K, N}, f32Type); + auto mnVectorType = VectorType::get({M, N}, f32Type); + auto f = makeFunction("vector_matmul_test", {}, + {mkVectorType, knVectorType, mnVectorType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + vector_matmul(A, B, C); + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0;