diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir @@ -77,7 +77,7 @@ scf.for %arg0 = %c0 to %iters step %c1 { // linalg.matmul writes %C in place, need to reset it to zero every time. // This is accounts for about 10-15% perf hit on small sizes. - // Once linalg on tensors is ready, fusing fill at teh register level will + // Once linalg on tensors is ready, fusing fill at the register level will // be easy. %z = constant 0.0 : !elem_type_c linalg.fill(%C, %z) : !row_major_C, !elem_type_c diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir @@ -75,7 +75,7 @@ scf.for %arg0 = %c0 to %iters step %c1 { // linalg.matmul writes %C in place, need to reset it to zero every time. // This is accounts for about 10-15% perf hit on small sizes. - // Once linalg on tensors is ready, fusing fill at teh register level will + // Once linalg on tensors is ready, fusing fill at the register level will // be easy. linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> () diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir @@ -84,7 +84,7 @@ scf.for %arg0 = %c0 to %iters step %c1 { // linalg.matmul writes %C in place, need to reset it to zero every time. // This is accounts for about 10-15% perf hit on small sizes. - // Once linalg on tensors is ready, fusing fill at teh register level will + // Once linalg on tensors is ready, fusing fill at the register level will // be easy. linalg.fill(%C, %f0) : !row_major_C, !elem_type_c call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) : diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir @@ -1,12 +1,11 @@ // RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \ // RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \ -// TODO: extend vectorization with interfaces so that it works with sexti -// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16" | \ +// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16 vectorize" | \ // RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \ // RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \ // RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \ -// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \ +// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -mlir-disable-threading | \ // RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ // Activate to dump assembly // R_UN: -dump-object-file -object-filename=/tmp/a.o \ @@ -18,9 +17,9 @@ !elem_type_a = type i8 !elem_type_b = type i8 !elem_type_c = type i32 -!row_major_A = type memref<${M}x${K}x!elem_type_a> -!row_major_B = type memref<${K}x${N}x!elem_type_b> -!row_major_C = type memref<${M}x${N}x!elem_type_c> +!row_major_A = type memref<24x64x!elem_type_a> +!row_major_B = type memref<64x192x!elem_type_b> +!row_major_C = type memref<24x192x!elem_type_c> func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C) // TODO: activate manually for now. @@ -33,9 +32,9 @@ func @print_perf(%iters: index, %total_time: f64) { %c2 = constant 2 : index - %cM = constant ${M} : index - %cN = constant ${N} : index - %cK = constant ${K} : index + %cM = constant 24 : index + %cN = constant 192 : index + %cK = constant 64 : index %mn = muli %cM, %cN : index %mnk = muli %mn, %cK : index @@ -65,7 +64,7 @@ %c0 = constant 0: index %c1 = constant 1: index - %iters = constant ${ITERS}: index + %iters = constant 100: index /// Run and dump performance for matmul. /// Preheating run: @@ -77,7 +76,7 @@ scf.for %arg0 = %c0 to %iters step %c1 { // linalg.matmul writes %C in place, need to reset it to zero every time. // This is accounts for about 10-15% perf hit on small sizes. - // Once linalg on tensors is ready, fusing fill at teh register level will + // Once linalg on tensors is ready, fusing fill at the register level will // be easy. linalg.fill(%C, %v0) : !row_major_C, !elem_type_c call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -38,6 +38,71 @@ #define DEBUG_TYPE "linalg-vectorization" +/// Return true if the use-def chain from `v` to `from` consists of 0 or more +/// unary single-operand operations. +// TODO: relax to multi-operands with constants, which are technically unary ops +// as needed (e.g. add5). +static bool isChainOfUnaryOpsFrom(Value v, Value from) { + while (true) { + if (v == from) + return true; + Operation *op = v.getDefiningOp(); + if (!op || op->getNumOperands() != 1) + return false; + v = op->getOperand(0); + }; +} + +/// Return the unique instance of OpType in `block` if it is indeed unique. +/// Return null if none or more than 1 instances exist. +template +static OpType getSingleOpOfType(Block &block) { + OpType res; + block.walk([&](OpType op) { + if (res) { + res = nullptr; + return WalkResult::interrupt(); + } + res = op; + return WalkResult::advance(); + }); + return res; +} + +/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` +/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent +/// unary operations that may change the type. +template +static bool isAddMul(Block &block) { + if (block.getNumArguments() != 3) + return false; + auto yieldOp = block.getTerminator(); + if (yieldOp->getNumOperands() != 1) + return false; + + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isAddMul: "; block.dump()); + AddOpType addOp = getSingleOpOfType(block); + MulOpType mulOp = getSingleOpOfType(block); + if (!addOp || !mulOp) + return false; + + Value argA = block.getArgument(0), argB = block.getArgument(1); + Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); + Value mul = mulOp->getResult(0); + Value argC = block.getArgument(2); + Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1); + Value add = addOp->getResult(0); + Value res = yieldOp->getOperand(0); + // Result traces back to add. + auto un = isChainOfUnaryOpsFrom; + bool success = un(res, add); + // One of the operands of add traces back to argC, the other to the mul. + success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC)); + // One of the operands of mul traces back to argA, the other to argB. + success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA)); + return success; +} + /// Helper data structure to represent the result of vectorization. /// In certain specific cases, like terminators, we do not want to propagate/ enum VectorizationStatus { @@ -305,55 +370,34 @@ return success(); } -/// Detect whether `r` exactly computes a floating-point or integer -/// multiply-accumulate. -static bool hasMultiplyAddBody(Region &r) { - if (!llvm::hasSingleElement(r)) - return false; - if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) - return false; - - using mlir::matchers::m_Val; - auto a = m_Val(r.getArgument(0)); - auto b = m_Val(r.getArgument(1)); - auto c = m_Val(r.getArgument(2)); - // TODO: Update this detection once we have matcher support for specifying - // that any permutation of operands matches. - auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); - auto pattern5 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern6 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern7 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern8 = m_Op(m_Op(c, m_Op(b, a))); - return pattern1.match(&r.front().back()) || - pattern2.match(&r.front().back()) || - pattern3.match(&r.front().back()) || - pattern4.match(&r.front().back()) || - pattern5.match(&r.front().back()) || - pattern6.match(&r.front().back()) || - pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); -} - /// Detect whether the LinalgOp `op` is a contraction. -// TODO: Should be Tablegen'd from a single source that generates the op itself. +/// A Linalg contraction is defined in general terms: +/// 1. Has 2 input and 1 output shapes. +/// 2. Has at least one reduction dimension. +/// 3. Has only projected permutation indexing maps. +/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field +/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary +/// operations that may change the type (e.g. for mixed-precision). +/// As a consequence, when vectorization of such an op occurs, the only special +/// behavior is that the (unique) MulOpType is vectorized into a +/// `vector.contract`. All other ops are handled in a generic fashion. +/// In the future, we may wish to allow more input arguments and elementwise and +/// constant operations that do not involve the reduction dimension(s). static LogicalResult isContraction(Operation *op) { - // TODO: interface for named ops. - if (isa(op)) - return success(); - - auto genericOp = dyn_cast(op); - if (!genericOp) + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isContraction: "; op->dump()); + auto linalgOp = dyn_cast(op); + if (!linalgOp) return failure(); - auto mapRange = genericOp.indexing_maps().getAsValueRange(); + auto mapRange = linalgOp.indexing_maps().getAsValueRange(); return success( - genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && + linalgOp.getNumInputs() == 2 && linalgOp.getNumOutputs() == 1 && + linalgOp.getNumReductionLoops() > 0 && llvm::all_of(mapRange, [](AffineMap m) { return m.isProjectedPermutation(); }) && - hasMultiplyAddBody(genericOp.region())); + // TODO: more fields than add/mul. + (isAddMul(linalgOp->getRegion(0).front()) || + isAddMul(linalgOp->getRegion(0).front()))); } /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. @@ -382,7 +426,7 @@ if (!genericOp.getOutputIndexingMap(i).isIdentity()) return false; } - // Currently limit the input indexing map to minor identity as other + // Currently boundary the input indexing map to minor identity as other // permutations might require adding transpose ops to convert the vector read // to the right shape. for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { @@ -479,6 +523,150 @@ "Unexpected vectorization failed despite preconditions"); } +//----------------------------------------------------------------------------// +// Misc. conv vectorization patterns. +//----------------------------------------------------------------------------// +// TODO: cleanup all this. +template +LogicalResult ConvOpVectorization::matchAndRewrite( + ConvOp op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + edsc::ScopedContext scope(rewriter, loc); + + ShapedType inShapeType = op.getInputShapedType(0); + ShapedType kShapeType = op.getInputShapedType(1); + + ArrayRef inShape = inShapeType.getShape(); + ArrayRef kShape = kShapeType.getShape(); + + if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) + return failure(); + + SmallVector mapping; + SmallVector vectorDims; + // Fail to apply when the size of not vectorized dimension is not 1. + for (unsigned i = 0; i < N; i++) { + if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) + return failure(); + + if (mask[i] && inShape[i] != kShape[i]) + return failure(); + + if (mask[i]) { + mapping.push_back(getAffineDimExpr(i, context)); + vectorDims.push_back(inShape[i]); + } + } + + Value input = op.getInput(0); + Value kernel = op.getInput(1); + Value output = op.getOutputBuffer(0); + + unsigned rank = inShapeType.getRank(); + unsigned numDims = mapping.size(); + Type elemType = inShapeType.getElementType(); + + auto map = AffineMap::get(rank, 0, mapping, context); + SmallVector zeros(rank, std_constant_index(0)); + auto vecType = VectorType::get(vectorDims, elemType); + + auto inputVec = vector_transfer_read(vecType, input, zeros, map); + auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); + + auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); + + std::array indexingMaps{ + AffineMap::getMultiDimIdentityMap(numDims, context), + AffineMap::getMultiDimIdentityMap(numDims, context), + AffineMap::get(numDims, 0, {}, context)}; + + std::vector iteratorTypes(numDims, "reduction"); + + auto result = rewriter.create( + loc, inputVec, kernelVec, acc, + rewriter.getAffineMapArrayAttr(indexingMaps), + rewriter.getStrArrayAttr(iteratorTypes)); + + rewriter.create(loc, result, output, ValueRange(zeros)); + rewriter.eraseOp(op); + return success(); +} + +using ConvOpConst = ConvOpVectorization; + +/// Inserts tiling, promotion and vectorization pattern for ConvOp +/// conversion into corresponding pattern lists. +template +static void +populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, + OwningRewritePatternList &promotionPatterns, + OwningRewritePatternList &vectorizationPatterns, + ArrayRef tileSizes, + MLIRContext *context) { + if (tileSizes.size() < N) + return; + + constexpr static StringRef kTiledMarker = "TILED"; + constexpr static StringRef kPromotedMarker = "PROMOTED"; + tilingPatterns.insert>( + context, LinalgTilingOptions().setTileSizes(tileSizes), + LinalgTransformationFilter(ArrayRef{}, + Identifier::get(kTiledMarker, context))); + + promotionPatterns.insert>( + context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), + LinalgTransformationFilter(Identifier::get(kTiledMarker, context), + Identifier::get(kPromotedMarker, context))); + + SmallVector mask(N); + int offset = tileSizes.size() - N; + std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), + [](int64_t i) -> bool { return i > 1; }); + + vectorizationPatterns.insert>(context, mask); +} + +void mlir::linalg::populateConvVectorizationPatterns( + MLIRContext *context, SmallVectorImpl &patterns, + ArrayRef tileSizes) { + OwningRewritePatternList tiling, promotion, vectorization; + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); + + populateVectorizationPatterns( + tiling, promotion, vectorization, tileSizes, context); + + populateVectorizationPatterns( + tiling, promotion, vectorization, tileSizes, context); + + patterns.push_back(std::move(tiling)); + patterns.push_back(std::move(promotion)); + patterns.push_back(std::move(vectorization)); +} + +//----------------------------------------------------------------------------// +// Forwarding patterns +//----------------------------------------------------------------------------// + /// Check whether there is any interleaved use of any `values` between `firstOp` /// and `secondOp`. Conservatively return `true` if any op or value is in a /// different block. @@ -649,139 +837,3 @@ return success(); } - -template -LogicalResult ConvOpVectorization::matchAndRewrite( - ConvOp op, PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); - edsc::ScopedContext scope(rewriter, loc); - - ShapedType inShapeType = op.getInputShapedType(0); - ShapedType kShapeType = op.getInputShapedType(1); - - ArrayRef inShape = inShapeType.getShape(); - ArrayRef kShape = kShapeType.getShape(); - - if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) - return failure(); - - SmallVector mapping; - SmallVector vectorDims; - // Fail to apply when the size of not vectorized dimension is not 1. - for (unsigned i = 0; i < N; i++) { - if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) - return failure(); - - if (mask[i] && inShape[i] != kShape[i]) - return failure(); - - if (mask[i]) { - mapping.push_back(getAffineDimExpr(i, context)); - vectorDims.push_back(inShape[i]); - } - } - - Value input = op.getInput(0); - Value kernel = op.getInput(1); - Value output = op.getOutputBuffer(0); - - unsigned rank = inShapeType.getRank(); - unsigned numDims = mapping.size(); - Type elemType = inShapeType.getElementType(); - - auto map = AffineMap::get(rank, 0, mapping, context); - SmallVector zeros(rank, std_constant_index(0)); - auto vecType = VectorType::get(vectorDims, elemType); - - auto inputVec = vector_transfer_read(vecType, input, zeros, map); - auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); - - auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); - - std::array indexingMaps{ - AffineMap::getMultiDimIdentityMap(numDims, context), - AffineMap::getMultiDimIdentityMap(numDims, context), - AffineMap::get(numDims, 0, {}, context)}; - - std::vector iteratorTypes(numDims, "reduction"); - - auto result = rewriter.create( - loc, inputVec, kernelVec, acc, - rewriter.getAffineMapArrayAttr(indexingMaps), - rewriter.getStrArrayAttr(iteratorTypes)); - - rewriter.create(loc, result, output, ValueRange(zeros)); - rewriter.eraseOp(op); - return success(); -} - -using ConvOpConst = ConvOpVectorization; - -/// Inserts tiling, promotion and vectorization pattern for ConvOp -/// conversion into corresponding pattern lists. -template -static void -populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, - OwningRewritePatternList &promotionPatterns, - OwningRewritePatternList &vectorizationPatterns, - ArrayRef tileSizes, - MLIRContext *context) { - if (tileSizes.size() < N) - return; - - constexpr static StringRef kTiledMarker = "TILED"; - constexpr static StringRef kPromotedMarker = "PROMOTED"; - tilingPatterns.insert>( - context, LinalgTilingOptions().setTileSizes(tileSizes), - LinalgTransformationFilter(ArrayRef{}, - Identifier::get(kTiledMarker, context))); - - promotionPatterns.insert>( - context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(Identifier::get(kTiledMarker, context), - Identifier::get(kPromotedMarker, context))); - - SmallVector mask(N); - int offset = tileSizes.size() - N; - std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), - [](int64_t i) -> bool { return i > 1; }); - - vectorizationPatterns.insert>(context, mask); -} - -void mlir::linalg::populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns, - ArrayRef tileSizes) { - OwningRewritePatternList tiling, promotion, vectorization; - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes, context); - - populateVectorizationPatterns( - tiling, promotion, vectorization, tileSizes, context); - - populateVectorizationPatterns( - tiling, promotion, vectorization, tileSizes, context); - - patterns.push_back(std::move(tiling)); - patterns.push_back(std::move(promotion)); - patterns.push_back(std::move(vectorization)); -} diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir --- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir +++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse | FileCheck %s +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file +// | FileCheck %s // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)> // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> @@ -6,16 +7,11 @@ // CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)> // CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> -func @conv_1d(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_1d ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - // CHECK-LABEL: @conv_1d // CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, %arg1: memref, %arg2: memref) { // CHECK-DAG: %[[c12:.*]] = constant 12 : index // CHECK-DAG: %[[c4:.*]] = constant 4 : index // CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 @@ -50,3 +46,8 @@ // CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] { // CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref // CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref + linalg.conv_1d ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} + diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,8 +1,6 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s -// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// ----- // CHECK-LABEL: contraction_dot func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { @@ -13,6 +11,8 @@ return } +// ----- + // CHECK-LABEL: contraction_matvec func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { // CHECK: vector.contract @@ -22,6 +22,8 @@ return } +// ----- + // CHECK-LABEL: contraction_matmul func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // CHECK: vector.contract @@ -31,6 +33,8 @@ return } +// ----- + // CHECK-LABEL: contraction_batch_matmul func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: vector.contract @@ -41,6 +45,8 @@ return } +// ----- + #matmul_trait = { args_in = 2, args_out = 1, @@ -51,8 +57,20 @@ ], iterator_types = ["parallel", "parallel", "reduction"] } + +// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @vectorization_test func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> + // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]] + // CHECK-SAME: vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) outs(%C : memref<8x32xf32>) { @@ -63,15 +81,33 @@ } return } -// CHECK-LABEL: func @vectorization_test -// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> -// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> -// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> -// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> +// ----- + +#matmul_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @vectorization_test_integer func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, %C: memref<8x32xi32>) { + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> + // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], + // CHECK-SAME: vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) outs(%C : memref<8x32xi32>) { @@ -82,58 +118,71 @@ } return } -// CHECK-LABEL: func @vectorization_test_integer -// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> -// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> -// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> -// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> -// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> +// ----- + +// CHECK-LABEL: func @vectorization_test_2 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { + // CHECK: vector.contract {{.*}} : + // vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> linalg.matmul ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) outs(%C: memref<8x32xf32>) return } -// CHECK-LABEL: func @vectorization_test_2 -// CHECK: vector.contract {{.*}} : -// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +// ----- + +// CHECK-LABEL: func @test_vectorize_fill func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { + // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> linalg.fill(%A, %arg0) : memref<8x16xf32>, f32 return } -// CHECK-LABEL: func @test_vectorize_fill -// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> -// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> +// ----- + +// CHECK-LABEL: func @test_vectorize_fill func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { + // CHECK-SAME: (%[[M:.*]]: memref, %[[V:.*]]: f32) + // CHECK: store %[[V]], %[[M]][] : memref linalg.fill(%A, %arg0) : memref, f32 return } -// CHECK-LABEL: func @test_vectorize_fill -// CHECK-SAME: (%[[M:.*]]: memref, %[[V:.*]]: f32) -// CHECK: store %[[V]], %[[M]][] : memref +// ----- + +// CHECK-LABEL: func @test_vectorize_copy func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> linalg.copy(%A, %B) : memref<8x16xf32>, memref<8x16xf32> return } -// CHECK-LABEL: func @test_vectorize_copy -// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> -// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_scalar func @test_vectorize_copy_scalar(%A : memref, %B : memref) { + // CHECK: %[[V:.*]] = load {{.*}} : memref + // CHECK: store %[[V]], {{.*}} : memref linalg.copy(%A, %B) : memref, memref return } -// CHECK-LABEL: func @test_vectorize_copy_scalar -// CHECK: %[[V:.*]] = load {{.*}} : memref -// CHECK: store %[[V]], {{.*}} : memref -func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>, - %arg2: memref<256xf32>, %i: f32) { +// ----- + +// CHECK-LABEL: func @generic_vectorize + // CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>, + // CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32) +func @generic_vectorize(%arg0: memref<4x256xf32>, + %arg1: memref<4x256xf32>, + %arg2: memref<256xf32>, %i: f32) { + // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> + // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index %c1_f32 = constant 1.0 : f32 linalg.generic { args_in = 0 : i64, @@ -159,57 +208,56 @@ memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, + // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> + // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> + // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> + // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32, %arg14 : f32): + // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> + // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> %6 = addf %arg4, %arg6 : f32 + // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> %7 = cmpf ogt, %arg3, %arg6 : f32 + // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> %8 = constant 2.0 : f32 + // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> %9 = divf %arg5, %i : f32 + // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> %10 = exp2 %arg5 : f32 + // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> %11 = mulf %arg5, %8 : f32 + // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> %12 = rsqrt %arg5 : f32 + // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> %13 = select %7, %arg5, %arg6 : f32 + // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> + // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> %14 = subf %arg5, %arg4 : f32 + // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> %15 = tanh %arg5 : f32 + // CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + // CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 } return } -// CHECK-LABEL: func @generic_vectorize -// CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32) -// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> -// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> -// CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> -// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> -// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> -// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> -// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> -// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> -// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> -// CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> -// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> +// ----- + +// CHECK-LABEL: func @generic_vectorize_tensor +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32) func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>, %arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>, %i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, @@ -240,82 +288,105 @@ ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32, %arg14 : f32): + // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> + // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> + // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> + // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> + // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> + // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> + // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> %6 = addf %arg4, %arg6 : f32 + // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> %7 = cmpf ogt, %arg3, %arg6 : f32 + // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> %8 = constant 2.0 : f32 + // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> %9 = divf %arg5, %i : f32 + // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> %10 = exp2 %arg5 : f32 + // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> %11 = mulf %arg5, %8 : f32 + // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> %12 = rsqrt %arg5 : f32 + // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> %13 = select %7, %arg5, %arg6 : f32 + // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> + // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> %14 = subf %arg5, %arg4 : f32 + // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> %15 = tanh %arg5 : f32 + // CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> + // CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 } -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> + // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9: tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> } -// CHECK-LABEL: func @generic_vectorize_tensor -// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>, -// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32) -// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> -// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> -// CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> -// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> -// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> -// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> -// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> -// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> -// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> -// CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> -// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> +// ----- +// CHECK-LABEL: func @matmul_tensors +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> func @matmul_tensors( %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>) -> tensor<8x12xf32> { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32> + // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> + // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> + // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> + // + // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. + // a later canonicalization fuses the add into vector.contract. + // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> + // CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32> + // CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) outs(%arg2: tensor<8x12xf32>) -> tensor<8x12xf32> + // CHECK: return %[[W]] : tensor<8x12xf32> return %0 : tensor<8x12xf32> } -// CHECK-LABEL: func @matmul_tensors -// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, -// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> -// CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32> -// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> -// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> -// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> -// -// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. -// a later canonicalization fuses the add into vector.contract. -// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> -// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32> -// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> -// CHECK: return %[[W]] : tensor<8x12xf32> +// ----- + +// CHECK-LABEL: func @matmul_i8_i8_i32 +// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<4x6xi8> +// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<6x12xi8> +// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32> +func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8> + // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8> + // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8> + // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32> + // + // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. + // a later canonicalization fuses the add into vector.contract. + // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] + // CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8> + // CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32> + // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32> + // CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} + // CHECK-SAME: vector<4x12xi32>, memref<4x12xi32> + linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>) + outs(%c: memref<4x12xi32>) + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -493,9 +493,11 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) { OwningRewritePatternList patterns; + // TODO: remove all this in favor of a single LinalgOp. patterns.insert< LinalgVectorizationPattern, LinalgVectorizationPattern, + LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern,