diff --git a/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h @@ -0,0 +1,53 @@ +//===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable interfaces for building structured MLIR +// snippets in a declarative fashion. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ +#define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" + +namespace mlir { +namespace edsc { +namespace ops { + +/// Build a generic vector contraction, that is a `vector.contract` op with +/// specified `iteratorTypes`. The client is responsible for specifying proper +/// indexings when creating the StructuredIndexed. +/// The computation represents a notional (A * B + C) where indexings specify +/// which dimensions are reduced and reordered. +/// Return the result of the `vector.contract` op +/// +/// Prerequisites: +/// A, B and C capture values of proper vector types, and indexing expressions +/// that match semantics of the `vector.contract` op. +Value vector_contraction(StructuredIndexed A, StructuredIndexed B, + StructuredIndexed C, + ArrayRef iteratorTypes); + +/// Build a generic vector contraction that computes a matmul on vectors. +/// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors. +/// +/// Prerequisites: +/// A, B and C capture values of proper vector types. For instance +/// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`. +Value vector_matmul(Value A, Value B, Value C); + +} // namespace ops +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h @@ -0,0 +1,23 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for VectorOps --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/VectorOps/EDSC/Builders.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using vector_contract = ValueBuilder; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -141,7 +141,11 @@ }]; let builders = [OpBuilder< "Builder *builder, OperationState &result, Value lhs, Value rhs, " - "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; + "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">, + OpBuilder< + "Builder *builder, OperationState &result, Value lhs, Value rhs, " + "Value acc, ArrayRef> indexingExprs, " + "ArrayRef iteratorTypes">]; let extraClassDeclaration = [{ VectorType getLhsType() { return lhs().getType().cast(); diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -436,8 +436,9 @@ StructuredIndexed(Value v, ArrayRef indexings) : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { assert((v.getType().isa() || - v.getType().isa()) && - "MemRef or RankedTensor expected"); + v.getType().isa() || + v.getType().isa()) && + "MemRef, RankedTensor or Vector expected"); } StructuredIndexed(ValueHandle vh, ArrayRef indexings) : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -63,6 +63,14 @@ static AffineMap getPermutationMap(ArrayRef permutation, MLIRContext *context); + /// Returns a vector of AffineMaps; each with as many results as + /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many + /// symbols as the largest symbol in `exprs`. + static SmallVector + inferFromExprList(ArrayRef> exprsList); + static SmallVector + inferFromExprList(ArrayRef> exprsList); + MLIRContext *getContext() const; explicit operator bool() { return map != nullptr; } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -130,18 +130,6 @@ } // namespace edsc } // namespace mlir -static void getMaxDimIndex(ArrayRef structuredIndices, - unsigned &pos) { - for (auto sidx : structuredIndices) { - for (auto expr : sidx.getExprs()) { - expr.walk([&pos](AffineExpr e) { - if (auto d = e.dyn_cast()) - pos = std::max(pos, d.getPosition()); - }); - } - } -} - Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, @@ -155,20 +143,16 @@ auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); unsigned nOutputs = outputs.size(); - unsigned maxPos = 0; - getMaxDimIndex(inputs, maxPos); - getMaxDimIndex(outputs, maxPos); - // maxPos is 0 indexed, need to turn this into a count (i.e. +1) - unsigned nDims = maxPos + 1; - - SmallVector maps; - maps.reserve(nInputs + nOutputs); - for (auto in : inputs) - maps.push_back( - AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs())); - for (auto out : outputs) - maps.push_back( - AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); + + SmallVector, 4> exprsList; + exprsList.reserve(nInputs + nOutputs); + for (auto structuredIndexed : inputs) + exprsList.emplace_back(structuredIndexed.getExprs().begin(), + structuredIndexed.getExprs().end()); + for (auto structuredIndexed : outputs) + exprsList.emplace_back(structuredIndexed.getExprs().begin(), + structuredIndexed.getExprs().end()); + auto maps = AffineMap::inferFromExprList(exprsList); unsigned nViews = nInputs + nOutputs; SmallVector values; diff --git a/mlir/lib/Dialect/VectorOps/CMakeLists.txt b/mlir/lib/Dialect/VectorOps/CMakeLists.txt --- a/mlir/lib/Dialect/VectorOps/CMakeLists.txt +++ b/mlir/lib/Dialect/VectorOps/CMakeLists.txt @@ -3,6 +3,7 @@ VectorOps.cpp VectorTransforms.cpp VectorUtils.cpp + EDSC/Builders.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps diff --git a/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp b/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp @@ -0,0 +1,41 @@ +//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/VectorOps/EDSC/Builders.h" +#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/Functional.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::edsc::ops; + +Value mlir::edsc::ops::vector_contraction( + StructuredIndexed A, StructuredIndexed B, StructuredIndexed C, + ArrayRef iteratorTypes) { + using IndexingExprs = ArrayRef>; + return vector_contract( + A.getValue(), B.getValue(), C.getValue(), + IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()}, + ArrayRef{functional::map(toString, iteratorTypes)}); +} + +Value mlir::edsc::ops::vector_matmul(Value A, Value B, Value C) { + AffineExpr m, n, k; + bindDims(ScopedContext::getContext(), m, n, k); + return vector_contraction(StructuredIndexed(A, {m, k}), + StructuredIndexed(B, {k, n}), + StructuredIndexed(C, {m, n}), + {IteratorType::Parallel, IteratorType::Parallel, + IteratorType::Reduction}); +} diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -64,6 +64,19 @@ // ContractionOp //===----------------------------------------------------------------------===// +void vector::ContractionOp::build(Builder *builder, OperationState &result, + Value lhs, Value rhs, Value acc, + ArrayRef> indexingExprs, + ArrayRef iteratorTypes) { + result.addOperands({lhs, rhs, acc}); + result.addTypes(acc.getType()); + result.addAttribute(getIndexingMapsAttrName(), + builder->getAffineMapArrayAttr( + AffineMap::inferFromExprList(indexingExprs))); + result.addAttribute(getIteratorTypesAttrName(), + builder->getStrArrayAttr(iteratorTypes)); +} + void vector::ContractionOp::build(Builder *builder, OperationState &result, Value lhs, Value rhs, Value acc, ArrayAttr indexingMaps, diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -111,6 +111,44 @@ return permutationMap; } +template +static void getMaxDimAndSymbol(ArrayRef exprsList, + int64_t &maxDim, int64_t &maxSym) { + for (const auto &exprs : exprsList) { + for (auto expr : exprs) { + expr.walk([&maxDim, &maxSym](AffineExpr e) { + if (auto d = e.dyn_cast()) + maxDim = std::max(maxDim, static_cast(d.getPosition())); + if (auto s = e.dyn_cast()) + maxSym = std::max(maxSym, static_cast(s.getPosition())); + }); + } + } +} + +template +SmallVector +inferFromExprList(ArrayRef exprsList) { + int64_t maxDim = -1, maxSym = -1; + getMaxDimAndSymbol(exprsList, maxDim, maxSym); + SmallVector maps; + maps.reserve(exprsList.size()); + for (const auto &exprs : exprsList) + maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, + /*symbolCount=*/maxSym + 1, exprs)); + return maps; +} + +SmallVector +AffineMap::inferFromExprList(ArrayRef> exprsList) { + return ::inferFromExprList(exprsList); +} + +SmallVector +AffineMap::inferFromExprList(ArrayRef> exprsList) { + return ::inferFromExprList(exprsList); +} + AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, MLIRContext *context) { SmallVector dimExprs; diff --git a/mlir/test/EDSC/CMakeLists.txt b/mlir/test/EDSC/CMakeLists.txt --- a/mlir/test/EDSC/CMakeLists.txt +++ b/mlir/test/EDSC/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRLoopOps MLIRStandardOps MLIRTransforms + MLIRVectorOps LLVMCore LLVMSupport ) @@ -25,5 +26,6 @@ MLIRLinalgOps MLIRLoopOps MLIRStandardOps + MLIRVectorOps MLIRTransforms ) diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" @@ -981,6 +982,36 @@ f.erase(); } +// CHECK-LABEL: func @vector_matmul_test( +// CHECK-SAME: %[[A:.*]]: vector<4x16xf32>, +// CHECK-SAME: %[[B:.*]]: vector<16x8xf32>, +// CHECK-SAME: %[[C:.*]]: vector<4x8xf32>) +// CHECK: vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"] +// CHECK-SAME: %[[A]], %[[B]], %[[C]] +// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32> +TEST_FUNC(vector_matmul_test) { + using namespace edsc; + using namespace edsc::ops; + + int64_t M = 4, N = 8, K = 16; + auto f32Type = FloatType::getF32(&globalContext()); + auto mkVectorType = VectorType::get({M, K}, f32Type); + auto knVectorType = VectorType::get({K, N}, f32Type); + auto mnVectorType = VectorType::get({M, N}, f32Type); + auto f = makeFunction("vector_matmul_test", {}, + {mkVectorType, knVectorType, mnVectorType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + vector_matmul(A, B, C); + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0;