Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -87,11 +87,14 @@ /// Build a vector.transfer_read from `source` at indices set to all `0`. /// If source has rank zero, build an memref.load. /// Return the produced value. -static Value buildVectorRead(OpBuilder &builder, Value source) { +static Value buildVectorRead(OpBuilder &builder, Value source, + VectorType vectorType, AffineMap map) { edsc::ScopedContext scope(builder); auto shapedType = source.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { + if (vectorType) { SmallVector indices(shapedType.getRank(), std_constant_index(0)); + if (map) + return vector_transfer_read(vectorType, source, indices, map); return vector_transfer_read(vectorType, source, indices); } return memref_load(source); @@ -238,6 +241,51 @@ builder.createOperation(state)}; } +/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. +static bool hasOnlyScalarElementwiseOp(Region &r) { + if (!llvm::hasSingleElement(r)) + return false; + for (Operation &op : r.front()) { + if (!(isa(op) || + OpTrait::hasElementwiseMappableTraits(&op)) || + llvm::any_of(op.getResultTypes(), + [](Type type) { return !type.isIntOrIndexOrFloat(); })) + return false; + } + return true; +} + +// Return true if the op is an element-wise linalg op. +static bool isElementwise(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return false; + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) + return false; + // TODO: relax the restrictions on indexing map. + for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { + if (!linalgOp.getOutputIndexingMap(i).isIdentity()) + return false; + } + if (linalgOp->getNumRegions() != 1) + return false; + return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); +} + +// Calculate the map to apply to transfer_read to convert the input shape into +// the output shape. +static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) { + AffineMap linalgMap = linalgOp.getIndexingMap(argIndex); + MLIRContext *context = linalgMap.getContext(); + AffineExpr zero = mlir::getAffineConstantExpr(0, context); + SmallVector exprs(linalgMap.getNumInputs(), zero); + for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) { + exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context); + } + return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs, + context); +} + /// Generic vectorization function that rewrites the body of a `linalgOp` into /// vector form. Generic vectorization proceeds as follows: /// 1. The region for the linalg op is created if necessary. @@ -282,7 +330,19 @@ SmallVector indexings; for (auto bbarg : block->getArguments()) { Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); - Value vectorRead = buildVectorRead(builder, vectorArg); + AffineMap map; + VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg); + if (isElementwise(linalgOp) && + !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) { + // Currently assume we don't support output permutations. + assert(linalgOp.getNumOutputs() > 0 && + linalgOp.getOutputIndexingMap(0).isIdentity()); + ArrayRef outputShape = + linalgOp.getOutputShapedType(0).getShape(); + vectorType = VectorType::get(outputShape, vectorType.getElementType()); + map = getTransferReadMap(linalgOp, bbarg.getArgNumber()); + } + Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" << bbarg.getArgNumber() << "): " << vectorRead); bvm.map(bbarg, vectorRead); @@ -316,44 +376,6 @@ return success(); } -/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. -static bool hasOnlyScalarElementwiseOp(Region &r) { - if (!llvm::hasSingleElement(r)) - return false; - for (Operation &op : r.front()) { - if (!(isa(op) || - OpTrait::hasElementwiseMappableTraits(&op)) || - llvm::any_of(op.getResultTypes(), - [](Type type) { return !type.isIntOrIndexOrFloat(); })) - return false; - } - return true; -} - -// Return true if the op is an element-wise linalg op. -static bool isElementwise(Operation *op) { - auto linalgOp = dyn_cast(op); - if (!linalgOp) - return false; - if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) - return false; - // TODO: relax the restrictions on indexing map. - for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { - if (!linalgOp.getOutputIndexingMap(i).isIdentity()) - return false; - } - // Currently bound the input indexing map to minor identity as other - // permutations might require adding transpose ops to convert the vector read - // to the right shape. - for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) { - if (!linalgOp.getInputIndexingMap(i).isMinorIdentity()) - return false; - } - if (linalgOp->getNumRegions() != 1) - return false; - return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); -} - static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults) { assert(isaContractionOpInterface(linalgOp) && Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2294,8 +2294,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; - if (op.permutation_map() == - getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType())) + if (op.permutation_map().isMinorIdentity()) elidedAttrs.push_back(op.getPermutationMapAttrName()); bool elideMasked = true; if (auto maybeMasked = op.masked()) { Index: mlir/lib/IR/AffineMap.cpp =================================================================== --- mlir/lib/IR/AffineMap.cpp +++ mlir/lib/IR/AffineMap.cpp @@ -106,8 +106,9 @@ } bool AffineMap::isMinorIdentity() const { - return *this == - getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); + return getNumDims() >= getNumResults() && + *this == + getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); } /// Returns true if this affine map is a minor identity up to broadcasted Index: mlir/test/Dialect/Linalg/vectorization.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorization.mlir +++ mlir/test/Dialect/Linalg/vectorization.mlir @@ -341,6 +341,42 @@ // ----- +// Test different input maps. +#matmul_trait = { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d0)>, + affine_map<(d0, d1, d2, d3) -> (d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] +} + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (0, d1, 0, d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> +// CHECK: func @vectorization_transpose +// CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP1]]} : memref<16x14xf32>, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP2]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32> +// CHECK: addf {{.*}} : vector<7x14x8x16xf32> +// CHECK: addf {{.*}} : vector<7x14x8x16xf32> +// CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32> +func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>, + %C: memref<16x14x7x8xf32>, %D: memref<7x14x8x16xf32>) { + linalg.generic #matmul_trait + ins(%A, %B, %C : memref<14x7xf32>, memref<16x14xf32>, memref<16x14x7x8xf32>) + outs(%D : memref<7x14x8x16xf32>) { + ^bb(%a: f32, %b: f32, %c: f32, %d: f32) : + %e = addf %a, %b: f32 + %f = addf %e, %c: f32 + linalg.yield %f : f32 + } + return +} + +// ----- + // CHECK-LABEL: func @matmul_tensors // CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, // CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>