diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -119,34 +119,38 @@ return hasOnlyScalarElementwiseOp(genericOp.getRegion()); } -static VectorType extractVectorTypeFromScalarView(Value v) { - MemRefType mt = v.getType().cast(); - return mt.getShape().empty() - ? VectorType() - : VectorType::get(mt.getShape(), mt.getElementType()); +static VectorType extractVectorTypeFromShapedValue(Value v) { + auto st = v.getType().cast(); + if (st.isa() && st.getShape().empty()) + return VectorType(); + return VectorType::get(st.getShape(), st.getElementType()); } -static Value transferReadVector(OpBuilder &builder, Value memref) { +static Value transferReadVector(OpBuilder &builder, Value source) { edsc::ScopedContext scope(builder); - auto memrefType = memref.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { - SmallVector indices(memrefType.getRank(), std_constant_index(0)); - return vector_transfer_read(vectorType, memref, indices); + auto shapedType = source.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { + SmallVector indices(shapedType.getRank(), std_constant_index(0)); + return vector_transfer_read(vectorType, source, indices); } - return std_load(memref); + return std_load(source); } -static void transferWriteVector(OpBuilder &builder, Value value, Value memref) { +static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) { edsc::ScopedContext scope(builder); - auto memrefType = memref.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { - SmallVector indices(memrefType.getRank(), std_constant_index(0)); + Operation *write; + auto shapedType = dest.getType().cast(); + if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { + SmallVector indices(shapedType.getRank(), std_constant_index(0)); if (vectorType != value.getType()) value = vector_broadcast(vectorType, value); - vector_transfer_write(value, memref, indices); + write = vector_transfer_write(value, dest, indices); } else { - std_store(value, memref); + write = std_store(value, dest); } + if (!write->getResults().empty()) + return write->getResult(0); + return Value(); } namespace { @@ -167,10 +171,12 @@ void vectorize(Operation &scalarOp) { auto yieldOp = dyn_cast(scalarOp); if (yieldOp) { - for (auto outputAndMemref : - llvm::zip(yieldOp.values(), generic.getOutputBuffers())) { - Value vectorValue = vectorize(std::get<0>(outputAndMemref)); - transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref)); + for (auto outputs : llvm::enumerate(yieldOp.values())) { + Value vectorValue = vectorize(outputs.value()); + Value result = transferWriteVector(builder, vectorValue, + generic.getOutput(outputs.index())); + if (result) + results.push_back(result); } return; } @@ -182,6 +188,8 @@ } } + llvm::ArrayRef getResults() { return results; } + private: // Transforms a scalar value into its vectorized counterpart, recursively // vectorizing operations as necessary using the underlying builder. @@ -261,6 +269,7 @@ OpBuilder &builder; linalg::GenericOp generic; llvm::DenseMap valueCache; + SmallVector results; }; } // namespace @@ -271,6 +280,8 @@ for (Operation &scalarOp : op.region().front()) { vectorizer.vectorize(scalarOp); } + if (!op->getResults().empty()) + op->replaceAllUsesWith(vectorizer.getResults()); } LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { @@ -331,32 +342,14 @@ LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg op as vector.contract: " << *op); auto linalgOp = cast(op); - Value viewA = linalgOp.getInput(0); - Value viewB = linalgOp.getInput(1); - Value viewC = linalgOp.getOutputBuffer(0); - VectorType vtA = extractVectorTypeFromScalarView(viewA); - VectorType vtB = extractVectorTypeFromScalarView(viewB); - VectorType vtC = extractVectorTypeFromScalarView(viewC); - Value zero = std_constant_index(0); - SmallVector indicesA, indicesB, indicesC; - if (vtA) - indicesA = SmallVector(vtA.getRank(), zero); - if (vtB) - indicesB = SmallVector(vtB.getRank(), zero); - if (vtC) - indicesC = SmallVector(vtC.getRank(), zero); - Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value - : std_load(viewA, indicesA).value; - Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value - : std_load(viewB, indicesB).value; - Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value - : std_load(viewC, indicesC).value; + Value a = transferReadVector(builder, linalgOp.getInput(0)); + Value b = transferReadVector(builder, linalgOp.getInput(1)); + Value c = transferReadVector(builder, linalgOp.getOutput(0)); Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), linalgOp.iterator_types()); - if (vtC) - vector_transfer_write(res, viewC, indicesC); - else - std_store(res, viewC, indicesC); + Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0)); + if (writeResult) + linalgOp->replaceAllUsesWith(ArrayRef(writeResult)); } /// Check whether there is any interleaved use of any `values` between `firstOp` diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2039,28 +2039,28 @@ /// Builder that sets padding to zero. void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vector, Value memref, ValueRange indices, + VectorType vector, Value source, ValueRange indices, AffineMap permutationMap, ArrayRef maybeMasked) { - Type elemType = memref.getType().cast().getElementType(); + Type elemType = source.getType().cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); if (maybeMasked.empty()) - return build(builder, result, vector, memref, indices, permutationMap, + return build(builder, result, vector, source, indices, permutationMap, padding, ArrayAttr()); ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); - build(builder, result, vector, memref, indices, permutationMap, padding, + build(builder, result, vector, source, indices, permutationMap, padding, maskedArrayAttr); } /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap' /// (resp. zero). void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value memref, + VectorType vectorType, Value source, ValueRange indices, ArrayRef maybeMasked) { auto permMap = getTransferMinorIdentityMap( - memref.getType().cast(), vectorType); - build(builder, result, vectorType, memref, indices, permMap, maybeMasked); + source.getType().cast(), vectorType); + build(builder, result, vectorType, source, indices, permMap, maybeMasked); } static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { @@ -2251,7 +2251,7 @@ ArrayRef maybeMasked) { auto vectorType = vector.getType().cast(); auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); + source.getType().cast(), vectorType); if (maybeMasked.empty()) return build(builder, result, vector, source, indices, permMap, ArrayAttr()); @@ -2327,7 +2327,7 @@ } static LogicalResult verify(TransferWriteOp op) { - // Consistency of elemental types in memref and vector. + // Consistency of elemental types in shape and vector. ShapedType shapedType = op.getShapedType(); VectorType vectorType = op.getVectorType(); auto permutationMap = op.permutation_map(); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -189,7 +189,7 @@ // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> @@ -209,3 +209,108 @@ // CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> + +func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>, + %arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>, + %i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) { + %c1_f32 = constant 1.0 : f32 + %r:10 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1, %arg2: tensor<4x256xf32>, tensor<256xf32>) + outs( + %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 : + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, + %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32, + %arg14 : f32): + %6 = addf %arg4, %arg6 : f32 + %7 = cmpf "ogt", %arg3, %arg6 : f32 + %8 = constant 2.0 : f32 + %9 = divf %arg5, %i : f32 + %10 = exp2 %arg5 : f32 + %11 = mulf %arg5, %8 : f32 + %12 = rsqrt %arg5 : f32 + %13 = select %7, %arg5, %arg6 : f32 + %14 = subf %arg5, %arg4 : f32 + %15 = tanh %arg5 : f32 + linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32 + } -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> + return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9: + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> +} + +// CHECK-LABEL: func @generic_vectorize_tensor +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32) +// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32> +// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> +// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> +// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> +// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> +// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> +// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32> +// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> +// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> +// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> +// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> +// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> +// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> +// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> +// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> +// CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> + +func @matmul_tensors( + %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>) + -> tensor<8x12xf32> { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) + outs(%arg2: tensor<8x12xf32>) + -> tensor<8x12xf32> + return %0 : tensor<8x12xf32> +} + +// CHECK-LABEL: func @matmul_tensors +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> +// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> +// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> +// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32> +// CHECK: return %[[W]] : tensor<8x12xf32>