diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" +#include "llvm/ADT/SmallSet.h" #define DEBUG_TYPE "normalize-memrefs" @@ -24,39 +25,45 @@ /// All memrefs passed across functions with non-trivial layout maps are /// converted to ones with trivial identity layout ones. - -// Input :- -// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> -// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) -> -// (memref<16xf64, #tile>) { -// affine.for %arg3 = 0 to 16 { -// %a = affine.load %A[%arg3] : memref<16xf64, #tile> -// %p = mulf %a, %a : f64 -// affine.store %p, %A[%arg3] : memref<16xf64, #tile> -// } -// %c = alloc() : memref<16xf64, #tile> -// %d = affine.load %c[0] : memref<16xf64, #tile> -// return %A: memref<16xf64, #tile> -// } - -// Output :- -// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) -// -> memref<4x4xf64> { -// affine.for %arg3 = 0 to 16 { -// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> -// %3 = mulf %2, %2 : f64 -// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> -// } -// %0 = alloc() : memref<16xf64, #map0> -// %1 = affine.load %0[0] : memref<16xf64, #map0> -// return %arg0 : memref<4x4xf64> -// } - +/// If all the memref types/uses in a function are normalizable, we treat +/// such functions as normalizable. Also, if a normalizable function is known +/// to call a non-normalizable function, we treat that function as +/// non-normalizable as well. We assume external functions to be normalizable. +/// +/// Input :- +/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> +/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) -> +/// (memref<16xf64, #tile>) { +/// affine.for %arg3 = 0 to 16 { +/// %a = affine.load %A[%arg3] : memref<16xf64, #tile> +/// %p = mulf %a, %a : f64 +/// affine.store %p, %A[%arg3] : memref<16xf64, #tile> +/// } +/// %c = alloc() : memref<16xf64, #tile> +/// %d = affine.load %c[0] : memref<16xf64, #tile> +/// return %A: memref<16xf64, #tile> +/// } +/// +/// Output :- +/// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) +/// -> memref<4x4xf64> { +/// affine.for %arg3 = 0 to 16 { +/// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : +/// memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3 +/// floordiv 4, %arg3 mod 4] : memref<4x4xf64> +/// } +/// %0 = alloc() : memref<16xf64, #map0> +/// %1 = affine.load %0[0] : memref<16xf64, #map0> +/// return %arg0 : memref<4x4xf64> +/// } +/// struct NormalizeMemRefs : public NormalizeMemRefsBase { void runOnOperation() override; - void runOnFunction(FuncOp funcOp); + void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp); bool areMemRefsNormalizable(FuncOp funcOp); - void updateFunctionSignature(FuncOp funcOp); + void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp); + void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp, + DenseSet &normalizableFuncs); }; } // end anonymous namespace @@ -67,41 +74,109 @@ void NormalizeMemRefs::runOnOperation() { ModuleOp moduleOp = getOperation(); - // We traverse each function within the module in order to normalize the - // memref type arguments. - // TODO: Handle external functions. + // We maintain all normalizable FuncOps in a DenseSet. It is initialized + // with all the functions within a module and then functions which are not + // normalizable are removed from this set. + // TODO: Change this to work on FuncLikeOp once there is an operation + // interface for it. + DenseSet normalizableFuncs; + // Initialize `normalizableFuncs` with all the functions within a module. + moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); + + // Traverse through all the functions applying a filter which determines + // whether that function is normalizable or not. All callers/callees of + // a non-normalizable function will also become non-normalizable even if + // they aren't passing any or specific non-normalizable memrefs. So, + // functions which calls or get called by a non-normalizable becomes non- + // normalizable functions themselves. moduleOp.walk([&](FuncOp funcOp) { - if (areMemRefsNormalizable(funcOp)) - runOnFunction(funcOp); + if (normalizableFuncs.contains(funcOp)) { + if (!areMemRefsNormalizable(funcOp)) { + // Since this function is not normalizable, we set all the caller + // functions and the callees of this function as not normalizable. + // TODO: Drop this conservative assumption in the future. + setCalleesAndCallersNonNormalizable(funcOp, moduleOp, + normalizableFuncs); + } + } }); + + // Those functions which can be normalized are subjected to normalization. + for (FuncOp &funcOp : normalizableFuncs) + normalizeFuncOpMemRefs(funcOp, moduleOp); } -// Return true if this operation dereferences one or more memref's. -// TODO: Temporary utility, will be replaced when this is modeled through -// side-effects/op traits. +/// Return true if this operation dereferences one or more memref's. +/// TODO: Temporary utility, will be replaced when this is modeled through +/// side-effects/op traits. static bool isMemRefDereferencingOp(Operation &op) { return isa(op); } -// Check whether all the uses of oldMemRef are either dereferencing uses or the -// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied -// will the value become a candidate for replacement. +/// Check whether all the uses of oldMemRef are either dereferencing uses or the +/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints +/// are satisfied will the value become a candidate for replacement. +/// TODO: Extend this for DimOps. static bool isMemRefNormalizable(Value::user_range opUsers) { if (llvm::any_of(opUsers, [](Operation *op) { if (isMemRefDereferencingOp(*op)) return false; - return !isa(*op); + return !isa(*op); })) return false; return true; } -// Check whether all the uses of AllocOps, CallOps and function arguments of a -// function are either of dereferencing type or of type: DeallocOp, CallOp. Only -// if these constraints are satisfied will the function become a candidate for -// normalization. +/// Set all the calling functions and the callees of the function as not +/// normalizable. +void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( + FuncOp funcOp, ModuleOp moduleOp, DenseSet &normalizableFuncs) { + if (!normalizableFuncs.contains(funcOp)) + return; + + normalizableFuncs.erase(funcOp); + // Caller of the function. + Optional symbolUses = funcOp.getSymbolUses(moduleOp); + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + // TODO: Extend this for ops that are FunctionLike. This would require + // creating an OpInterface for FunctionLike ops. + FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType(); + for (FuncOp &funcOp : normalizableFuncs) { + if (parentFuncOp == funcOp) { + setCalleesAndCallersNonNormalizable(funcOp, moduleOp, + normalizableFuncs); + break; + } + } + } + + // Functions called by this function. + funcOp.walk([&](CallOp callOp) { + StringRef callee = callOp.getCallee(); + for (FuncOp &funcOp : normalizableFuncs) { + // We compare FuncOp and callee's name. + if (callee == funcOp.getName()) { + setCalleesAndCallersNonNormalizable(funcOp, moduleOp, + normalizableFuncs); + break; + } + } + }); +} + +/// Check whether all the uses of AllocOps, CallOps and function arguments of a +/// function are either of dereferencing type or are uses in: DeallocOp, CallOp +/// or ReturnOp. Only if these constraints are satisfied will the function +/// become a candidate for normalization. We follow a conservative approach here +/// wherein even if the non-normalizable memref is not a part of the function's +/// argument or return type, we still label the entire function as +/// non-normalizable. We assume external functions to be normalizable. bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) { + // We assume external functions to be normalizable. + if (funcOp.isExternal()) + return true; + if (funcOp .walk([&](AllocOp allocOp) -> WalkResult { Value oldMemRef = allocOp.getResult(); @@ -136,28 +211,138 @@ return true; } -// Fetch the updated argument list and result of the function and update the -// function signature. -void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) { +/// Fetch the updated argument list and result of the function and update the +/// function signature. This updates the function's return type at the caller +/// site and in case the return type is a normalized memref then it updates +/// the calling function's signature. +/// TODO: An update to the calling function signature is required only if the +/// returned value is in turn used in ReturnOp of the calling function. +void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp, + ModuleOp moduleOp) { FunctionType functionType = funcOp.getType(); - SmallVector argTypes; SmallVector resultTypes; + FunctionType newFuncType; + resultTypes = llvm::to_vector<4>(functionType.getResults()); - for (const auto &arg : llvm::enumerate(funcOp.getArguments())) - argTypes.push_back(arg.value().getType()); + // External function's signature was already updated in + // 'normalizeFuncOpMemRefs()'. + if (!funcOp.isExternal()) { + SmallVector argTypes; + for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) + argTypes.push_back(argEn.value().getType()); - resultTypes = llvm::to_vector<4>(functionType.getResults()); - // We create a new function type and modify the function signature with this - // new type. - FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes, - /*results=*/resultTypes, - /*context=*/&getContext()); - - // TODO: Handle ReturnOps to update function results the caller site. - funcOp.setType(newFuncType); + // Traverse ReturnOps to check if an update to the return type in the + // function signature is required. + funcOp.walk([&](ReturnOp returnOp) { + for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { + Type opType = operandEn.value().getType(); + MemRefType memrefType = opType.dyn_cast(); + // If type is not memref or if the memref type is same as that in + // function's return signature then no update is required. + if (!memrefType || memrefType == resultTypes[operandEn.index()]) + continue; + // Update function's return type signature. + // Return type gets normalized either as a result of function argument + // normalization, AllocOp normalization or an update made at CallOp. + // There can be many call flows inside a function and an update to a + // specific ReturnOp has not yet been made. So we check that the result + // memref type is normalized. + // TODO: When selective normalization is implemented, handle multiple + // results case where some are normalized, some aren't. + if (memrefType.getAffineMaps().empty()) + resultTypes[operandEn.index()] = memrefType; + } + }); + + // We create a new function type and modify the function signature with this + // new type. + newFuncType = FunctionType::get(/*inputs=*/argTypes, + /*results=*/resultTypes, + /*context=*/&getContext()); + } + + // Since we update the function signature, it might affect the result types at + // the caller site. Since this result might even be used by the caller + // function in ReturnOps, the caller function's signature will also change. + // Hence we record the caller function in 'funcOpsToUpdate' to update their + // signature as well. + llvm::SmallDenseSet funcOpsToUpdate; + // We iterate over all symbolic uses of the function and update the return + // type at the caller site. + Optional symbolUses = funcOp.getSymbolUses(moduleOp); + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + Operation *callOp = symbolUse.getUser(); + OpBuilder builder(callOp); + StringRef callee = cast(callOp).getCallee(); + Operation *newCallOp = builder.create( + callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), + callOp->getOperands()); + bool replacingMemRefUsesFailed = false; + bool returnTypeChanged = false; + for (unsigned resIndex : llvm::seq(0, callOp->getNumResults())) { + OpResult oldResult = callOp->getResult(resIndex); + OpResult newResult = newCallOp->getResult(resIndex); + // This condition ensures that if the result is not of type memref or if + // the resulting memref was already having a trivial map layout then we + // need not perform any use replacement here. + if (oldResult.getType() == newResult.getType()) + continue; + AffineMap layoutMap = + oldResult.getType().dyn_cast().getAffineMaps().front(); + if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true, + /*replaceInDeallocOp=*/true))) { + // If it failed (due to escapes for example), bail out. + // It should never hit this part of the code because it is called by + // only those functions which are normalizable. + newCallOp->erase(); + replacingMemRefUsesFailed = true; + break; + } + returnTypeChanged = true; + } + if (replacingMemRefUsesFailed) + continue; + // Replace all uses for other non-memref result types. + callOp->replaceAllUsesWith(newCallOp); + callOp->erase(); + if (returnTypeChanged) { + // Since the return type changed it might lead to a change in function's + // signature. + // TODO: If funcOp doesn't return any memref type then no need to update + // signature. + // TODO: Further optimization - Check if the memref is indeed part of + // ReturnOp at the parentFuncOp and only then updation of signature is + // required. + // TODO: Extend this for ops that are FunctionLike. This would require + // creating an OpInterface for FunctionLike ops. + FuncOp parentFuncOp = newCallOp->getParentOfType(); + funcOpsToUpdate.insert(parentFuncOp); + } + } + // Because external function's signature is already updated in + // 'normalizeFuncOpMemRefs()', we don't need to update it here again. + if (!funcOp.isExternal()) + funcOp.setType(newFuncType); + + // Updating the signature type of those functions which call the current + // function. Only if the return type of the current function has a normalized + // memref will the caller function become a candidate for signature update. + for (FuncOp parentFuncOp : funcOpsToUpdate) + updateFunctionSignature(parentFuncOp, moduleOp); } -void NormalizeMemRefs::runOnFunction(FuncOp funcOp) { +/// Normalizes the memrefs within a function which includes those arising as a +/// result of AllocOps, CallOps and function's argument. The ModuleOp argument +/// is used to help update function's signature after normalization. +void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp, + ModuleOp moduleOp) { // Turn memrefs' non-identity layouts maps into ones with identity. Collect // alloc ops first and then process since normalizeMemRef replaces/erases ops // during memref rewriting. @@ -169,22 +354,27 @@ // We use this OpBuilder to create new memref layout later. OpBuilder b(funcOp); + FunctionType functionType = funcOp.getType(); + SmallVector inputTypes; // Walk over each argument of a function to perform memref normalization (if - // any). - for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { - Type argType = funcOp.getArgument(argIndex).getType(); + for (unsigned argIndex : + llvm::seq(0, functionType.getNumInputs())) { + Type argType = functionType.getInput(argIndex); MemRefType memrefType = argType.dyn_cast(); // Check whether argument is of MemRef type. Any other argument type can // simply be part of the final function signature. - if (!memrefType) + if (!memrefType) { + inputTypes.push_back(argType); continue; + } // Fetch a new memref type after normalizing the old memref to have an // identity map layout. MemRefType newMemRefType = normalizeMemRefType(memrefType, b, /*numSymbolicOperands=*/0); - if (newMemRefType == memrefType) { + if (newMemRefType == memrefType || funcOp.isExternal()) { // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. + inputTypes.push_back(newMemRefType); continue; } @@ -202,7 +392,7 @@ /*domInstFilter=*/nullptr, /*postDomInstFilter=*/nullptr, /*allowNonDereferencingOps=*/true, - /*handleDeallocOp=*/true))) { + /*replaceInDeallocOp=*/true))) { // If it failed (due to escapes for example), bail out. Removing the // temporary argument inserted previously. funcOp.front().eraseArgument(argIndex); @@ -214,5 +404,36 @@ funcOp.front().eraseArgument(argIndex + 1); } - updateFunctionSignature(funcOp); + // In a normal function, memrefs in the return type signature gets normalized + // as a result of normalization of functions arguments, AllocOps or CallOps' + // result types. Since an external function doesn't have a body, memrefs in + // the return type signature can only get normalized by iterating over the + // individual return types. + if (funcOp.isExternal()) { + SmallVector resultTypes; + for (unsigned resIndex : + llvm::seq(0, functionType.getNumResults())) { + Type resType = functionType.getResult(resIndex); + MemRefType memrefType = resType.dyn_cast(); + // Check whether result is of MemRef type. Any other argument type can + // simply be part of the final function signature. + if (!memrefType) { + resultTypes.push_back(resType); + continue; + } + // Computing a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + /*numSymbolicOperands=*/0); + resultTypes.push_back(newMemRefType); + continue; + } + + FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes, + /*results=*/resultTypes, + /*context=*/&getContext()); + // Setting the new function signature for this external function. + funcOp.setType(newFuncType); + } + updateFunctionSignature(funcOp, moduleOp); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -274,12 +274,12 @@ // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. if (!isMemRefDereferencingOp(*op)) { - // Currently we support the following non-dereferencing types to be a - // candidate for replacement: Dealloc and CallOp. - // TODO: Add support for other kinds of ops. if (!allowNonDereferencingOps) return failure(); - if (!(isa(*op))) + // Currently we support the following non-dereferencing ops to be a + // candidate for replacement: Dealloc, CallOp and ReturnOp. + // TODO: Add support for other kinds of ops. + if (!isa(*op)) return failure(); } diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir --- a/mlir/test/Transforms/normalize-memrefs.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -126,14 +126,6 @@ return } -// Memref escapes; no normalization. -// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}> -func @escaping() -> memref<64xf32, affine_map<(d0) -> (d0 + 2)>> { - // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}> - %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>> - return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>> -} - // Semi-affine maps, normalization not implemented yet. // CHECK-LABEL: func @semi_affine_layout_map func @semi_affine_layout_map(%s0: index, %s1: index) { @@ -205,9 +197,125 @@ return %d : i1 } -// Test case 4: No normalization should take place because the function is returning the memref. -// CHECK-LABEL: func @memref_used_in_return -// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}> -func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) { - return %A : memref<8xf64, #tile> +// Test cases here onwards deal with normalization of memref in function signature, caller site. + +// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls. +// CHECK-LABEL: func @ret_multiple_argument_type +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64) +func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) { + %a = affine.load %A[0] : memref<16xf64, #tile> + %p = mulf %a, %a : f64 + %cond = constant 1 : i1 + cond_br %cond, ^bb1, ^bb2 + ^bb1: + %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + return %res2, %p: memref<8xf64, #tile>, f64 + ^bb2: + return %C, %p: memref<8xf64, #tile>, f64 +} + +// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> +// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 +// CHECK: %true = constant true +// CHECK: cond_br %true, ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64 + +// CHECK-LABEL: func @ret_single_argument_type +// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){ + %a = alloc() : memref<8xf64, #tile> + %b = alloc() : memref<16xf64, #tile> + %d = constant 23.0 : f64 + call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) + %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) + return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile> +} + +// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64> +// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64> +// CHECK: %cst = constant 2.300000e+01 : f64 +// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64) +// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) +// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64> + +// Test case set #5: To check normalization in a chain of interconnected functions. +// CHECK-LABEL: func @func_A +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) +func @func_A(%A: memref<8xf64, #tile>) { + call @func_B(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> () + +// CHECK-LABEL: func @func_B +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) +func @func_B(%A: memref<8xf64, #tile>) { + call @func_C(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> () + +// CHECK-LABEL: func @func_C +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) +func @func_C(%A: memref<8xf64, #tile>) { + return +} + +// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type. +// CHECK-LABEL: func @some_func_A +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) +func @some_func_A(%A: memref<8xf64, #tile>) { + call @some_func_B(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> () + +// CHECK-LABEL: func @some_func_B +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) +func @some_func_B(%A: memref<8xf64, #tile>) { + "test.test"(%A) : (memref<8xf64, #tile>) -> () + call @some_func_C(%A) : (memref<8xf64, #tile>) -> () + return +} +// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> () + +// CHECK-LABEL: func @some_func_C +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) +func @some_func_C(%A: memref<8xf64, #tile>) { + return +} + +// Test case set #7: Check normalization in case of external functions. +// CHECK-LABEL: func @external_func_A +// CHECK-SAME: (memref<4x4xf64>) +func @external_func_A(memref<16xf64, #tile>) -> () + +// CHECK-LABEL: func @external_func_B +// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64> +func @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) + +// CHECK-LABEL: func @simply_call_external() +func @simply_call_external() { + %a = alloc() : memref<16xf64, #tile> + call @external_func_A(%a) : (memref<16xf64, #tile>) -> () + return +} +// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64> +// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> () + +// CHECK-LABEL: func @use_value_of_external +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64> +func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) { + %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) + return %res : memref<8xf64, #tile> } +// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64> +// CHECK: return %{{.*}} : memref<2x4xf64>