Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -238,6 +238,79 @@ builder.createOperation(state)}; } +/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. +static bool hasOnlyScalarElementwiseOp(Region &r) { + if (!llvm::hasSingleElement(r)) + return false; + for (Operation &op : r.front()) { + if (!(isa(op) || + OpTrait::hasElementwiseMappableTraits(&op)) || + llvm::any_of(op.getResultTypes(), + [](Type type) { return !type.isIntOrIndexOrFloat(); })) + return false; + } + return true; +} + +// Return true if the op is an element-wise linalg op. +static bool isElementwise(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return false; + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) + return false; + // TODO: relax the restrictions on indexing map. + for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { + if (!linalgOp.getOutputIndexingMap(i).isIdentity()) + return false; + } + if (linalgOp->getNumRegions() != 1) + return false; + return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); +} + +// Broadcast and transpose a linalgOp argument to a shape matching the linalg +// output shape. +static Value broadcastAndTranspose(OpBuilder &builder, Value vectorRead, + LinalgOp linalgOp, unsigned argIndex) { + Location loc = linalgOp.getLoc(); + // Currently assume we don't support output permutations. + assert(linalgOp.getNumOutputs() > 0 && + linalgOp.getOutputIndexingMap(0).isIdentity()); + ArrayRef outputShape = linalgOp.getOutputShapedType(0).getShape(); + AffineMap map = linalgOp.getIndexingMap(argIndex); + unsigned numMissingDim = map.getNumInputs() - map.getNumResults(); + SmallVector permutation(map.getNumInputs()); + llvm::SmallVector existingDim(map.getNumInputs(), false); + for (unsigned i : llvm::seq(0, map.getNumResults())) { + unsigned pos = map.getDimPosition(i); + existingDim[pos] = true; + permutation[pos] = numMissingDim + i; + } + // If there are missing dimensions first broadcast the vector. + if (numMissingDim > 0) { + // Insert the missing dimensions as high rank dimensions. + unsigned index = 0; + SmallVector broadcastShape; + for (unsigned i : llvm::seq(0, map.getNumInputs())) { + if (!existingDim[i]) { + broadcastShape.push_back(outputShape[i]); + permutation[i] = index++; + } + } + auto sourceType = vectorRead.getType().cast(); + broadcastShape.append(sourceType.getShape().begin(), + sourceType.getShape().end()); + auto newVecType = + VectorType::get(broadcastShape, sourceType.getElementType()); + vectorRead = + builder.create(loc, newVecType, vectorRead); + } + Value transposeOp = + builder.create(loc, vectorRead, permutation); + return transposeOp; +} + /// Generic vectorization function that rewrites the body of a `linalgOp` into /// vector form. Generic vectorization proceeds as follows: /// 1. The region for the linalg op is created if necessary. @@ -285,6 +358,11 @@ Value vectorRead = buildVectorRead(builder, vectorArg); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" << bbarg.getArgNumber() << "): " << vectorRead); + // Apply broadcast and transpose if needed. + if (isElementwise(linalgOp) && + !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) + vectorRead = broadcastAndTranspose(builder, vectorRead, linalgOp, + bbarg.getArgNumber()); bvm.map(bbarg, vectorRead); bvm.map(vectorArg, vectorRead); } @@ -316,44 +394,6 @@ return success(); } -/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. -static bool hasOnlyScalarElementwiseOp(Region &r) { - if (!llvm::hasSingleElement(r)) - return false; - for (Operation &op : r.front()) { - if (!(isa(op) || - OpTrait::hasElementwiseMappableTraits(&op)) || - llvm::any_of(op.getResultTypes(), - [](Type type) { return !type.isIntOrIndexOrFloat(); })) - return false; - } - return true; -} - -// Return true if the op is an element-wise linalg op. -static bool isElementwise(Operation *op) { - auto linalgOp = dyn_cast(op); - if (!linalgOp) - return false; - if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) - return false; - // TODO: relax the restrictions on indexing map. - for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { - if (!linalgOp.getOutputIndexingMap(i).isIdentity()) - return false; - } - // Currently bound the input indexing map to minor identity as other - // permutations might require adding transpose ops to convert the vector read - // to the right shape. - for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) { - if (!linalgOp.getInputIndexingMap(i).isMinorIdentity()) - return false; - } - if (linalgOp->getNumRegions() != 1) - return false; - return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); -} - static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults) { assert(isaContractionOpInterface(linalgOp) && Index: mlir/test/Dialect/Linalg/vectorization.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorization.mlir +++ mlir/test/Dialect/Linalg/vectorization.mlir @@ -341,6 +341,44 @@ // ----- +// Test different input maps. +#matmul_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d0)>, + affine_map<(d0, d1, d2, d3) -> (d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] +} +// CHECK-LABEL: func @vectorization_transpose +// CHECK: vector.transfer_read {{.*}} : memref<14x7xf32>, vector<14x7xf32> +// CHECK: vector.broadcast {{.*}} : vector<14x7xf32> to vector<8x16x14x7xf32> +// CHECK: vector.transpose {{.*}}, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}} : memref<16x14xf32>, vector<16x14xf32> +// CHECK: vector.broadcast {{.*}} : vector<16x14xf32> to vector<7x8x16x14xf32> +// CHECK: vector.transpose {{.*}}, [0, 3, 1, 2] : vector<7x8x16x14xf32> to vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}} : memref<16x14x7x8xf32>, vector<16x14x7x8xf32> +// CHECK: vector.transpose {{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> +// CHECK: addf {{.*}} : vector<7x14x8x16xf32> +// CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32> +func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>, + %C: memref<16x14x7x8xf32>, %D: memref<7x14x8x16xf32>) { + linalg.generic #matmul_trait + ins(%A, %B, %C : memref<14x7xf32>, memref<16x14xf32>, memref<16x14x7x8xf32>) + outs(%D : memref<7x14x8x16xf32>) { + ^bb(%a: f32, %b: f32, %c: f32, %d: f32) : + %e = addf %a, %b: f32 + %f = addf %e, %c: f32 + linalg.yield %f : f32 + } + return +} + +// ----- + // CHECK-LABEL: func @matmul_tensors // CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, // CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>