diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -122,8 +122,14 @@ void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } Value size(unsigned i) { return d.size(rewriter(), loc(), i); } void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } + void setConstantSize(unsigned i, int64_t v) { + d.setConstantSize(rewriter(), loc(), i, v); + } Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } + void setConstantStride(unsigned i, int64_t v) { + d.setConstantStride(rewriter(), loc(), i, v); + } operator Value() { return d; } @@ -161,6 +167,48 @@ } }; +// ReshapeOp creates a new view descriptor of the proper rank. +// For now, the only conversion supported is for target MemRef with static sizes +// and strides. +class ReshapeOpConversion : public LLVMOpLowering { +public: + explicit ReshapeOpConversion(MLIRContext *context, + LLVMTypeConverter &lowering_) + : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reshapeOp = cast(op); + MemRefType dstType = reshapeOp.getResult().getType().cast(); + + if (!dstType.hasStaticShape()) + return matchFailure(); + + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(dstType, strides, offset); + if (failed(res) || llvm::any_of(strides, [](int64_t val) { + return ShapedType::isDynamicStrideOrOffset(val); + })) + return matchFailure(); + + edsc::ScopedContext context(rewriter, op->getLoc()); + ReshapeOpOperandAdaptor adaptor(operands); + BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper desc(lowering.convertType(dstType)); + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); + desc.setOffset(baseDesc.offset()); + for (auto en : llvm::enumerate(dstType.getShape())) + desc.setConstantSize(en.index(), en.value()); + for (auto en : llvm::enumerate(strides)) + desc.setConstantStride(en.index(), en.value()); + rewriter.replaceOp(op, {desc}); + return matchSuccess(); + } +}; + /// Conversion pattern that transforms a linalg.slice op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. @@ -508,8 +556,8 @@ void mlir::populateLinalgToLLVMConversionPatterns( LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx, converter); + patterns.insert(ctx, converter); } namespace { diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -196,3 +196,63 @@ // CHECK-LABEL: func @matmul_vec_indexed( // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () + +func @reshape_static(%arg0: memref<3x4x5xf32>) { + // Reshapes that expand and collapse back a contiguous tensor with some 1's. + %0 = linalg.reshape %arg0 [(i, j, k, l, m) -> (i, j), + (i, j, k, l, m) -> (k), + (i, j, k, l, m) -> (l, m)] : + memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + %r0 = linalg.reshape %0 [(i, j, k, l, m) -> (i, j), + (i, j, k, l, m) -> (k), + (i, j, k, l, m) -> (l, m)] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return +} +// CHECK-LABEL: func @reshape_static( +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">