diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -53,6 +53,18 @@ for (Operation *userOp : result.getUsers()) if (forwardSlice->count(userOp) == 0) getForwardSliceImpl(userOp, forwardSlice, filter); + } else if (auto linalgOp = dyn_cast(op)) { + // Forward slice of all results. + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) + if (forwardSlice->count(userOp) == 0) + getForwardSliceImpl(userOp, forwardSlice, filter); + } + // Forward slice of all ops within the region. + for (Operation &bbOp : linalgOp->getRegion(0).front().getOperations()) { + if (forwardSlice->count(&bbOp) == 0) + getForwardSliceImpl(&bbOp, forwardSlice, filter); + } } else { assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); for (Value result : op->getResults()) { @@ -107,6 +119,10 @@ auto *loopOp = loopIv.getOperation(); if (backwardSlice->count(loopOp) == 0) getBackwardSliceImpl(loopOp, backwardSlice, filter); + } else if (auto linalgOp = dyn_cast( + blockArg.getOwner()->getParentOp())) { + if (backwardSlice->count(linalgOp) == 0) + getBackwardSliceImpl(linalgOp, backwardSlice, filter); } else if (blockArg.getOwner() != &op->getParentOfType().getBody().front()) { op->emitError("unsupported CF for operand ") << en.index(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -23,6 +24,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -36,6 +39,245 @@ #define DEBUG_TYPE "linalg-vectorization" +/// Helper data structure to represent the result of vectorization. +/// In certain specific cases, like terminators, we do not want to propagate +enum VectorizationStatus { + // Op failed to vectorize. + Failure = 0, + // Op vectorized and custom function took care of replacement logic + NoReplace, + // Op vectorized into a new Op whose results will replace original Op's + // results. + NewOp + // TODO: support values if Op vectorized to Many-Ops whose results we need to + // aggregate for replacement. +}; + +struct VectorizationResult { + enum VectorizationStatus status = VectorizationStatus::Failure; + Operation *newOp; +}; + +static VectorType extractVectorTypeFromShapedValue(Value v) { + auto st = v.getType().cast(); + if (st.isa() && st.getShape().empty()) + return VectorType(); + return VectorType::get(st.getShape(), st.getElementType()); +} + +static Value transferReadVector(OpBuilder &builder, Value source) { + edsc::ScopedContext scope(builder); + auto shapedType = source.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { + SmallVector indices(shapedType.getRank(), std_constant_index(0)); + return vector_transfer_read(vectorType, source, indices); + } + return std_load(source); +} + +static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) { + edsc::ScopedContext scope(builder); + Operation *write; + auto shapedType = dest.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { + SmallVector indices(shapedType.getRank(), std_constant_index(0)); + if (vectorType != value.getType()) + value = vector_broadcast(vectorType, value); + write = vector_transfer_write(value, dest, indices); + } else { + write = std_store(value, dest); + } + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); + if (!write->getResults().empty()) + return write->getResult(0); + return Value(); +} + +// If the value's type doesn't have the given shape broadcast it. +static Value broadcastIfNeeded(OpBuilder &builder, Value value, + ArrayRef shape) { + auto vecType = value.getType().dyn_cast(); + if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) + return value; + auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() + : value.getType()); + return builder.create( + builder.getInsertionPoint()->getLoc(), newVecType, value); +} + +static VectorizationResult +vectorizeLinalgYield(OpBuilder &builder, Operation *op, + const BlockAndValueMapping &bvm, LinalgOp linalgOp, + SmallVectorImpl &results) { + auto yieldOp = dyn_cast(op); + if (!yieldOp) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + for (auto outputs : llvm::enumerate(yieldOp.values())) { + Value vectorValue = bvm.lookup(outputs.value()); + Value result = transferWriteVector(builder, vectorValue, + linalgOp.getOutput(outputs.index())); + if (result) + results.push_back(result); + } + return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; +}; + +// Custom vectorization function type. Produce a vector form of Operation* +// assuming all its vectorized operands are already in the BlockAndValueMapping. +// Return nullptr if the Operation cannot be vectorized. +using CustomVectorizationHook = std::function; + +// Takes a scalar operation and builds its vectorized counterpart. +static VectorizationResult +vectorizeOneOp(OpBuilder &builder, Operation *op, BlockAndValueMapping &bvm, + ArrayRef customVectorizationHooks) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op); + + // First, try to apply any CustomVectorizationHook. + if (!customVectorizationHooks.empty()) { + for (auto &customFunc : customVectorizationHooks) { + VectorizationResult result = customFunc(op, bvm); + if (result.status == VectorizationStatus::Failure) + continue; + return result; + } + } + + // Constant ops don't get vectorized but rather broadcasted at their users. + // Clone so that the constant is not confined to the linalgOp block . + if (isa(op)) + return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)}; + + // Only ElementwiseMappable are allowed in the generic vectorization. + if (!op->hasTrait()) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + + // Generic vectorization path for ElementwiseMappable ops. + // a. first get the first max ranked shape. + SmallVector firstMaxRankedShape; + for (Value operand : op->getOperands()) { + auto vt = bvm.lookup(operand).getType().dyn_cast(); + if (vt && firstMaxRankedShape.size() < vt.getShape().size()) + firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); + } + // b. broadcast each op if needed. + auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { + return firstMaxRankedShape.empty() + ? bvm.lookup(v) + : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape); + }); + // c. for elementwise, the result is the vector with the firstMaxRankedShape + auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { + return firstMaxRankedShape.empty() + ? t + : VectorType::get(firstMaxRankedShape, t); + }); + OperationState state(op->getLoc(), op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(llvm::to_vector<4>(vectorizedOperands)); + state.addTypes(llvm::to_vector<4>(returnTypes)); + + return VectorizationResult{VectorizationStatus::NewOp, + builder.createOperation(state)}; +} + +static LogicalResult vectorizeAsLinalgGeneric( + OpBuilder &builder, LinalgOp linalgOp, + ArrayRef customVectorizationHooks = {}) { + + // Certain Linalg ops do not have a region but only a region builder. + // If so, build the region so we can vectorize. + std::unique_ptr owningRegion; + Region *region; + if (linalgOp->getNumRegions() > 0) + region = &linalgOp->getRegion(0); + else { + // RAII avoid remaining in block. + OpBuilder::InsertionGuard g(builder); + owningRegion = std::make_unique(); + region = owningRegion.get(); + Block *block = builder.createBlock(region); + auto elementTypes = llvm::to_vector<4>( + llvm::map_range(linalgOp.getShapedOperandTypes(), + [](ShapedType t) { return t.getElementType(); })); + block->addArguments(elementTypes); + linalgOp.getRegionBuilder()(*block); + } + + Block *block = ®ion->front(); + llvm::SetVector slice; + mlir::getBackwardSlice(block->getTerminator(), &slice, [&](Operation *op) { + return op->getBlock() == block; + }); + // Add terminator to slice for proper vectorization of vector.transfer_write. + slice.insert(block->getTerminator()); + + BlockAndValueMapping bvm; + // Set mapping for values defined above. Such values can only be broadcast for + // now. + llvm::SetVector valuesSet; + mlir::getUsedValuesDefinedAbove(*region, valuesSet); + bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); + // Turn all BBArgs into load / transfer_read. + SmallVector indexings; + for (auto bbarg : block->getArguments()) { + Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); + AffineMap indexingMap = linalgOp.getIndexingMap(bbarg.getArgNumber()); + // Scan for an opportunity for reuse. + // TODO: use a map. + bool reuse = false; + for (unsigned idx = 0, e = bbarg.getArgNumber(); idx < e; ++idx) { + if (linalgOp.getShapedOperand(idx) == vectorArg && + linalgOp.getIndexingMap(idx) == indexingMap) { + Value vectorRead = bvm.lookup(linalgOp.getShapedOperand(idx)); + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: reuse vectorized bbarg(" + << bbarg.getArgNumber() << "): " << vectorRead); + bvm.map(bbarg, vectorRead); + bvm.map(vectorArg, vectorRead); + reuse = true; + break; + } + } + if (reuse) + continue; + Value vectorRead = transferReadVector(builder, vectorArg); + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" + << bbarg.getArgNumber() << "): " << vectorRead); + bvm.map(bbarg, vectorRead); + bvm.map(vectorArg, vectorRead); + } + // Special function that describes how to vectorize the yield in a linalg op. + SmallVector results; + CustomVectorizationHook vectorizeYield = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + return vectorizeLinalgYield(builder, op, bvm, linalgOp, results); + }; + // Append the vectorizeYield hook. + auto allCustomFuncs = llvm::to_vector<4>(customVectorizationHooks); + allCustomFuncs.push_back(vectorizeYield); + // Iteratively call `vectorizeOneOp` to each op in the slice. + for (Operation *op : slice) { + VectorizationResult result = + vectorizeOneOp(builder, op, bvm, allCustomFuncs); + if (result.status == VectorizationStatus::Failure) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << *op); + return failure(); + } + if (result.status == VectorizationStatus::NewOp) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " + << *result.newOp;); + bvm.map(op->getResults(), result.newOp->getResults()); + } + } + // auto func = linalgOp->getParentOfType(); + // If results from the yield have been filled. + if (!results.empty()) + linalgOp->replaceAllUsesWith(results); + return success(); +} + static bool hasMultiplyAddBody(Region &r) { if (!llvm::hasSingleElement(r)) return false; @@ -119,171 +361,6 @@ return hasOnlyScalarElementwiseOp(genericOp.getRegion()); } -static VectorType extractVectorTypeFromShapedValue(Value v) { - auto st = v.getType().cast(); - if (st.isa() && st.getShape().empty()) - return VectorType(); - return VectorType::get(st.getShape(), st.getElementType()); -} - -static Value transferReadVector(OpBuilder &builder, Value source) { - edsc::ScopedContext scope(builder); - auto shapedType = source.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { - SmallVector indices(shapedType.getRank(), std_constant_index(0)); - return vector_transfer_read(vectorType, source, indices); - } - return std_load(source); -} - -static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) { - edsc::ScopedContext scope(builder); - Operation *write; - auto shapedType = dest.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { - SmallVector indices(shapedType.getRank(), std_constant_index(0)); - if (vectorType != value.getType()) - value = vector_broadcast(vectorType, value); - write = vector_transfer_write(value, dest, indices); - } else { - write = std_store(value, dest); - } - if (!write->getResults().empty()) - return write->getResult(0); - return Value(); -} - -namespace { -// Transforms scalar operations into their vectorized counterparts, -// while using the provided generic op to map: -// * Its arguments to transfer reads from the views of the generic op. -// * linalg.yield ops to transfer writes to the views of the generic op. -class GenericVectorizer { -public: - GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic) - : builder(builder), generic(generic) {} - - // Takes a scalar operation and builds its vectorized counterpart or - // counterparts using the underlying builder. - // If operands of the scalar operation are referring to previously vectorized - // operations, then in their vectorized form these operands will be referring - // to previous vectorization results. - void vectorize(Operation &scalarOp) { - auto yieldOp = dyn_cast(scalarOp); - if (yieldOp) { - for (auto outputs : llvm::enumerate(yieldOp.values())) { - Value vectorValue = vectorize(outputs.value()); - Value result = transferWriteVector(builder, vectorValue, - generic.getOutput(outputs.index())); - if (result) - results.push_back(result); - } - return; - } - Operation *vectorOp = uncachedVectorize(scalarOp); - assert(scalarOp.getNumResults() == vectorOp->getNumResults()); - for (auto result : - llvm::zip(scalarOp.getResults(), vectorOp->getResults())) { - valueCache[std::get<0>(result)] = std::get<1>(result); - } - } - - llvm::ArrayRef getResults() { return results; } - -private: - // Transforms a scalar value into its vectorized counterpart, recursively - // vectorizing operations as necessary using the underlying builder. - // Keeps track of previously vectorized values and reuses vectorization - // results if these values come up again. - Value vectorize(Value scalarValue) { - // Don't vectorize values coming from outside the region. - if (scalarValue.getParentRegion() != &generic.region()) - return scalarValue; - auto vectorValueIt = valueCache.find(scalarValue); - if (vectorValueIt != valueCache.end()) - return vectorValueIt->second; - - // If the value is from the region but not in the cache it means it is a - // block argument. - auto scalarArg = scalarValue.cast(); - assert(scalarArg.getOwner() == &generic.region().front()); - Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber()); - Value vectorResult = transferReadVector(builder, vectorArg); - valueCache[scalarArg] = vectorResult; - return vectorResult; - } - - // Return the largest shape of all the given values. Return an empty - // SmallVector if there are no vector value. - static SmallVector getLargestShape(ArrayRef values) { - SmallVector largestShape; - int64_t maxSize = 1; - for (Value value : values) { - auto vecType = value.getType().dyn_cast(); - if (!vecType) - continue; - if (maxSize < vecType.getNumElements()) { - maxSize = vecType.getNumElements(); - largestShape.assign(vecType.getShape().begin(), - vecType.getShape().end()); - } - } - return largestShape; - } - - // If the value's type doesn't have the given shape broadcast it. - Value broadcastIfNeeded(Value value, ArrayRef shape) { - auto vecType = value.getType().dyn_cast(); - if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) - return value; - auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() - : value.getType()); - return builder.create( - builder.getInsertionPoint()->getLoc(), newVecType, value); - } - - // Takes a scalar operation and builds its vectorized counterpart or - // counterparts using underlying builder without involving any caches. - Operation *uncachedVectorize(Operation &base_scalarOp) { - SmallVector vectorizedOperands; - for (Value operand : base_scalarOp.getOperands()) { - vectorizedOperands.push_back(vectorize(operand)); - } - SmallVector shape = getLargestShape(vectorizedOperands); - for (Value &operand : vectorizedOperands) - operand = broadcastIfNeeded(operand, shape); - OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName()); - state.addAttributes(base_scalarOp.getAttrs()); - state.addOperands(vectorizedOperands); - if (shape.empty()) { - state.addTypes(base_scalarOp.getResultTypes()); - } else { - SmallVector vectorizedTypes; - for (auto Type : base_scalarOp.getResultTypes()) - vectorizedTypes.push_back(VectorType::get(shape, Type)); - state.addTypes(vectorizedTypes); - } - return builder.createOperation(state); - } - - OpBuilder &builder; - linalg::GenericOp generic; - llvm::DenseMap valueCache; - SmallVector results; -}; -} // namespace - -// Replaces elementwise linalg.generic ops with their bodies with scalar -// operations from these bodies promoted to vector operations. -static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) { - GenericVectorizer vectorizer(builder, op); - for (Operation &scalarOp : op.region().front()) { - vectorizer.vectorize(scalarOp); - } - if (!op->getResults().empty()) - op->replaceAllUsesWith(vectorizer.getResults()); -} - LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -327,12 +404,16 @@ return; } + auto linalgOp = cast(op); + Location loc = linalgOp.getLoc(); + if (isElementwise(op)) { LLVM_DEBUG(dbgs() << dbgPref - << "Rewrite linalg op as vector.transfer_read + " - "vector_op + vector.transfer_write: " - << *op); - return vectorizeElementwise(cast(op), builder); + << "Rewrite linalg op as vector.transfer_read + " << *op); + auto status = vectorizeAsLinalgGeneric(builder, linalgOp); + assert(succeeded(status) && + "Unexpected vectorization failed despite preconditions"); + return; } assert(succeeded(isContraction(op)) && "Expected contraction"); @@ -341,15 +422,28 @@ // TODO: interface. LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg op as vector.contract: " << *op); - auto linalgOp = cast(op); - Value a = transferReadVector(builder, linalgOp.getInput(0)); - Value b = transferReadVector(builder, linalgOp.getInput(1)); - Value c = transferReadVector(builder, linalgOp.getOutput(0)); - Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0)); - if (writeResult) - linalgOp->replaceAllUsesWith(ArrayRef(writeResult)); + // Special function that describes how to vectorize the multiplication op in a + // linalg contraction. + CustomVectorizationHook vectorizeContraction = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + if (!isa(op) && !isa(op)) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + auto outShape = linalgOp.getOutputShapedType(0).getShape(); + auto vType = outShape.empty() + ? op->getResult(0).getType() + : VectorType::get(outShape, op->getResult(0).getType()); + auto zero = + builder.create(loc, vType, builder.getZeroAttr(vType)); + Operation *contract = builder.create( + loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, + linalgOp.indexing_maps(), linalgOp.iterator_types()); + return VectorizationResult{VectorizationStatus::NewOp, contract}; + }; + auto status = + vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction}); + assert(succeeded(status) && + "Unexpected vectorization failed despite preconditions"); } /// Check whether there is any interleaved use of any `values` between `firstOp` diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -183,22 +183,23 @@ // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> -// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> -// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> -// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> -// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> -// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> -// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> -// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> -// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> +// +// Any lexicographic ordering that preserves use-def chains is valid. +// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> +// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> +// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> +// CHECK-DAG: %[[V2B:.*]] = vector.broadcast %[[V2]] : vector<256xf32> to vector<4x256xf32> +// CHECK-DAG: %[[ADD:.*]] = addf %[[V2B]], %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> +// CHECK-DAG: %[[DIV:.*]] = divf %[[V0]], %[[ARG3B]] : vector<4x256xf32> +// CHECK-DAG: %[[EXP:.*]] = exp2 %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[MUL:.*]] = mulf %[[V0]], %[[CST0]] : vector<4x256xf32> +// CHECK-DAG: %[[RSQRT:.*]] = rsqrt %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[CMP:.*]] = cmpf ogt, %[[V1]], %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[SEL:.*]] = select %[[CMP]], %[[V0]], %[[V0]] : vector<4x256xi1>, vector<4x256xf32> +// CHECK-DAG: %[[V2B:.*]] = vector.broadcast %[[V2]] : vector<256xf32> to vector<4x256xf32> +// CHECK-DAG: %[[SUB:.*]] = subf %[[V0]], %[[V2B]] : vector<4x256xf32> +// CHECK-DAG: %[[TAN:.*]] = tanh %[[V0]] : vector<4x256xf32> // CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> @@ -267,22 +268,23 @@ // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> -// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> -// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> -// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> -// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> -// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> -// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> -// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> -// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> -// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> -// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> +// +// Any lexicographic ordering that preserves use-def chains is valid. +// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> +// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> +// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> +// CHECK-DAG: %[[V2B:.*]] = vector.broadcast %[[V2]] : vector<256xf32> to vector<4x256xf32> +// CHECK-DAG: %[[ADD:.*]] = addf %[[V2B]], %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[CMP:.*]] = cmpf ogt, %[[V1]], %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> +// CHECK-DAG: %[[DIV:.*]] = divf %[[V0]], %[[ARG3B]] : vector<4x256xf32> +// CHECK-DAG: %[[EXP:.*]] = exp2 %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[MUL:.*]] = mulf %[[V0]], %[[CST0]] : vector<4x256xf32> +// CHECK-DAG: %[[RSQRT:.*]] = rsqrt %[[V0]] : vector<4x256xf32> +// CHECK-DAG: %[[SEL:.*]] = select %[[CMP]], %[[V0]], %[[V0]] : vector<4x256xi1>, vector<4x256xf32> +// CHECK-DAG: %[[V2B:.*]] = vector.broadcast %[[V2]] : vector<256xf32> to vector<4x256xf32> +// CHECK-DAG: %[[SUB:.*]] = subf %[[V0]], %[[V2B]] : vector<4x256xf32> +// CHECK-DAG: %[[TAN:.*]] = tanh %[[V0]] : vector<4x256xf32> // CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> // CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> // CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> @@ -307,10 +309,15 @@ // CHECK-LABEL: func @matmul_tensors // CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, // CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> -// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> -// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32> +// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> +// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> +// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> +// +// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. +// a later canonicalization fuses the add into vector.contract. +// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> +// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32> +// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> // CHECK: return %[[W]] : tensor<8x12xf32>