diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -342,8 +342,8 @@ undefined behavior. }]; - let arguments = (ins Arg:$input); - let results = (outs Arg:$output); + let arguments = (ins Arg:$input); + let results = (outs Arg:$output); let extraClassDeclaration = [{ Value getSource() { return input();} @@ -376,7 +376,7 @@ ``` }]; - let arguments = (ins Arg:$memref); + let arguments = (ins Arg:$memref); let hasFolder = 1; let assemblyFormat = "$memref attr-dict `:` type($memref)"; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -474,7 +474,11 @@ // CloneOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CloneOp op) { return success(); } +static LogicalResult verify(CloneOp op) { + if (!op.input().getType().isa()) + return op.emitOpError("operand must be a memref"); + return success(); +} void CloneOp::getEffects( SmallVectorImpl> @@ -545,7 +549,7 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(DeallocOp op) { - if (!op.memref().getType().isa()) + if (!op.memref().getType().isa()) return op.emitOpError("operand must be a memref"); return success(); } diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -139,7 +139,7 @@ { func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) { // expected-error @+2 {{unhandled deallocation type}} - // expected-error @+1 {{'memref.dealloc' op operand #0 must be memref of any type values}} + // expected-error @+1 {{'memref.dealloc' op operand #0 must be unranked.memref of any type values or memref of any type values}} memref.dealloc %arg0 : memref<4x?xf32, 3> return } @@ -154,7 +154,7 @@ { func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) { // expected-error @+2 {{unhandled deallocation type}} - // expected-error @+1 {{op operand #0 must be memref of any type values}} + // expected-error @+1 {{op operand #0 must be unranked.memref of any type values or memref of any type values}} memref.dealloc %arg0 : memref<4x5xf32> return } diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -60,3 +60,19 @@ %1 = memref.tensor_load %0 : memref<2xf32> return } + +// CHECK-LABEL: func @memref_clone +func @memref_clone() { + %0 = memref.alloc() : memref<2xf32> + %1 = memref.cast %0 : memref<2xf32> to memref<*xf32> + %2 = memref.clone %1 : memref<*xf32> to memref<*xf32> + return +} + +// CHECK-LABEL: func @memref_dealloc +func @memref_dealloc() { + %0 = memref.alloc() : memref<2xf32> + %1 = memref.cast %0 : memref<2xf32> to memref<*xf32> + memref.dealloc %1 : memref<*xf32> + return +} diff --git a/mlir/test/Transforms/buffer-deallocation.mlir b/mlir/test/Transforms/buffer-deallocation.mlir --- a/mlir/test/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Transforms/buffer-deallocation.mlir @@ -90,6 +90,43 @@ // ----- +// Test case: See above. + +// CHECK-LABEL: func @condBranchUnrankedType +func @condBranchUnrankedType( + %arg0: i1, + %arg1: memref<*xf32>, + %arg2: memref<*xf32>, + %arg3: index) { + cond_br %arg0, ^bb1, ^bb2(%arg3: index) +^bb1: + br ^bb3(%arg1 : memref<*xf32>) +^bb2(%0: index): + %1 = memref.alloc(%0) : memref + %2 = memref.cast %1 : memref to memref<*xf32> + test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>) + br ^bb3(%2 : memref<*xf32>) +^bb3(%3: memref<*xf32>): + test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>) + return +} + +// CHECK-NEXT: cond_br +// CHECK: %[[ALLOC0:.*]] = memref.clone +// CHECK-NEXT: br ^bb3(%[[ALLOC0]] +// CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) +// CHECK: test.buffer_based +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone +// CHECK-NEXT: memref.dealloc %[[ALLOC1]] +// CHECK-NEXT: br ^bb3 +// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}}) +// CHECK: test.copy(%[[ALLOC3]], +// CHECK-NEXT: memref.dealloc %[[ALLOC3]] +// CHECK-NEXT: return + +// ----- + // Test Case: // bb0 // / \ diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1865,8 +1865,8 @@ let description = [{ Represents a copy operation. }]; - let arguments = (ins Res:$source, - Res:$target); + let arguments = (ins Res:$source, + Res:$target); let assemblyFormat = [{ `(` $source `,` $target `)` `:` `(` type($source) `,` type($target) `)` attr-dict @@ -1901,7 +1901,8 @@ let description = [{ A buffer based operation, that uses memRefs as input and output. }]; - let arguments = (ins AnyMemRef:$input, AnyMemRef:$output); + let arguments = (ins AnyRankedOrUnrankedMemRef:$input, + AnyRankedOrUnrankedMemRef:$output); } def BufferBasedOp : BufferBasedOpBase<"buffer_based", []>{