Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -192,62 +192,15 @@ /// assumes that `reductionOp` has tow operands and one of them is the reduction /// initial value. static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, - Value outputArg, - const SmallVector &reductionMask, - const BlockAndValueMapping &bvm) { + Value valueToReduce, + const SmallVector &reductionMask) { auto maybeKind = getKindForOp(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); - Value operandToReduce = reduceOp->getOperand(0) == outputArg - ? reduceOp->getOperand(1) - : reduceOp->getOperand(0); - Value vec = bvm.lookup(operandToReduce); - return b.create(reduceOp->getLoc(), vec, - reductionMask, *maybeKind); + return b.create( + reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind); } -/// Read the initial value associated to the given `outputOperand`. -static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp, - OpOperand *outputOperand) { - AffineMap map = inversePermutation( - reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand))); - Type readType; - if (linalgOp.getShape(outputOperand).empty()) { - readType = getElementTypeOrSelf(outputOperand->get()); - } else { - readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)), - getElementTypeOrSelf(outputOperand->get())); - } - Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map); - return vectorRead; -} - -/// Assuming `outputOperand` is an output operand of a LinalgOp, determine -/// whether a reduction is needed to produce a `targetType` and create that -/// reduction if it is the case. -static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value, - OpOperand *outputOperand, - const BlockAndValueMapping &bvm) { - LDBG("Reduce " << value << " to type " << targetType); - LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n" - << *(outputOperand->getOwner())); - auto linalgOp = cast(outputOperand->getOwner()); - auto vecType = value.getType().dyn_cast(); - VectorType targetVectorType = targetType.dyn_cast(); - if (!vecType) - return value; - if (targetVectorType && vecType.getShape() == targetVectorType.getShape()) - return value; - - // At this point, we know we need to reduce. Detect the reduction operator. - unsigned pos = 0; - MLIRContext *ctx = b.getContext(); - SmallVector exprs; - for (auto s : linalgOp.iterator_types()) - if (isParallelIterator(s)) - exprs.push_back(getAffineDimExpr(pos++, ctx)); - - Operation *reduceOp = matchLinalgReduction(outputOperand); - assert(reduceOp && "Failed precondition: could not math a reduction"); +static SmallVector getReductionMask(LinalgOp linalgOp) { unsigned idx = 0; SmallVector reductionMask(linalgOp.iterator_types().size(), false); for (auto attr : linalgOp.iterator_types()) { @@ -255,24 +208,7 @@ reductionMask[idx] = true; ++idx; } - assert(reduceOp->getNumOperands() == 2 && - "Only support binary reduce op right now"); - unsigned outputPos = - outputOperand->getOperandNumber() - linalgOp.getNumInputs(); - Value outputArg = linalgOp.getRegionOutputArgs()[outputPos]; - // Reduce across the iteration space. - Value reduce = - buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm); - - // Read the original output value. - Value initialValue = readInitialValue(b, linalgOp, outputOperand); - - // Combine the output argument with the reduced value. - OperationState state(reduceOp->getLoc(), reduceOp->getName()); - state.addAttributes(reduceOp->getAttrs()); - state.addOperands({reduce, initialValue}); - state.addTypes(initialValue.getType()); - return b.createOperation(state)->getResult(0); + return reductionMask; } /// Build a vector.transfer_write of `value` into `outputOperand` at indices set @@ -280,8 +216,7 @@ /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand *outputOperand, - const BlockAndValueMapping &bvm) { + OpOperand *outputOperand) { Operation *write; Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); @@ -296,12 +231,9 @@ SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); - value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm); write = b.create(loc, value, outputOperand->get(), indices, map); } else { - value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand, - bvm); write = vector::TransferWriteOp::createScalarOp( b, loc, value, outputOperand->get(), ValueRange{}); } @@ -336,7 +268,7 @@ // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm); + b, vectorValue, linalgOp.getOutputOperand(outputs.index())); if (newResult) newResults.push_back(newResult); } @@ -379,6 +311,17 @@ return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; } +/// Create a new vectorized verstion of `op` with the given operands and types. +static Operation *createVectorizedOp(OpBuilder &b, Operation *op, + ValueRange newOperands, + ArrayRef types) { + OperationState state(op->getLoc(), op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(newOperands); + state.addTypes(types); + return b.createOperation(state); +} + /// Generic vectorization for a single operation `op`, given already vectorized /// operands carried by `bvm`. Vectorization occurs as follows: /// 1. Try to apply any of the `customVectorizationHooks` and return its @@ -399,7 +342,8 @@ /// This function does not update `bvm` but returns a VectorizationStatus that /// instructs the caller what `bvm` update needs to occur. static VectorizationResult -vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm, +vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, + const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { LDBG("vectorize op " << *op); @@ -422,7 +366,36 @@ if (!OpTrait::hasElementwiseMappableTraits(op)) return VectorizationResult{VectorizationStatus::Failure, nullptr}; - // 4. Generic vectorization path for ElementwiseMappable ops. + // 4 . Check if the operation is a reduction. + for (Value operand : op->getOperands()) { + auto arg = operand.dyn_cast(); + if (!arg || arg.getArgNumber() < linalgOp.getNumInputs()) + continue; + SmallVector reductionOps; + Value reduceValue = matchReduction( + linalgOp.getRegionOutputArgs(), + arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps); + if (!reduceValue) + continue; + Value reduceVec = bvm.lookup(reduceValue); + Value outputVec = bvm.lookup(operand); + auto reduceType = reduceVec.getType().dyn_cast(); + auto outputType = outputVec.getType().dyn_cast(); + // Reduce only if needed as the value may already have been reduce for + // contraction vectorization. + if (!reduceType || + (outputType && reduceType.getShape() == outputType.getShape())) + continue; + SmallVector reductionMask = getReductionMask(linalgOp); + Value reduce = + buildMultiDimReduce(b, reductionOps[0], reduceVec, reductionMask); + // Combine the output argument with the reduced value. + return VectorizationResult{ + VectorizationStatus::NewOp, + createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType())}; + } + + // 5. Generic vectorization path for ElementwiseMappable ops. // a. first get the first max ranked shape. SmallVector firstMaxRankedShape; for (Value operand : op->getOperands()) { @@ -444,12 +417,10 @@ }); // Build and return the new op. - OperationState state(op->getLoc(), op->getName()); - state.addAttributes(op->getAttrs()); - state.addOperands(llvm::to_vector<4>(vectorizedOperands)); - state.addTypes(llvm::to_vector<4>(returnTypes)); - return VectorizationResult{VectorizationStatus::NewOp, - b.createOperation(state)}; + return VectorizationResult{ + VectorizationStatus::NewOp, + createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands), + llvm::to_vector<4>(returnTypes))}; } /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. @@ -544,7 +515,8 @@ if (linalgOp.getShape(opOperand).empty()) { readType = bbarg.getType(); } else { - if (broadcastToMaximalCommonShape) { + if (broadcastToMaximalCommonShape && + opOperand->getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermuation( linalgOp.getTiedIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, @@ -581,7 +553,7 @@ // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block.getOperations()) { - VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks); + VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LDBG("failed to vectorize: " << op); return failure(); Index: mlir/test/Dialect/Linalg/vectorization.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorization.mlir +++ mlir/test/Dialect/Linalg/vectorization.mlir @@ -749,9 +749,9 @@ -> tensor<4x16xf32> { // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> // CHECK: math.exp {{.*}} : vector<4x16x8xf32> // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> // CHECK: addf {{.*}} : vector<4x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> // CHECK: return {{.*}} : tensor<4x16xf32> @@ -782,11 +782,11 @@ { // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32> // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32> // CHECK: addf {{.*}} : vector<2x5xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32>