diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -900,6 +900,7 @@ $dst`,` $src `:` type($dst)`,` type($src) attr-dict }]; let verifier = [{ return ::verify(*this); }]; + let hasFolder = 1; } def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -1066,6 +1067,26 @@ return success(); } +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref.cast +/// into the root operation directly. +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto cast = operand.get().getDefiningOp(); + if (cast) { + operand.set(cast.getOperand()); + folded = true; + } + } + return success(folded); +} + +LogicalResult MemcpyOp::fold(ArrayRef operands, + SmallVectorImpl<::mlir::OpFoldResult> &results) { + return foldMemRefCast(*this); +} + #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @memcpy_after_cast +func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { + // CHECK-NOT: memref.cast + // CHECK: gpu.memcpy + %0 = memref.cast %arg0 : memref<10xf32> to memref + %1 = memref.cast %arg1 : memref<10xf32> to memref + gpu.memcpy %0,%1 : memref, memref + return +}