diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -23,6 +23,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -36,6 +38,275 @@ #define DEBUG_TYPE "linalg-vectorization" +/// Helper data structure to represent the result of vectorization. +/// In certain specific cases, like terminators, we do not want to propagate/ +enum VectorizationStatus { + /// Op failed to vectorize. + Failure = 0, + /// Op vectorized and custom function took care of replacement logic + NoReplace, + /// Op vectorized into a new Op whose results will replace original Op's + /// results. + NewOp + // TODO: support values if Op vectorized to Many-Ops whose results we need to + // aggregate for replacement. +}; +struct VectorizationResult { + /// Return status from vectorizing the current op. + enum VectorizationStatus status = VectorizationStatus::Failure; + /// New vectorized operation to replace the current op. + /// Replacement behavior is specified by `status`. + Operation *newOp; +}; + +/// Return a vector type of the same shape and element type as the (assumed) +/// ShapedType of `v`. +static VectorType extractVectorTypeFromShapedValue(Value v) { + auto st = v.getType().cast(); + if (st.isa() && st.getShape().empty()) + return VectorType(); + return VectorType::get(st.getShape(), st.getElementType()); +} + +/// Build a vector.transfer_read from `source` at indices set to all `0`. +/// If source has rank zero, build an std.load. +/// Return the produced value. +static Value buildVectorRead(OpBuilder &builder, Value source) { + edsc::ScopedContext scope(builder); + auto shapedType = source.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { + SmallVector indices(shapedType.getRank(), std_constant_index(0)); + return vector_transfer_read(vectorType, source, indices); + } + return std_load(source); +} + +/// Build a vector.transfer_write of `value` into `dest` at indices set to all +/// `0`. If `dest` has null rank, build an std.store. +/// Return the produced value or null if no value is produced. +static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) { + edsc::ScopedContext scope(builder); + Operation *write; + auto shapedType = dest.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { + SmallVector indices(shapedType.getRank(), std_constant_index(0)); + if (vectorType != value.getType()) + value = vector_broadcast(vectorType, value); + write = vector_transfer_write(value, dest, indices); + } else { + write = std_store(value, dest); + } + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); + if (!write->getResults().empty()) + return write->getResult(0); + return Value(); +} + +/// If value of assumed VectorType has a shape different than `shape`, buil and +/// return a new vector.broadcast to `shape`. +/// Otherwise, just return value. +static Value broadcastIfNeeded(OpBuilder &builder, Value value, + ArrayRef shape) { + auto vecType = value.getType().dyn_cast(); + if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) + return value; + auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() + : value.getType()); + return builder.create( + builder.getInsertionPoint()->getLoc(), newVecType, value); +} + +// Custom vectorization function type. Produce a vector form of Operation* +// assuming all its vectorized operands are already in the BlockAndValueMapping. +// Return nullptr if the Operation cannot be vectorized. +using CustomVectorizationHook = std::function; + +/// Helper function to vectorize the terminator of a `linalgOp`. New result +/// vector values are appended to `results`. +/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm +/// that it should not try to map produced operations: this is the purpose of +/// the `results` argument to capture such values and make them available for +/// RAUW to the vectorization algorithm. +/// This function is meant to be used as a CustomVectorizationHook. +static VectorizationResult +vectorizeLinalgYield(OpBuilder &builder, Operation *op, + const BlockAndValueMapping &bvm, LinalgOp linalgOp, + SmallVectorImpl &results) { + auto yieldOp = dyn_cast(op); + if (!yieldOp) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + for (auto outputs : llvm::enumerate(yieldOp.values())) { + // TODO: Scan for an opportunity for reuse. + // TODO: use a map. + Value vectorValue = bvm.lookup(outputs.value()); + Value result = buildVectorWrite(builder, vectorValue, + linalgOp.getOutput(outputs.index())); + if (result) + results.push_back(result); + } + return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; +}; + +/// Generic vectorization for a single operation `op`, given already vectorized +/// operands carried by `bvm`. Vectorization occurs as follows: +/// 1. Try to apply any of the `customVectorizationHooks` and return its +/// result on success. +/// 2. Clone any constant in the current scope without vectorization: each +/// consumer of the constant will later determine the shape to which the +/// constant needs to be broadcast to. +/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose +/// of the `customVectorizationHooks` to cover such cases. +/// 4. Clone `op` in vector form to a vector of shape prescribed by the first +/// operand of maximal rank. Other operands have smaller rank and are +/// broadcast accordingly. It is assumed this broadcast is always legal, +/// otherwise, it means one of the `customVectorizationHooks` is incorrect. +/// +/// This function assumes all operands of `op` have been vectorized and are in +/// the `bvm` mapping. As a consequence, this function is meant to be called on +/// a topologically-sorted list of ops. +/// This function does not update `bvm` but returns a VectorizationStatus that +/// instructs the caller what `bvm` update needs to occur. +static VectorizationResult +vectorizeOneOp(OpBuilder &builder, Operation *op, + const BlockAndValueMapping &bvm, + ArrayRef customVectorizationHooks) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op); + + // 1. Try to apply any CustomVectorizationHook. + if (!customVectorizationHooks.empty()) { + for (auto &customFunc : customVectorizationHooks) { + VectorizationResult result = customFunc(op, bvm); + if (result.status == VectorizationStatus::Failure) + continue; + return result; + } + } + + // 2. Constant ops don't get vectorized but rather broadcasted at their users. + // Clone so that the constant is not confined to the linalgOp block . + if (isa(op)) + return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)}; + + // 3. Only ElementwiseMappable are allowed in the generic vectorization. + if (!op->hasTrait()) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + + // 4. Generic vectorization path for ElementwiseMappable ops. + // a. first get the first max ranked shape. + SmallVector firstMaxRankedShape; + for (Value operand : op->getOperands()) { + auto vt = bvm.lookup(operand).getType().dyn_cast(); + if (vt && firstMaxRankedShape.size() < vt.getShape().size()) + firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); + } + // b. broadcast each op if needed. + auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { + return firstMaxRankedShape.empty() + ? bvm.lookup(v) + : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape); + }); + // c. for elementwise, the result is the vector with the firstMaxRankedShape + auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { + return firstMaxRankedShape.empty() + ? t + : VectorType::get(firstMaxRankedShape, t); + }); + + // Build and return the new op. + OperationState state(op->getLoc(), op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(llvm::to_vector<4>(vectorizedOperands)); + state.addTypes(llvm::to_vector<4>(returnTypes)); + return VectorizationResult{VectorizationStatus::NewOp, + builder.createOperation(state)}; +} + +/// Generic vectorization function that rewrites the body of a `linalgOp` into +/// vector form. Generic vectorization proceeds as follows: +/// 1. The region for the linalg op is created if necessary. +/// 2. Values defined above the region are mapped to themselves and will be +/// broadcasted on a per-need basis by their consumers. +/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d +/// load). +/// TODO: Reuse opportunities for RAR dependencies. +/// 4. Register CustomVectorizationHook for YieldOp to capture the results. +/// 5. Iteratively call vectorizeOneOp on the region operations. +/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp. +static LogicalResult vectorizeAsLinalgGeneric( + OpBuilder &builder, LinalgOp linalgOp, + ArrayRef customVectorizationHooks = {}) { + // 1. Certain Linalg ops do not have a region but only a region builder. + // If so, build the region so we can vectorize. + std::unique_ptr owningRegion; + Region *region; + if (linalgOp->getNumRegions() > 0) { + region = &linalgOp->getRegion(0); + } else { + // RAII avoid remaining in block. + OpBuilder::InsertionGuard g(builder); + owningRegion = std::make_unique(); + region = owningRegion.get(); + Block *block = builder.createBlock(region); + auto elementTypes = llvm::to_vector<4>( + llvm::map_range(linalgOp.getShapedOperandTypes(), + [](ShapedType t) { return t.getElementType(); })); + block->addArguments(elementTypes); + linalgOp.getRegionBuilder()(*block); + } + Block *block = ®ion->front(); + + BlockAndValueMapping bvm; + // 2. Values defined above the region can only be broadcast for now. Make them + // map to themselves. + llvm::SetVector valuesSet; + mlir::getUsedValuesDefinedAbove(*region, valuesSet); + bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); + + // 3. Turn all BBArgs into vector.transfer_read / load. + SmallVector indexings; + for (auto bbarg : block->getArguments()) { + Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); + Value vectorRead = buildVectorRead(builder, vectorArg); + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" + << bbarg.getArgNumber() << "): " << vectorRead); + bvm.map(bbarg, vectorRead); + bvm.map(vectorArg, vectorRead); + } + + // 4. Register CustomVectorizationHook for yieldOp. + SmallVector results; + CustomVectorizationHook vectorizeYield = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + return vectorizeLinalgYield(builder, op, bvm, linalgOp, results); + }; + // Append the vectorizeYield hook. + auto hooks = llvm::to_vector<4>(customVectorizationHooks); + hooks.push_back(vectorizeYield); + + // 5. Iteratively call `vectorizeOneOp` to each op in the slice. + for (Operation &op : block->getOperations()) { + VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); + if (result.status == VectorizationStatus::Failure) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op); + return failure(); + } + if (result.status == VectorizationStatus::NewOp) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " + << *result.newOp;); + bvm.map(op.getResults(), result.newOp->getResults()); + } + } + + // 6. RAUW the linalg op by the results captured vectorizing the YieldOp. + if (!results.empty()) + linalgOp->replaceAllUsesWith(results); + return success(); +} + +/// Detect whether `r` exactly computes a floating-point or integer +/// multiply-accumulate. static bool hasMultiplyAddBody(Region &r) { if (!llvm::hasSingleElement(r)) return false; @@ -65,6 +336,7 @@ pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); } +/// Detect whether the LinalgOp `op` is a contraction. // TODO: Should be Tablegen'd from a single source that generates the op itself. static LogicalResult isContraction(Operation *op) { // TODO: interface for named ops. @@ -84,6 +356,7 @@ hasMultiplyAddBody(genericOp.region())); } +/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. static bool hasOnlyScalarElementwiseOp(Region &r) { if (!llvm::hasSingleElement(r)) return false; @@ -119,171 +392,6 @@ return hasOnlyScalarElementwiseOp(genericOp.getRegion()); } -static VectorType extractVectorTypeFromShapedValue(Value v) { - auto st = v.getType().cast(); - if (st.isa() && st.getShape().empty()) - return VectorType(); - return VectorType::get(st.getShape(), st.getElementType()); -} - -static Value transferReadVector(OpBuilder &builder, Value source) { - edsc::ScopedContext scope(builder); - auto shapedType = source.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { - SmallVector indices(shapedType.getRank(), std_constant_index(0)); - return vector_transfer_read(vectorType, source, indices); - } - return std_load(source); -} - -static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) { - edsc::ScopedContext scope(builder); - Operation *write; - auto shapedType = dest.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { - SmallVector indices(shapedType.getRank(), std_constant_index(0)); - if (vectorType != value.getType()) - value = vector_broadcast(vectorType, value); - write = vector_transfer_write(value, dest, indices); - } else { - write = std_store(value, dest); - } - if (!write->getResults().empty()) - return write->getResult(0); - return Value(); -} - -namespace { -// Transforms scalar operations into their vectorized counterparts, -// while using the provided generic op to map: -// * Its arguments to transfer reads from the views of the generic op. -// * linalg.yield ops to transfer writes to the views of the generic op. -class GenericVectorizer { -public: - GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic) - : builder(builder), generic(generic) {} - - // Takes a scalar operation and builds its vectorized counterpart or - // counterparts using the underlying builder. - // If operands of the scalar operation are referring to previously vectorized - // operations, then in their vectorized form these operands will be referring - // to previous vectorization results. - void vectorize(Operation &scalarOp) { - auto yieldOp = dyn_cast(scalarOp); - if (yieldOp) { - for (auto outputs : llvm::enumerate(yieldOp.values())) { - Value vectorValue = vectorize(outputs.value()); - Value result = transferWriteVector(builder, vectorValue, - generic.getOutput(outputs.index())); - if (result) - results.push_back(result); - } - return; - } - Operation *vectorOp = uncachedVectorize(scalarOp); - assert(scalarOp.getNumResults() == vectorOp->getNumResults()); - for (auto result : - llvm::zip(scalarOp.getResults(), vectorOp->getResults())) { - valueCache[std::get<0>(result)] = std::get<1>(result); - } - } - - llvm::ArrayRef getResults() { return results; } - -private: - // Transforms a scalar value into its vectorized counterpart, recursively - // vectorizing operations as necessary using the underlying builder. - // Keeps track of previously vectorized values and reuses vectorization - // results if these values come up again. - Value vectorize(Value scalarValue) { - // Don't vectorize values coming from outside the region. - if (scalarValue.getParentRegion() != &generic.region()) - return scalarValue; - auto vectorValueIt = valueCache.find(scalarValue); - if (vectorValueIt != valueCache.end()) - return vectorValueIt->second; - - // If the value is from the region but not in the cache it means it is a - // block argument. - auto scalarArg = scalarValue.cast(); - assert(scalarArg.getOwner() == &generic.region().front()); - Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber()); - Value vectorResult = transferReadVector(builder, vectorArg); - valueCache[scalarArg] = vectorResult; - return vectorResult; - } - - // Return the largest shape of all the given values. Return an empty - // SmallVector if there are no vector value. - static SmallVector getLargestShape(ArrayRef values) { - SmallVector largestShape; - int64_t maxSize = 1; - for (Value value : values) { - auto vecType = value.getType().dyn_cast(); - if (!vecType) - continue; - if (maxSize < vecType.getNumElements()) { - maxSize = vecType.getNumElements(); - largestShape.assign(vecType.getShape().begin(), - vecType.getShape().end()); - } - } - return largestShape; - } - - // If the value's type doesn't have the given shape broadcast it. - Value broadcastIfNeeded(Value value, ArrayRef shape) { - auto vecType = value.getType().dyn_cast(); - if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) - return value; - auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() - : value.getType()); - return builder.create( - builder.getInsertionPoint()->getLoc(), newVecType, value); - } - - // Takes a scalar operation and builds its vectorized counterpart or - // counterparts using underlying builder without involving any caches. - Operation *uncachedVectorize(Operation &base_scalarOp) { - SmallVector vectorizedOperands; - for (Value operand : base_scalarOp.getOperands()) { - vectorizedOperands.push_back(vectorize(operand)); - } - SmallVector shape = getLargestShape(vectorizedOperands); - for (Value &operand : vectorizedOperands) - operand = broadcastIfNeeded(operand, shape); - OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName()); - state.addAttributes(base_scalarOp.getAttrs()); - state.addOperands(vectorizedOperands); - if (shape.empty()) { - state.addTypes(base_scalarOp.getResultTypes()); - } else { - SmallVector vectorizedTypes; - for (auto Type : base_scalarOp.getResultTypes()) - vectorizedTypes.push_back(VectorType::get(shape, Type)); - state.addTypes(vectorizedTypes); - } - return builder.createOperation(state); - } - - OpBuilder &builder; - linalg::GenericOp generic; - llvm::DenseMap valueCache; - SmallVector results; -}; -} // namespace - -// Replaces elementwise linalg.generic ops with their bodies with scalar -// operations from these bodies promoted to vector operations. -static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) { - GenericVectorizer vectorizer(builder, op); - for (Operation &scalarOp : op.region().front()) { - vectorizer.vectorize(scalarOp); - } - if (!op->getResults().empty()) - op->replaceAllUsesWith(vectorizer.getResults()); -} - LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -313,7 +421,7 @@ // Vectorize fill as a vector.broadcast. LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg.fill as vector.broadcast: " << *op); - transferWriteVector(builder, fillOp.value(), fillOp.output()); + buildVectorWrite(builder, fillOp.value(), fillOp.output()); return; } if (auto copyOp = dyn_cast(op)) { @@ -322,17 +430,21 @@ << "Rewrite linalg.copy as vector.transfer_read + " "vector.transfer_write: " << *op); - Value vector = transferReadVector(builder, copyOp.input()); - transferWriteVector(builder, vector, copyOp.output()); + Value vector = buildVectorRead(builder, copyOp.input()); + buildVectorWrite(builder, vector, copyOp.output()); return; } + auto linalgOp = cast(op); + Location loc = linalgOp.getLoc(); + if (isElementwise(op)) { LLVM_DEBUG(dbgs() << dbgPref - << "Rewrite linalg op as vector.transfer_read + " - "vector_op + vector.transfer_write: " - << *op); - return vectorizeElementwise(cast(op), builder); + << "Rewrite linalg op as vector.transfer_read + " << *op); + auto status = vectorizeAsLinalgGeneric(builder, linalgOp); + assert(succeeded(status) && + "Unexpected vectorization failed despite preconditions"); + return; } assert(succeeded(isContraction(op)) && "Expected contraction"); @@ -341,15 +453,28 @@ // TODO: interface. LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg op as vector.contract: " << *op); - auto linalgOp = cast(op); - Value a = transferReadVector(builder, linalgOp.getInput(0)); - Value b = transferReadVector(builder, linalgOp.getInput(1)); - Value c = transferReadVector(builder, linalgOp.getOutput(0)); - Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0)); - if (writeResult) - linalgOp->replaceAllUsesWith(ArrayRef(writeResult)); + // Special function that describes how to vectorize the multiplication op in a + // linalg contraction. + CustomVectorizationHook vectorizeContraction = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + if (!isa(op)) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + auto outShape = linalgOp.getOutputShapedType(0).getShape(); + auto vType = outShape.empty() + ? op->getResult(0).getType() + : VectorType::get(outShape, op->getResult(0).getType()); + auto zero = + builder.create(loc, vType, builder.getZeroAttr(vType)); + Operation *contract = builder.create( + loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, + linalgOp.indexing_maps(), linalgOp.iterator_types()); + return VectorizationResult{VectorizationStatus::NewOp, contract}; + }; + auto status = + vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction}); + assert(succeeded(status) && + "Unexpected vectorization failed despite preconditions"); } /// Check whether there is any interleaved use of any `values` between `firstOp` diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -183,13 +183,13 @@ // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> @@ -267,13 +267,13 @@ // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> @@ -307,10 +307,15 @@ // CHECK-LABEL: func @matmul_tensors // CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, // CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> -// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> -// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32> +// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> +// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> +// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> +// +// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. +// a later canonicalization fuses the add into vector.contract. +// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> +// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32> +// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> // CHECK: return %[[W]] : tensor<8x12xf32>