diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1288,7 +1288,7 @@ OpBuilder<(ins "VectorType":$vector, "Value":$source, "ValueRange":$indices, "AffineMap":$permutationMap, CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets padding to 'getMinorIdentityMap'. + // Builder that sets permutation map to 'getMinorIdentityMap'. OpBuilder<(ins "VectorType":$vector, "Value":$source, "ValueRange":$indices, "Value":$padding, CArg<"ArrayRef", "{}">:$inBounds)>, @@ -1306,6 +1306,17 @@ "ArrayAttr":$inBounds)> ]; + let extraClassDeclaration = [{ + /// Temporary convenience builders to account for the fact that we do not + /// have 0-d vectors atm. These create a constant `vector<1xt>` and + /// insert/extract into it. + // Builder that sets permutation map (resp. padding) to + // 'getMinorIdentityMap' (resp. zero). + static Value createScalarOp(OpBuilder &builder, Location loc, Value source, + ValueRange indices, + ArrayRef inBounds = ArrayRef{}); + }]; + let hasCanonicalizer = 1; let hasFolder = 1; } @@ -1416,11 +1427,12 @@ }]; let builders = [ + // Builder that sets an empty mask. + OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, + "AffineMap":$permutationMap, CArg<"ArrayRef", "{}">:$inBounds)>, // Builder that sets permutation map to 'getMinorIdentityMap'. OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, CArg<"ArrayRef", "{}">:$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap)>, OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>, OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, @@ -1429,6 +1441,18 @@ "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>, ]; + let extraClassDeclaration = [{ + /// Temporary convenience builders to account for the fact that we do not + /// have 0-d vectors atm. These create a constant `vector<1xt>` and + /// insert/extract into it. + // Builder that sets permutation map (resp. padding) to + // 'getMinorIdentityMap' (resp. zero). + static Operation *createScalarOp( + OpBuilder &builder, Location loc, Value value, + Value dest, ValueRange indices, + ArrayRef inBounds = ArrayRef{}); + }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -40,6 +40,9 @@ #define DEBUG_TYPE "linalg-vectorization" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X) + /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. template @@ -106,7 +109,7 @@ /// ShapedType of `v`. static VectorType extractVectorTypeFromShapedValue(Value v) { auto st = v.getType().cast(); - if (st.isa() && st.getShape().empty()) + if (st.getShape().empty()) return VectorType(); return VectorType::get(st.getShape(), st.getElementType()); } @@ -163,16 +166,23 @@ return b.createOrFold(loc, targetVectorType, value); } -/// If value of assumed VectorType has a shape different than `shape`, build and -/// return a new vector.broadcast to `shape`. -/// Otherwise, just return value. -static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType, - Value value, OpOperand *outputOperand) { +/// Assuming `outputOperand` is an output operand of a LinalgOp, determine +/// whether a reduction is needed to produce a `targetType` and create that +/// reduction if it is the case. +static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value, + OpOperand *outputOperand) { + LDBG("Reduce " << value << " to type " << targetType); + LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n" + << *(outputOperand->getOwner())); auto linalgOp = cast(outputOperand->getOwner()); auto vecType = value.getType().dyn_cast(); - if (!vecType || vecType.getShape() == targetVectorType.getShape()) + VectorType targetVectorType = targetType.dyn_cast(); + if (!vecType) + return value; + if (targetVectorType && vecType.getShape() == targetVectorType.getShape()) return value; + // At this point, we know we need to reduce. Detect the reduction operator. unsigned pos = 0; MLIRContext *ctx = b.getContext(); SmallVector exprs; @@ -181,7 +191,6 @@ exprs.push_back(getAffineDimExpr(pos++, ctx)); auto loc = value.getLoc(); - // At this point, we know we need to reduce. Detect the reduction operator. auto maybeKind = matchLinalgReduction(outputOperand); assert(maybeKind && "Failed precondition: could not get reduction kind"); unsigned idx = 0; @@ -196,16 +205,18 @@ } /// Build a vector.transfer_read from `source` at indices set to all `0`. -/// If source has rank zero, build an memref.load. +/// If source has rank zero, build a `vector<1xt> transfer_read + extract`. /// Return the produced value. -static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType, +static Value buildVectorRead(OpBuilder &b, Value source, Type readType, AffineMap map) { Location loc = source.getLoc(); auto shapedType = source.getType().cast(); SmallVector indices(shapedType.getRank(), b.create(loc, 0)); - return b.create(loc, vectorType, source, indices, - map); + if (auto vectorType = readType.dyn_cast()) + return b.create(loc, vectorType, source, indices, + map); + return vector::TransferReadOp::createScalarOp(b, loc, source, indices); } /// Build a vector.transfer_write of `value` into `outputOperand` at indices set @@ -216,13 +227,14 @@ OpOperand *outputOperand) { Operation *write; Location loc = value.getLoc(); + auto linalgOp = cast(outputOperand->getOwner()); if (VectorType vectorType = extractVectorTypeFromShapedValue(outputOperand->get())) { - auto linalgOp = cast(outputOperand->getOwner()); AffineMap map = reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)); SmallVector transposeShape = applyPermutationMap(inversePermutation(map), vectorType.getShape()); + assert(!transposeShape.empty() && "unexpected empty transpose shape"); vectorType = VectorType::get(transposeShape, vectorType.getElementType()); SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); @@ -231,9 +243,12 @@ write = b.create(loc, value, outputOperand->get(), indices, map); } else { - write = b.create(loc, value, outputOperand->get()); + value = + reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand); + write = vector::TransferWriteOp::createScalarOp( + b, loc, value, outputOperand->get(), ValueRange{}); } - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); + LDBG("vectorized op: " << *write); if (!write->getResults().empty()) return write->getResult(0); return Value(); @@ -329,7 +344,7 @@ static VectorizationResult vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op); + LDBG("vectorize op " << *op); // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { @@ -466,33 +481,27 @@ continue; } // TODO: 0-d vectors. - if (linalgOp.getShape(opOperand).empty()) { - Value loaded = - b.create(linalgOp.getLoc(), opOperand->get()); - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" - << bbarg.getArgNumber() << "): " << loaded); - bvm.map(bbarg, loaded); - bvm.map(opOperand->get(), loaded); - continue; - } + Type readType; AffineMap map; - VectorType vectorType; - if (broadcastToMaximalCommonShape) { - map = inverseAndBroadcastProjectedPermuation( - linalgOp.getTiedIndexingMap(opOperand)); - vectorType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); + if (linalgOp.getShape(opOperand).empty()) { + readType = bbarg.getType(); } else { - map = inversePermutation( - reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); - vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), + if (broadcastToMaximalCommonShape) { + map = inverseAndBroadcastProjectedPermuation( + linalgOp.getTiedIndexingMap(opOperand)); + readType = VectorType::get(commonVectorShape, + getElementTypeOrSelf(opOperand->get())); + } else { + map = inversePermutation( + reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); + readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), getElementTypeOrSelf(opOperand->get())); + } } - Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map); - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" - << bbarg.getArgNumber() << "): " << vectorRead); - bvm.map(bbarg, vectorRead); - bvm.map(opOperand->get(), vectorRead); + Value readValue = buildVectorRead(b, opOperand->get(), readType, map); + LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); + bvm.map(bbarg, readValue); + bvm.map(opOperand->get(), readValue); } auto hooks = llvm::to_vector<4>(customVectorizationHooks); @@ -516,12 +525,11 @@ for (Operation &op : block.getOperations()) { VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op); + LDBG("failed to vectorize: " << op); return failure(); } if (result.status == VectorizationStatus::NewOp) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " - << *result.newOp;); + LDBG("new vector op: " << *result.newOp;); bvm.map(op.getResults(), result.newOp->getResults()); } } @@ -536,9 +544,9 @@ Location loc = linalgOp.getLoc(); // Vectorize other ops as vector contraction. // TODO: interface. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Rewrite linalg op as vector.contract: "; - linalgOp.dump()); + LDBG("" + << "Rewrite linalg op as vector.contract: "; + linalgOp.dump()); // Special function that describes how to vectorize the multiplication op in a // linalg contraction. CustomVectorizationHook vectorizeContraction = @@ -592,11 +600,15 @@ // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { - if (llvm::none_of(op.iterator_types(), isReductionIterator)) + if (llvm::none_of(op.iterator_types(), isReductionIterator)) { + LDBG("reduction precondition failed: no reduction iterator"); return failure(); + } for (OpOperand *opOperand : op.getOutputOperands()) { - if (!matchLinalgReduction(opOperand)) + if (!matchLinalgReduction(opOperand)) { + LDBG("reduction precondition failed: reduction detection failed"); return failure(); + } } return success(); } @@ -604,8 +616,10 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. - if (linalgOp.hasDynamicShape()) + if (linalgOp.hasDynamicShape()) { + LDBG("precondition failed: dynamic shape"); return failure(); + } if (isElementwise(op)) return success(); if (isaContractionOpInterface(linalgOp)) @@ -613,10 +627,15 @@ // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - if (allIndexingsAreProjectedPermutation(linalgOp) && - succeeded(reductionPreconditions(linalgOp))) - return success(); - return failure(); + if (!allIndexingsAreProjectedPermutation(linalgOp)) { + LDBG("precondition failed: not projected permutations"); + return failure(); + } + if (failed(reductionPreconditions(linalgOp))) { + LDBG("precondition failed: reduction preconditions"); + return failure(); + } + return success(); } LogicalResult @@ -629,10 +648,10 @@ if (isaContractionOpInterface(linalgOp)) return vectorizeContraction(b, linalgOp, newResults); - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Vectorize linalg op as a generic by broadcasting to " - "maximal common shape: " - << *op); + LDBG("" + << "Vectorize linalg op as a generic by broadcasting to " + "maximal common shape: " + << *op); return vectorizeAsLinalgGeneric(b, linalgOp, newResults, /*broadcastToMaximalCommonShape=*/true); } @@ -1200,9 +1219,8 @@ ValueRange values) { if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " - << "interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp); + LDBG("interleavedUses precondition failed, firstOp: " + << *firstOp << ", second op: " << *secondOp); return true; } for (auto v : values) { @@ -1214,10 +1232,8 @@ if (owner->getBlock() == firstOp->getBlock() && (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; - LLVM_DEBUG(llvm::dbgs() - << "\n[" DEBUG_TYPE "]: " - << " found interleaved op " << *owner - << ", firstOp: " << *firstOp << ", second op: " << *secondOp); + LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp + << ", second op: " << *secondOp); return true; } } @@ -1248,15 +1264,14 @@ !viewOrAlloc.getDefiningOp()) return failure(); - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc); + LDBG(viewOrAlloc); // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); if (!subViewOp) return failure(); Value subView = subViewOp.getResult(); - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " - << "with subView " << subView); + LDBG("with subView " << subView); // Find the copy into `subView` without interleaved uses. CopyOp copyOp; @@ -1265,8 +1280,7 @@ assert(newCopyOp.output().getType().isa()); if (newCopyOp.output() != subView) continue; - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " - << "copy candidate " << *newCopyOp); + LDBG("copy candidate " << *newCopyOp); if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) continue; copyOp = newCopyOp; @@ -1275,8 +1289,7 @@ } if (!copyOp) return failure(); - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " - << "with copy " << *copyOp); + LDBG("with copy " << *copyOp); // Find the fill into `viewOrAlloc` without interleaved uses before the copy. FillOp maybeFillOp; @@ -1285,8 +1298,7 @@ assert(newFillOp.output().getType().isa()); if (newFillOp.output() != viewOrAlloc) continue; - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " - << "fill candidate " << *newFillOp); + LDBG("fill candidate " << *newFillOp); if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) continue; maybeFillOp = newFillOp; @@ -1297,8 +1309,7 @@ if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) return failure(); if (maybeFillOp) - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " - << "with maybeFillOp " << *maybeFillOp); + LDBG("with maybeFillOp " << *maybeFillOp); // `in` is the subview that linalg.copy reads. Replace it. Value in = copyOp.input(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2439,6 +2439,18 @@ /*mask=*/Value(), inBounds); } +Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc, + Value source, ValueRange indices, + ArrayRef inBounds) { + Type elemType = source.getType().cast().getElementType(); + auto vectorType = VectorType::get(ArrayRef{1}, elemType); + AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0, + getAffineConstantExpr(0, loc.getContext())); + Value read = builder.create(loc, vectorType, source, + indices, map, inBounds); + return builder.create(loc, read, ArrayRef{0}); +} + static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); @@ -2769,6 +2781,16 @@ // TransferWriteOp //===----------------------------------------------------------------------===// +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value dest, ValueRange indices, + AffineMap permutationMap, ArrayRef inBounds) { + if (inBounds.empty()) + return build(builder, result, vector, dest, indices, permutationMap, + /*mask=*/Value(), ArrayAttr()); + build(builder, result, vector, dest, indices, permutationMap, + /*mask=*/Value(), builder.getBoolArrayAttr(inBounds)); +} + /// Builder that sets permutation map to 'getMinorIdentityMap'. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value source, ValueRange indices, @@ -2783,13 +2805,6 @@ build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr); } -void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - AffineMap permutationMap) { - build(builder, result, vector, source, indices, permutationMap, - /*inBounds=*/ArrayAttr()); -} - void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value source, ValueRange indices, AffineMapAttr permutationMap, @@ -2817,6 +2832,20 @@ mask, inBounds); } +Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc, + Value value, Value dest, + ValueRange indices, + ArrayRef inBounds) { + Value vectorOfAScalar = value; + if (!value.getType().isa()) + vectorOfAScalar = builder.create( + loc, VectorType::get({1}, value.getType()), value); + AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0, + getAffineConstantExpr(0, loc.getContext())); + return builder.create(loc, vectorOfAScalar, dest, + indices, map, inBounds); +} + static ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -203,8 +203,9 @@ // CHECK-LABEL: func @test_vectorize_fill func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { - // CHECK-SAME: (%[[M:.*]]: memref, %[[V:.*]]: f32) - // CHECK: store %[[V]], %[[M]][] : memref + // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) + // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32> + // CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref linalg.fill(%arg0, %A) : f32, memref return } @@ -223,8 +224,11 @@ // CHECK-LABEL: func @test_vectorize_copy_scalar func @test_vectorize_copy_scalar(%A : memref, %B : memref) { - // CHECK: %[[V:.*]] = memref.load {{.*}} : memref - // CHECK: store %[[V]], {{.*}} : memref + // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) + // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector<1xf32> + // CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32> + // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32> + // CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref linalg.copy(%A, %B) : memref, memref return } @@ -857,3 +861,42 @@ return %red : tensor<4xf32> } +// ----- + +// CHECK-LABEL: func @reduce_1d( +// CHECK-SAME: %[[A:.*]]: tensor<32xf32> +func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { + // CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32> + // CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + %f0 = constant 0.000000e+00 : f32 + + // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor + %0 = linalg.init_tensor [] : tensor + + // CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][] + // CHECK-SAME: : vector<1xf32>, tensor + %1 = linalg.fill(%f0, %0) : f32, tensor -> tensor + + // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] + // CHECK-SAME: : tensor<32xf32>, vector<32xf32> + // CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32> + // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[a]] [0] + // CHECK-SAME: : vector<32xf32> to f32 + // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32> + // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] + // CHECK-SAME: : vector<1xf32>, tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor<32xf32>) + outs(%1 : tensor) { + ^bb0(%a: f32, %b: f32): // no predecessors + %3 = addf %a, %b : f32 + linalg.yield %3 : f32 + } -> tensor + + return %2 : tensor +} +