diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -961,7 +961,8 @@ "{}">:$transpose_paddings, DefaultValuedAttr:$copy_back_op); let results = (outs TransformHandleTypeInterface:$padded, - TransformHandleTypeInterface:$pad); + TransformHandleTypeInterface:$pad, + TransformHandleTypeInterface:$copy); let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1612,7 +1612,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - SmallVector paddedOps, padOps; + SmallVector paddedOps, padOps, copyBackOps; for (Operation *target : state.getPayloadOps(getTarget())) { auto linalgTarget = dyn_cast(target); @@ -1707,10 +1707,18 @@ rewriter.replaceOp(linalgTarget, replacements); paddedOps.push_back(paddedOp); padOps.append(newPadOps.begin(), newPadOps.end()); + if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) { + for (Value v : replacements) { + Operation *copyBackOp = v.getDefiningOp(); + if (llvm::find(copyBackOps, copyBackOp) == copyBackOps.end()) + copyBackOps.push_back(copyBackOp); + } + } } results.set(cast(getPadded()), paddedOps); results.set(cast(getPad()), padOps); + results.set(cast(getCopy()), copyBackOps); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -461,18 +461,20 @@ AnalysisState state(bufferizationOptions); #ifndef NDEBUG - // Ops with nested tensor ops are not supported yet. At the moment, this - // function just bufferizes the given op itself, but not its body. - op->walk([&](Operation *nestedOp) { - if (op == nestedOp) - return; - if (llvm::any_of(nestedOp->getOperands(), - [](Value v) { return v.getType().isa(); })) - llvm_unreachable("ops with nested tensor ops are not supported yet"); - if (llvm::any_of(nestedOp->getResults(), - [](Value v) { return v.getType().isa(); })) - llvm_unreachable("ops with nested tensor ops are not supported yet"); - }); + if (!options.bufferizeDestinationOnly) { + // Ops with nested tensor ops are not supported yet. At the moment, this + // function just bufferizes the given op itself, but not its body. + op->walk([&](Operation *nestedOp) { + if (op == nestedOp) + return; + if (llvm::any_of(nestedOp->getOperands(), + [](Value v) { return v.getType().isa(); })) + llvm_unreachable("ops with nested tensor ops are not supported yet"); + if (llvm::any_of(nestedOp->getResults(), + [](Value v) { return v.getType().isa(); })) + llvm_unreachable("ops with nested tensor ops are not supported yet"); + }); + } #endif // NDEBUG // Gather tensor results. diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir @@ -0,0 +1,110 @@ + +// RUN: mlir-opt --test-transform-dialect-interpreter="enable-expensive-checks=1" %s | FileCheck %s + +// CHECK-LABEL: func @matmul_divisible +// CHECK: scf.forall +// CHECK-NOT: memref.copy +// CHECK: linalg.fill +// CHECK: scf.for +// CHECK: memref.alloc() : memref<128x16xf32, 3> +// CHECK: scf.forall +// CHECK: vector.create_mask +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write +// CHECK: memref.alloc() : memref<16x128xf32, 3> +// CHECK: scf.forall +// CHECK: vector.create_mask +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write +// CHECK: memref.alloc() : memref<128x128xf32, 3> +// CHECK: scf.forall +// CHECK: vector.create_mask +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write +// CHECK: linalg.matmul +// CHECK: scf.forall +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write +func.func @matmul_divisible(%A: tensor<1024x1024xf32>, + %B: tensor<1024x1024xf32>, + %C: tensor<1024x1024xf32>) + -> tensor<1024x1024xf32> +{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) + outs(%C : tensor<1024x1024xf32>) + -> tensor<1024x1024xf32> + %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>) + outs(%0 : tensor<1024x1024xf32>) + -> tensor<1024x1024xf32> + return %1 : tensor<1024x1024xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + // Fuse linalg.fill into linalg.matmul and tile. + %matmul_op = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %forall_op, %tiled_matmul_op = transform.structured.tile_to_forall_op %matmul_op num_threads [] tile_sizes [128, 128](mapping = [#gpu.block, #gpu.block]) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %fill_op into %forall_op + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile linalg.matmul a second time. + %tiled_linalg_op, %loops = transform.structured.tile %tiled_matmul_op[0, 0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Pad linalg.matmul. + %padded, %pad, %copy_back = transform.structured.pad %tiled_linalg_op + {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 1], + copy_back_op = "linalg.copy"} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Map and tile tensor.pad. + %pad_forall_op, %tiled_pad_op = transform.structured.gpu.map_copy_to_threads + %pad total_num_threads = 32 desired_bit_alignment = 128 + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Map and tile copy back. + %copy_forall_op, %tiled_copy_op = transform.structured.gpu.map_copy_to_threads + %copy_back total_num_threads = 32 desired_bit_alignment = 128 + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Apply masked vectorization to padding ops. + transform.structured.masked_vectorize %tiled_pad_op vector_sizes [128, 4] + : !transform.any_op + + // Assign shared memory buffer to padding. + %buffer, %new_ops = transform.structured.bufferize_to_allocation + %pad_forall_op {memory_space = 3, bufferize_destination_only} + : !transform.any_op + + // Bufferize. + %func_op_1 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.bufferization.eliminate_empty_tensors %func_op_1 : !transform.any_op + transform.apply_dce to %func_op_1 : !transform.any_op + transform.apply_cse to %func_op_1 : !transform.any_op + %bufferized = transform.bufferization.one_shot_bufferize + layout{IdentityLayoutMap} %arg1 {bufferize_function_boundaries=true} + : (!transform.any_op) -> !transform.any_op + + // Apply vectorization to copy back from shared memory. + // TODO: Find a way to retain the handle to linalg.copy throughout + // bufferization. + %func_op_2 = transform.structured.match ops{["func.func"]} in %bufferized + : (!transform.any_op) -> !transform.any_op + %bufferized_copy_back = transform.structured.match ops{["linalg.copy"]} in %func_op_2 + : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize + %bufferized_copy_back vector_sizes [128, 4] : !transform.any_op + + // Canonicalize, cleanup and vector lowering. This step also removes buffer + // self-copies. + transform.apply_patterns to %func_op_2 { + transform.apply_patterns.canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } {apply_cse} : !transform.any_op +}