diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -658,7 +658,7 @@ // CallOp //===----------------------------------------------------------------------===// -def CallOp : Std_Op<"call", [CallOpInterface]> { +def CallOp : Std_Op<"call", [MemRefsNormalizable, CallOpInterface]> { let summary = "call operation"; let description = [{ The `call` operation represents a direct call to a function that is within @@ -1388,7 +1388,8 @@ // DeallocOp //===----------------------------------------------------------------------===// -def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> { +def DeallocOp : Std_Op<"dealloc", + [MemRefsNormalizable, MemoryEffects<[MemFree]>]> { let summary = "memory deallocation operation"; let description = [{ The `dealloc` operation frees the region of memory referenced by a memref @@ -2125,7 +2126,7 @@ //===----------------------------------------------------------------------===// def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike, - Terminator]> { + Terminator, MemRefsNormalizable]> { let summary = "return operation"; let description = [{ The `return` operation represents a return operation within a function. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1698,6 +1698,8 @@ NativeOpTrait<"SameOperandsAndResultElementType">; // Op is a terminator. def Terminator : NativeOpTrait<"IsTerminator">; +// Op can be safely normalized in presence of MemRefs with maps +def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1212,6 +1212,15 @@ } }; +/// This trait provides a the definition for MemRefsNormalizable +template +struct MemRefsNormalizable + : public TraitBase { + //static LogicalResult verifyTrait(Operation *op) { + // return true; + //} +}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -120,9 +120,11 @@ /// TODO: Extend this for DimOps. static bool isMemRefNormalizable(Value::user_range opUsers) { if (llvm::any_of(opUsers, [](Operation *op) { + if (op->hasTrait()) + return false; if (isMemRefDereferencingOp(*op)) return false; - return !isa(*op); + return true; })) return false; return true; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -279,7 +279,7 @@ // Currently we support the following non-dereferencing ops to be a // candidate for replacement: Dealloc, CallOp and ReturnOp. // TODO: Add support for other kinds of ops. - if (!isa(*op)) + if (! op->hasTrait()) return failure(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -618,6 +618,16 @@ let arguments = (ins I32, OptionalAttr:$optional_attr); let results = (outs I32); } + +// Test for memrefs normalization of an op with normalizable memrefs. +def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); +} +def OpNonNorm : TEST_Op<"op_nonnorm"> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); +} + + // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice. diff --git a/mlir/test/mlir-tblgen/op-memrefs-norm.mlir b/mlir/test/mlir-tblgen/op-memrefs-norm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-memrefs-norm.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt -test-patterns -normalize-memrefs %s | FileCheck %s + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)> + +// CHECK-LABEL: test_norm +func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %0 = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () + // CHECK: dealloc %0 : memref<1x16x1x1x32x64xf32> + return +} + +// CHECK-LABEL: test_nonnorm +func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %0 = alloc() : memref<1x16x14x14xf32, #map0> + // CHECK: "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + // CHECK: dealloc %0 : memref<1x16x14x14xf32, #map0> + return +} + +// CHECK-LABEL: test_norm_mix +func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %0 = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () + // CHECK: dealloc %0 : memref<1x16x1x1x32x64xf32> + return +}